Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ffn/training/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def get_example(load_example, eval_tracker: tracker.EvalTracker,
assert predicted.base is seed
yield predicted, patches, labels, weights

eval_tracker.add_patch(full_labels, seed, loss_weights, coord)
eval_tracker.add_patch(full_labels, seed, loss_weights, coord,
volume_name=volname)


ExampleGenerator = Iterable[tuple[np.ndarray, np.ndarray, np.ndarray,
Expand Down
242 changes: 163 additions & 79 deletions ffn/training/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
import collections
import enum
import io
from typing import Optional, Sequence
from typing import Any, Sequence

from absl import logging
import numpy as np

import PIL
import PIL.Image
import PIL.ImageDraw
import PIL.ImageFont
from scipy import special

import tensorflow.compat.v1 as tf

from . import mask
from . import variables

Expand Down Expand Up @@ -62,20 +63,26 @@ class FovStat(enum.IntEnum):
class EvalTracker:
"""Tracks eval results over multiple training steps."""

def __init__(self,
eval_shape: list[int],
shifts: Sequence[tuple[int, int, int]]):
def __init__(
self, eval_shape: list[int], shifts: Sequence[tuple[int, int, int]]
):
# TODO(mjanusz): Remove this TFv1 code once no longer used.
if not tf.executing_eagerly():
self.eval_labels = tf.compat.v1.placeholder(
tf.float32, [1] + eval_shape + [1], name='eval_labels')
tf.float32, [1] + eval_shape + [1], name='eval_labels'
)
self.eval_preds = tf.compat.v1.placeholder(
tf.float32, [1] + eval_shape + [1], name='eval_preds')
tf.float32, [1] + eval_shape + [1], name='eval_preds'
)
self.eval_weights = tf.compat.v1.placeholder(
tf.float32, [1] + eval_shape + [1], name='eval_weights')
tf.float32, [1] + eval_shape + [1], name='eval_weights'
)
self.eval_loss = tf.reduce_mean(
self.eval_weights * tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.eval_preds, labels=self.eval_labels))
self.eval_weights
* tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.eval_preds, labels=self.eval_labels
)
)
self.sess = None
self.eval_threshold = special.logit(0.9)
self._eval_shape = eval_shape # zyx
Expand Down Expand Up @@ -138,12 +145,15 @@ def track_weights(self, weights: np.ndarray):
self.fov_stats.value[FovStat.MASKED_VOXELS] += np.sum(weights == 0.0)
self.fov_stats.value[FovStat.WEIGHTS_SUM] += np.sum(weights)

def record_move(self, wanted: bool, executed: bool,
offset_xyz: Sequence[int]):
def record_move(
self, wanted: bool, executed: bool, offset_xyz: Sequence[int]
):
"""Records an FFN FOV move."""
r = int(np.linalg.norm(offset_xyz))
assert r in self.moves_by_r, ('%d not in %r' %
(r, list(self.moves_by_r.keys())))
assert r in self.moves_by_r, '%d not in %r' % (
r,
list(self.moves_by_r.keys()),
)

if wanted:
if executed:
Expand All @@ -156,9 +166,15 @@ def record_move(self, wanted: bool, executed: bool,
self.moves.value[MoveType.SPURIOUS] += 1
self.moves_by_r[r].value[MoveType.SPURIOUS] += 1

def slice_image(self, coord: np.ndarray, labels: np.ndarray,
predicted: np.ndarray, weights: np.ndarray,
slice_axis: int) -> tf.Summary.Value:
def slice_image(
self,
coord: np.ndarray,
labels: np.ndarray,
predicted: np.ndarray,
weights: np.ndarray,
slice_axis: int,
volume_name: str | bytes | Sequence[Any] | np.ndarray | None = None,
) -> tf.Summary.Value:
"""Builds a tf.Summary showing a slice of an object mask.

The object mask slice is shown side by side with the corresponding
Expand All @@ -172,6 +188,7 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray,
slice_axis: axis in the middle of which to place the cutting plane for
which the summary image will be generated, valid values are 2 ('x'), 1
('y'), and 0 ('z').
volume_name: name of the volume to be displayed on the image.

Returns:
tf.Summary.Value object with the image.
Expand All @@ -191,14 +208,37 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray,

im = PIL.Image.fromarray(
np.repeat(
np.concatenate([labels, predicted, weights], axis=1)[...,
np.newaxis],
np.concatenate([labels, predicted, weights], axis=1)[
..., np.newaxis
],
3,
axis=2), 'RGB')
axis=2,
),
'RGB',
)
draw = PIL.ImageDraw.Draw(im)

x, y, z = coord.squeeze()
draw.text((1, 1), '%d %d %d' % (x, y, z), fill='rgb(255,64,64)')
text = f'{x},{y},{z}'
if volume_name is not None:
if (
isinstance(volume_name, (list, tuple, np.ndarray))
and len(volume_name) == 1
):
volume_name = volume_name[0]

if isinstance(volume_name, bytes):
volume_name = volume_name.decode('utf-8')

text += f'\n{volume_name}'

try:

# font = PIL.ImageFont.load_default()
except (IOError, ValueError):
font = PIL.ImageFont.load_default()

draw.text((1, 1), text, fill='rgb(255,64,64)', font=font)
del draw

im.save(buf, 'PNG')
Expand All @@ -212,14 +252,19 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray,
height=h,
width=w * 3,
colorspace=3, # RGB
encoded_image_string=buf.getvalue()))

def add_patch(self,
labels: np.ndarray,
predicted: np.ndarray,
weights: np.ndarray,
coord: Optional[np.ndarray] = None,
image_summaries: bool = True):
encoded_image_string=buf.getvalue(),
),
)

def add_patch(
self,
labels: np.ndarray,
predicted: np.ndarray,
weights: np.ndarray,
coord: np.ndarray | None = None,
image_summaries: bool = True,
volume_name: str | None = None,
):
"""Evaluates single-object segmentation quality."""

predicted = mask.crop_and_pad(predicted, (0, 0, 0), self._eval_shape)
Expand All @@ -228,15 +273,21 @@ def add_patch(self,

if not tf.executing_eagerly():
assert self.sess is not None
loss, = self.sess.run(
[self.eval_loss], {
(loss,) = self.sess.run(
[self.eval_loss],
{
self.eval_labels: labels,
self.eval_preds: predicted,
self.eval_weights: weights
})
self.eval_weights: weights,
},
)
else:
loss = tf.reduce_mean(weights * tf.nn.sigmoid_cross_entropy_with_logits(
logits=predicted, labels=labels))
loss = tf.reduce_mean(
weights
* tf.nn.sigmoid_cross_entropy_with_logits(
logits=predicted, labels=labels
)
)

self.loss.value[:] += loss
self.num_voxels.value[VoxelType.TOTAL] += labels.size
Expand All @@ -247,23 +298,29 @@ def add_patch(self,
pred_bg = np.logical_not(pred_mask)
true_bg = np.logical_not(true_mask)

self.prediction_counts.value[PredictionType.TP] += np.sum(pred_mask
& true_mask)
self.prediction_counts.value[PredictionType.TP] += np.sum(
pred_mask & true_mask
)
self.prediction_counts.value[PredictionType.TN] += np.sum(pred_bg & true_bg)
self.prediction_counts.value[PredictionType.FP] += np.sum(pred_mask
& true_bg)
self.prediction_counts.value[PredictionType.FN] += np.sum(pred_bg
& true_mask)
self.prediction_counts.value[PredictionType.FP] += np.sum(
pred_mask & true_bg
)
self.prediction_counts.value[PredictionType.FN] += np.sum(
pred_bg & true_mask
)
self.num_patches.value[:] += 1

if image_summaries:
predicted = special.expit(predicted)
self.images_xy.append(
self.slice_image(coord, labels, predicted, weights, 0))
self.slice_image(coord, labels, predicted, weights, 0, volume_name)
)
self.images_xz.append(
self.slice_image(coord, labels, predicted, weights, 1))
self.slice_image(coord, labels, predicted, weights, 1, volume_name)
)
self.images_yz.append(
self.slice_image(coord, labels, predicted, weights, 2))
self.slice_image(coord, labels, predicted, weights, 2, volume_name)
)

def _compute_classification_metrics(self, prediction_counts, prefix):
"""Computes standard classification metrics."""
Expand All @@ -276,19 +333,21 @@ def _compute_classification_metrics(self, prediction_counts, prefix):
recall = tp / max(tp + fn, 1)

if precision > 0 or recall > 0:
f1 = (2.0 * precision * recall / (precision + recall))
f1 = 2.0 * precision * recall / (precision + recall)
else:
f1 = 0.0

return [
tf.Summary.Value(
tag='%s/accuracy' % prefix,
simple_value=(tp + tn) / max(tp + tn + fp + fn, 1)),
simple_value=(tp + tn) / max(tp + tn + fp + fn, 1),
),
tf.Summary.Value(tag='%s/precision' % prefix, simple_value=precision),
tf.Summary.Value(tag='%s/recall' % prefix, simple_value=recall),
tf.Summary.Value(
tag='%s/specificity' % prefix, simple_value=tn / max(tn + fp, 1)),
tf.Summary.Value(tag='%s/f1' % prefix, simple_value=f1)
tag='%s/specificity' % prefix, simple_value=tn / max(tn + fp, 1)
),
tf.Summary.Value(tag='%s/f1' % prefix, simple_value=f1),
]

def get_summaries(self) -> list[tf.Summary.Value]:
Expand All @@ -308,49 +367,74 @@ def get_summaries(self) -> list[tf.Summary.Value]:
move_summaries.append(
tf.Summary.Value(
tag='moves/all/%s' % mt.name.lower(),
simple_value=self.moves.tf_value[mt] / total_moves))

summaries = [
tf.Summary.Value(
tag='fov/masked_voxel_fraction',
simple_value=(self.fov_stats.tf_value[FovStat.MASKED_VOXELS] /
self.fov_stats.tf_value[FovStat.TOTAL_VOXELS])),
tf.Summary.Value(
tag='fov/average_weight',
simple_value=(self.fov_stats.tf_value[FovStat.WEIGHTS_SUM] /
self.fov_stats.tf_value[FovStat.TOTAL_VOXELS])),
tf.Summary.Value(
tag='masked_voxel_fraction',
simple_value=(self.num_voxels.tf_value[VoxelType.MASKED] /
self.num_voxels.tf_value[VoxelType.TOTAL])),
tf.Summary.Value(
tag='eval/patch_loss',
simple_value=self.loss.tf_value[0] / self.num_patches.tf_value[0]),
tf.Summary.Value(
tag='eval/patches', simple_value=self.num_patches.tf_value[0]),
tf.Summary.Value(tag='moves/total', simple_value=total_moves)
] + move_summaries + (
list(self.meshes) + list(self.images_xy) + list(self.images_xz) +
list(self.images_yz))
simple_value=self.moves.tf_value[mt] / total_moves,
)
)

summaries = (
[
tf.Summary.Value(
tag='fov/masked_voxel_fraction',
simple_value=(
self.fov_stats.tf_value[FovStat.MASKED_VOXELS]
/ self.fov_stats.tf_value[FovStat.TOTAL_VOXELS]
),
),
tf.Summary.Value(
tag='fov/average_weight',
simple_value=(
self.fov_stats.tf_value[FovStat.WEIGHTS_SUM]
/ self.fov_stats.tf_value[FovStat.TOTAL_VOXELS]
),
),
tf.Summary.Value(
tag='masked_voxel_fraction',
simple_value=(
self.num_voxels.tf_value[VoxelType.MASKED]
/ self.num_voxels.tf_value[VoxelType.TOTAL]
),
),
tf.Summary.Value(
tag='eval/patch_loss',
simple_value=self.loss.tf_value[0]
/ self.num_patches.tf_value[0],
),
tf.Summary.Value(
tag='eval/patches', simple_value=self.num_patches.tf_value[0]
),
tf.Summary.Value(tag='moves/total', simple_value=total_moves),
]
+ move_summaries
+ (
list(self.meshes)
+ list(self.images_xy)
+ list(self.images_xz)
+ list(self.images_yz)
)
)

summaries.extend(
self._compute_classification_metrics(self.prediction_counts,
'eval/all'))
self._compute_classification_metrics(self.prediction_counts, 'eval/all')
)

for r, r_moves in self.moves_by_r.items():
total_moves = sum(r_moves.tf_value)
summaries.extend([
tf.Summary.Value(
tag='moves/r=%d/correct' % r,
simple_value=r_moves.tf_value[MoveType.CORRECT] / total_moves),
simple_value=r_moves.tf_value[MoveType.CORRECT] / total_moves,
),
tf.Summary.Value(
tag='moves/r=%d/spurious' % r,
simple_value=r_moves.tf_value[MoveType.SPURIOUS] / total_moves),
simple_value=r_moves.tf_value[MoveType.SPURIOUS] / total_moves,
),
tf.Summary.Value(
tag='moves/r=%d/missed' % r,
simple_value=r_moves.tf_value[MoveType.MISSED] / total_moves),
simple_value=r_moves.tf_value[MoveType.MISSED] / total_moves,
),
tf.Summary.Value(
tag='moves/r=%d/total' % r, simple_value=total_moves)
tag='moves/r=%d/total' % r, simple_value=total_moves
),
])

return summaries
Loading
Loading