torch_em.transform.label

  1from typing import Optional
  2
  3import numpy as np
  4import skimage.measure
  5import skimage.segmentation
  6import vigra
  7
  8from ..util import ensure_array, ensure_spatial_array
  9
 10try:
 11    from affogato.affinities import compute_affinities
 12except ImportError:
 13    compute_affinities = None
 14
 15
 16def connected_components(labels, ndim=None, ensure_zero=False):
 17    labels = ensure_array(labels) if ndim is None else ensure_spatial_array(labels, ndim)
 18    labels = skimage.measure.label(labels)
 19    if ensure_zero and 0 not in labels:
 20        labels -= 1
 21    return labels
 22
 23
 24def labels_to_binary(labels, background_label=0):
 25    return (labels != background_label).astype(labels.dtype)
 26
 27
 28def label_consecutive(labels, with_background=True):
 29    if with_background:
 30        seg = skimage.segmentation.relabel_sequential(labels)[0]
 31    else:
 32        if 0 in labels:
 33            labels += 1
 34        seg = skimage.segmentation.relabel_sequential(labels)[0]
 35        assert seg.min() == 1
 36        seg -= 1
 37    return seg
 38
 39
 40# TODO smoothing
 41class BoundaryTransform:
 42    def __init__(self, mode="thick", add_binary_target=False, ndim=None):
 43        self.mode = mode
 44        self.add_binary_target = add_binary_target
 45        self.ndim = ndim
 46
 47    def __call__(self, labels):
 48        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
 49        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
 50        if self.add_binary_target:
 51            binary = labels_to_binary(labels)[None].astype(boundaries.dtype)
 52            target = np.concatenate([binary, boundaries], axis=0)
 53        else:
 54            target = boundaries
 55        return target
 56
 57
 58# TODO smoothing
 59class NoToBackgroundBoundaryTransform:
 60    def __init__(self, bg_label=0, mask_label=-1, mode="thick", add_binary_target=False, ndim=None):
 61        self.bg_label = bg_label
 62        self.mask_label = mask_label
 63        self.mode = mode
 64        self.ndim = ndim
 65        self.add_binary_target = add_binary_target
 66
 67    def __call__(self, labels):
 68        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
 69        # calc normal boundaries
 70        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
 71
 72        # make label image binary and calculate to-background-boundaries
 73        labels_binary = (labels != self.bg_label)
 74        to_bg_boundaries = skimage.segmentation.find_boundaries(labels_binary, mode=self.mode)[None]
 75
 76        # mask the to-background-boundaries
 77        boundaries = boundaries.astype(np.int8)
 78        boundaries[to_bg_boundaries] = self.mask_label
 79
 80        if self.add_binary_target:
 81            binary = labels_to_binary(labels, self.bg_label).astype(boundaries.dtype)
 82            binary[labels == self.mask_label] = self.mask_label
 83            target = np.concatenate([binary[None], boundaries], axis=0)
 84        else:
 85            target = boundaries
 86
 87        return target
 88
 89
 90# TODO smoothing
 91class BoundaryTransformWithIgnoreLabel:
 92    def __init__(self, ignore_label=-1, mode="thick", add_binary_target=False, ndim=None):
 93        self.ignore_label = ignore_label
 94        self.mode = mode
 95        self.ndim = ndim
 96        self.add_binary_target = add_binary_target
 97
 98    def __call__(self, labels):
 99        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
100        # calculate the normal boundaries
101        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
102
103        # calculate the boundaries for the ignore label
104        labels_ignore = (labels == self.ignore_label)
105        to_ignore_boundaries = skimage.segmentation.find_boundaries(labels_ignore, mode=self.mode)[None]
106
107        # mask the to-background-boundaries
108        boundaries = boundaries.astype(np.int8)
109        boundaries[to_ignore_boundaries] = self.ignore_label
110
111        if self.add_binary_target:
112            binary = labels_to_binary(labels).astype(boundaries.dtype)
113            binary[labels == self.ignore_label] = self.ignore_label
114            target = np.concatenate([binary[None], boundaries], axis=0)
115        else:
116            target = boundaries
117
118        return target
119
120
121# TODO affinity smoothing
122class AffinityTransform:
123    def __init__(self, offsets,
124                 ignore_label=None,
125                 add_binary_target=False,
126                 add_mask=False,
127                 include_ignore_transitions=False):
128        assert compute_affinities is not None
129        self.offsets = offsets
130        self.ndim = len(self.offsets[0])
131        assert self.ndim in (2, 3)
132
133        self.ignore_label = ignore_label
134        self.add_binary_target = add_binary_target
135        self.add_mask = add_mask
136        self.include_ignore_transitions = include_ignore_transitions
137
138    def add_ignore_transitions(self, affs, mask, labels):
139        ignore_seg = (labels == self.ignore_label).astype(labels.dtype)
140        ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets)
141        invalid_mask = np.logical_not(invalid_mask)
142        # NOTE affinity convention returned by affogato: transitions are marked by 0
143        ignore_transitions = ignore_transitions == 0
144        ignore_transitions[invalid_mask] = 0
145        affs[ignore_transitions] = 1
146        mask[ignore_transitions] = 1
147        return affs, mask
148
149    def __call__(self, labels):
150        dtype = "uint64"
151        if np.dtype(labels.dtype) in (np.dtype("int16"), np.dtype("int32"), np.dtype("int64")):
152            dtype = "int64"
153        labels = ensure_spatial_array(labels, self.ndim, dtype=dtype)
154        affs, mask = compute_affinities(labels, self.offsets,
155                                        have_ignore_label=self.ignore_label is not None,
156                                        ignore_label=0 if self.ignore_label is None else self.ignore_label)
157        # we use the "disaffinity" convention for training; i.e. 1 means repulsive, 0 attractive
158        affs = 1. - affs
159
160        # remove transitions to the ignore label from the mask
161        if self.ignore_label is not None and self.include_ignore_transitions:
162            affs, mask = self.add_ignore_transitions(affs, mask, labels)
163
164        if self.add_binary_target:
165            binary = labels_to_binary(labels)[None].astype(affs.dtype)
166            assert binary.ndim == affs.ndim
167            affs = np.concatenate([binary, affs], axis=0)
168
169        if self.add_mask:
170            if self.add_binary_target:
171                if self.ignore_label is None:
172                    mask_for_bin = np.ones((1,) + labels.shape, dtype=mask.dtype)
173                else:
174                    mask_for_bin = (labels != self.ignore_label)[None].astype(mask.dtype)
175                assert mask.ndim == mask_for_bin.ndim
176                mask = np.concatenate([mask_for_bin, mask], axis=0)
177            assert affs.shape == mask.shape
178            affs = np.concatenate([affs, mask.astype(affs.dtype)], axis=0)
179
180        return affs
181
182
183class OneHotTransform:
184    def __init__(self, class_ids=None):
185        self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids
186
187    def __call__(self, labels):
188        class_ids = np.unique(labels).tolist() if self.class_ids is None else self.class_ids
189        n_classes = len(class_ids)
190        one_hot = np.zeros((n_classes,) + labels.shape, dtype="float32")
191        for i, class_id in enumerate(class_ids):
192            one_hot[i][labels == class_id] = 1.0
193        return one_hot
194
195
196class DistanceTransform:
197    """Compute distances to foreground in the labels.
198
199    Args:
200        distances: Whether to compute the absolute distances.
201        directed_distances: Whether to compute the directed distances (vector distances).
202        normalize: Whether to normalize the computed distances.
203        max_distance: Maximal distance at which to threshold the distances.
204        foreground_id: Label id to which the distance is compute.
205        invert Whether to invert the distances:
206        func: Normalization function for the distances.
207    """
208    eps = 1e-7
209
210    def __init__(
211        self,
212        distances: bool = True,
213        directed_distances: bool = False,
214        normalize: bool = True,
215        max_distance: Optional[float] = None,
216        foreground_id=1,
217        invert=False,
218        func=None
219    ):
220        if sum((distances, directed_distances)) == 0:
221            raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
222        self.directed_distances = directed_distances
223        self.distances = distances
224        self.normalize = normalize
225        self.max_distance = max_distance
226        self.foreground_id = foreground_id
227        self.invert = invert
228        self.func = func
229
230    def _compute_distances(self, directed_distances):
231        distances = np.linalg.norm(directed_distances, axis=0)
232        if self.max_distance is not None:
233            distances = np.clip(distances, 0, self.max_distance)
234        if self.normalize:
235            distances /= (distances.max() + self.eps)
236        if self.invert:
237            distances = distances.max() - distances
238        if self.func is not None:
239            distances = self.func(distances)
240        return distances
241
242    def _compute_directed_distances(self, directed_distances):
243        if self.max_distance is not None:
244            directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance)
245        if self.normalize:
246            directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps)
247        if self.invert:
248            directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances
249        if self.func is not None:
250            directed_distances = self.func(directed_distances)
251        return directed_distances
252
253    def _get_distances_for_empty_labels(self, labels):
254        shape = labels.shape
255        fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2)
256        data = np.full((labels.ndim,) + shape, fill_value)
257        return data
258
259    def __call__(self, labels):
260        distance_mask = (labels == self.foreground_id).astype("uint32")
261        # the distances are not computed corrected if they are all zero
262        # so this case needs to be handled separately
263        if distance_mask.sum() == 0:
264            directed_distances = self._get_distances_for_empty_labels(labels)
265        else:
266            ndim = distance_mask.ndim
267            to_channel_first = (ndim,) + tuple(range(ndim))
268            directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first)
269
270        if self.distances:
271            distances = self._compute_distances(directed_distances)
272
273        if self.directed_distances:
274            directed_distances = self._compute_directed_distances(directed_distances)
275
276        if self.distances and self.directed_distances:
277            return np.concatenate((distances[None], directed_distances), axis=0)
278        if self.distances:
279            return distances
280        if self.directed_distances:
281            return directed_distances
282
283
284class PerObjectDistanceTransform:
285    """Compute normalized distances per object in a segmentation.
286
287    Args:
288        distances: Whether to compute the undirected distances.
289        boundary_distances: Whether to compute the distances to the object boundaries.
290        directed_distances: Whether to compute the directed distances (vector distances).
291        foreground: Whether to return a foreground channel.
292        apply_label: Whether to apply connected components to the labels before computing distances.
293        correct_centers: Whether to correct centers that are not in the objects.
294        min_size: Minimal size of objects for distance calculdation.
295        distance_fill_value: Fill value for the distances outside of objects.
296    """
297    eps = 1e-7
298
299    def __init__(
300        self,
301        distances=True,
302        boundary_distances=True,
303        directed_distances=False,
304        foreground=True,
305        instances=False,
306        apply_label=True,
307        correct_centers=True,
308        min_size=0,
309        distance_fill_value=1.0,
310    ):
311        if sum([distances, directed_distances, boundary_distances]) == 0:
312            raise ValueError("At least one of distances or directed distances has to be passed.")
313        self.distances = distances
314        self.boundary_distances = boundary_distances
315        self.directed_distances = directed_distances
316        self.foreground = foreground
317        self.instances = instances
318
319        self.apply_label = apply_label
320        self.correct_centers = correct_centers
321        self.min_size = min_size
322        self.distance_fill_value = distance_fill_value
323
324    def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
325        # Crop the mask and generate array with the correct center.
326        cropped_mask = mask[bb]
327        cropped_center = tuple(ce - b.start for ce, b in zip(center, bb))
328
329        # The centroid might not be inside of the object.
330        # In this case we correct the center by taking the maximum of the distance to the boundary.
331        # Note: the centroid is still the best estimate for the center, as long as it's in the object.
332        correct_center = not cropped_mask[cropped_center]
333
334        # Compute the boundary distances if necessary.
335        # (Either if we need to correct the center, or compute the boundary distances anyways.)
336        if correct_center or self.boundary_distances:
337            # Crop the boundary mask and compute the boundary distances.
338            cropped_boundary_mask = boundaries[bb]
339            boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask)
340            boundary_distances[~cropped_mask] = 0
341            max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape)
342
343        # Set the crop center to the max dist point
344        if correct_center:
345            # Find the center (= maximal distance from the boundaries).
346            cropped_center = max_dist_point
347
348        cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32")
349        cropped_center_mask[cropped_center] = 1
350
351        # Compute the directed distances,
352        if self.distances or self.directed_distances:
353            this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask)
354        else:
355            this_distances = None
356
357        # Keep only the specified distances:
358        if self.distances and self.directed_distances:  # all distances
359            # Compute the undirected ditacnes from directed distances and concatenate,
360            undir = np.linalg.norm(this_distances, axis=-1, keepdims=True)
361            this_distances = np.concatenate([undir, this_distances], axis=-1)
362
363        elif self.distances:  # only undirected distances
364            # Compute the undirected distances from directed distances and keep only them.
365            this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True)
366
367        elif self.directed_distances:  # only directed distances
368            pass  # We don't have to do anything becasue the directed distances are already computed.
369
370        # Add an extra channel for the boundary distances if specified.
371        if self.boundary_distances:
372            boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None]
373            if this_distances is None:
374                this_distances = boundary_distances
375            else:
376                this_distances = np.concatenate([this_distances, boundary_distances], axis=-1)
377
378        # Set distances outside of the mask to zero.
379        this_distances[~cropped_mask] = 0
380
381        # Normalize the distances.
382        spatial_axes = tuple(range(mask.ndim))
383        this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps)
384
385        # Set the distance values in the global result.
386        distances[bb][cropped_mask] = this_distances[cropped_mask]
387
388        return distances
389
390    def __call__(self, labels):
391        # Apply label (connected components) if specified.
392        if self.apply_label:
393            labels = skimage.measure.label(labels).astype("uint32")
394        else:  # Otherwise just relabel the segmentation.
395            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
396
397        # Filter out small objects if min_size is specified.
398        if self.min_size > 0:
399            ids, sizes = np.unique(labels, return_counts=True)
400            discard_ids = ids[sizes < self.min_size]
401            labels[np.isin(labels, discard_ids)] = 0
402            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
403
404        # Compute the boundaries. They will be used to determine the most central point,
405        # and if 'self.boundary_distances is True' to add the boundary distances.
406        boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32")
407
408        # Compute region properties to derive bounding boxes and centers.
409        ndim = labels.ndim
410        props = skimage.measure.regionprops(labels)
411        bounding_boxes = {
412            prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim))
413            for prop in props
414        }
415
416        # Compute the object centers from centroids.
417        centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props}
418
419        # Compute how many distance channels we have.
420        n_channels = 0
421        if self.distances:  # We need one channel for the overall distances.
422            n_channels += 1
423        if self.boundary_distances:  # We need one channel for the boundary distances.
424            n_channels += 1
425        if self.directed_distances:  # And ndim channels for directed distances.
426            n_channels += ndim
427
428        # Compute the per object distances.
429        distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32")
430        for prop in props:
431            label_id = prop.label
432            mask = labels == label_id
433            distances = self.compute_normalized_object_distances(
434                mask, boundaries, bounding_boxes[label_id], centers[label_id], distances
435            )
436
437        # Bring the distance channel to the first dimension.
438        to_channel_first = (ndim,) + tuple(range(ndim))
439        distances = distances.transpose(to_channel_first)
440
441        # Add the foreground mask as first channel if specified.
442        if self.foreground:
443            binary_labels = (labels > 0).astype("float32")
444            distances = np.concatenate([binary_labels[None], distances], axis=0)
445
446        if self.instances:
447            distances = np.concatenate([labels[None], distances], axis=0)
448
449        return distances
def connected_components(labels, ndim=None, ensure_zero=False):
17def connected_components(labels, ndim=None, ensure_zero=False):
18    labels = ensure_array(labels) if ndim is None else ensure_spatial_array(labels, ndim)
19    labels = skimage.measure.label(labels)
20    if ensure_zero and 0 not in labels:
21        labels -= 1
22    return labels
def labels_to_binary(labels, background_label=0):
25def labels_to_binary(labels, background_label=0):
26    return (labels != background_label).astype(labels.dtype)
def label_consecutive(labels, with_background=True):
29def label_consecutive(labels, with_background=True):
30    if with_background:
31        seg = skimage.segmentation.relabel_sequential(labels)[0]
32    else:
33        if 0 in labels:
34            labels += 1
35        seg = skimage.segmentation.relabel_sequential(labels)[0]
36        assert seg.min() == 1
37        seg -= 1
38    return seg
class BoundaryTransform:
42class BoundaryTransform:
43    def __init__(self, mode="thick", add_binary_target=False, ndim=None):
44        self.mode = mode
45        self.add_binary_target = add_binary_target
46        self.ndim = ndim
47
48    def __call__(self, labels):
49        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
50        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
51        if self.add_binary_target:
52            binary = labels_to_binary(labels)[None].astype(boundaries.dtype)
53            target = np.concatenate([binary, boundaries], axis=0)
54        else:
55            target = boundaries
56        return target
BoundaryTransform(mode='thick', add_binary_target=False, ndim=None)
43    def __init__(self, mode="thick", add_binary_target=False, ndim=None):
44        self.mode = mode
45        self.add_binary_target = add_binary_target
46        self.ndim = ndim
mode
add_binary_target
ndim
class NoToBackgroundBoundaryTransform:
60class NoToBackgroundBoundaryTransform:
61    def __init__(self, bg_label=0, mask_label=-1, mode="thick", add_binary_target=False, ndim=None):
62        self.bg_label = bg_label
63        self.mask_label = mask_label
64        self.mode = mode
65        self.ndim = ndim
66        self.add_binary_target = add_binary_target
67
68    def __call__(self, labels):
69        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
70        # calc normal boundaries
71        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
72
73        # make label image binary and calculate to-background-boundaries
74        labels_binary = (labels != self.bg_label)
75        to_bg_boundaries = skimage.segmentation.find_boundaries(labels_binary, mode=self.mode)[None]
76
77        # mask the to-background-boundaries
78        boundaries = boundaries.astype(np.int8)
79        boundaries[to_bg_boundaries] = self.mask_label
80
81        if self.add_binary_target:
82            binary = labels_to_binary(labels, self.bg_label).astype(boundaries.dtype)
83            binary[labels == self.mask_label] = self.mask_label
84            target = np.concatenate([binary[None], boundaries], axis=0)
85        else:
86            target = boundaries
87
88        return target
NoToBackgroundBoundaryTransform( bg_label=0, mask_label=-1, mode='thick', add_binary_target=False, ndim=None)
61    def __init__(self, bg_label=0, mask_label=-1, mode="thick", add_binary_target=False, ndim=None):
62        self.bg_label = bg_label
63        self.mask_label = mask_label
64        self.mode = mode
65        self.ndim = ndim
66        self.add_binary_target = add_binary_target
bg_label
mask_label
mode
ndim
add_binary_target
class BoundaryTransformWithIgnoreLabel:
 92class BoundaryTransformWithIgnoreLabel:
 93    def __init__(self, ignore_label=-1, mode="thick", add_binary_target=False, ndim=None):
 94        self.ignore_label = ignore_label
 95        self.mode = mode
 96        self.ndim = ndim
 97        self.add_binary_target = add_binary_target
 98
 99    def __call__(self, labels):
100        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
101        # calculate the normal boundaries
102        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
103
104        # calculate the boundaries for the ignore label
105        labels_ignore = (labels == self.ignore_label)
106        to_ignore_boundaries = skimage.segmentation.find_boundaries(labels_ignore, mode=self.mode)[None]
107
108        # mask the to-background-boundaries
109        boundaries = boundaries.astype(np.int8)
110        boundaries[to_ignore_boundaries] = self.ignore_label
111
112        if self.add_binary_target:
113            binary = labels_to_binary(labels).astype(boundaries.dtype)
114            binary[labels == self.ignore_label] = self.ignore_label
115            target = np.concatenate([binary[None], boundaries], axis=0)
116        else:
117            target = boundaries
118
119        return target
BoundaryTransformWithIgnoreLabel(ignore_label=-1, mode='thick', add_binary_target=False, ndim=None)
93    def __init__(self, ignore_label=-1, mode="thick", add_binary_target=False, ndim=None):
94        self.ignore_label = ignore_label
95        self.mode = mode
96        self.ndim = ndim
97        self.add_binary_target = add_binary_target
ignore_label
mode
ndim
add_binary_target
class AffinityTransform:
123class AffinityTransform:
124    def __init__(self, offsets,
125                 ignore_label=None,
126                 add_binary_target=False,
127                 add_mask=False,
128                 include_ignore_transitions=False):
129        assert compute_affinities is not None
130        self.offsets = offsets
131        self.ndim = len(self.offsets[0])
132        assert self.ndim in (2, 3)
133
134        self.ignore_label = ignore_label
135        self.add_binary_target = add_binary_target
136        self.add_mask = add_mask
137        self.include_ignore_transitions = include_ignore_transitions
138
139    def add_ignore_transitions(self, affs, mask, labels):
140        ignore_seg = (labels == self.ignore_label).astype(labels.dtype)
141        ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets)
142        invalid_mask = np.logical_not(invalid_mask)
143        # NOTE affinity convention returned by affogato: transitions are marked by 0
144        ignore_transitions = ignore_transitions == 0
145        ignore_transitions[invalid_mask] = 0
146        affs[ignore_transitions] = 1
147        mask[ignore_transitions] = 1
148        return affs, mask
149
150    def __call__(self, labels):
151        dtype = "uint64"
152        if np.dtype(labels.dtype) in (np.dtype("int16"), np.dtype("int32"), np.dtype("int64")):
153            dtype = "int64"
154        labels = ensure_spatial_array(labels, self.ndim, dtype=dtype)
155        affs, mask = compute_affinities(labels, self.offsets,
156                                        have_ignore_label=self.ignore_label is not None,
157                                        ignore_label=0 if self.ignore_label is None else self.ignore_label)
158        # we use the "disaffinity" convention for training; i.e. 1 means repulsive, 0 attractive
159        affs = 1. - affs
160
161        # remove transitions to the ignore label from the mask
162        if self.ignore_label is not None and self.include_ignore_transitions:
163            affs, mask = self.add_ignore_transitions(affs, mask, labels)
164
165        if self.add_binary_target:
166            binary = labels_to_binary(labels)[None].astype(affs.dtype)
167            assert binary.ndim == affs.ndim
168            affs = np.concatenate([binary, affs], axis=0)
169
170        if self.add_mask:
171            if self.add_binary_target:
172                if self.ignore_label is None:
173                    mask_for_bin = np.ones((1,) + labels.shape, dtype=mask.dtype)
174                else:
175                    mask_for_bin = (labels != self.ignore_label)[None].astype(mask.dtype)
176                assert mask.ndim == mask_for_bin.ndim
177                mask = np.concatenate([mask_for_bin, mask], axis=0)
178            assert affs.shape == mask.shape
179            affs = np.concatenate([affs, mask.astype(affs.dtype)], axis=0)
180
181        return affs
AffinityTransform( offsets, ignore_label=None, add_binary_target=False, add_mask=False, include_ignore_transitions=False)
124    def __init__(self, offsets,
125                 ignore_label=None,
126                 add_binary_target=False,
127                 add_mask=False,
128                 include_ignore_transitions=False):
129        assert compute_affinities is not None
130        self.offsets = offsets
131        self.ndim = len(self.offsets[0])
132        assert self.ndim in (2, 3)
133
134        self.ignore_label = ignore_label
135        self.add_binary_target = add_binary_target
136        self.add_mask = add_mask
137        self.include_ignore_transitions = include_ignore_transitions
offsets
ndim
ignore_label
add_binary_target
add_mask
include_ignore_transitions
def add_ignore_transitions(self, affs, mask, labels):
139    def add_ignore_transitions(self, affs, mask, labels):
140        ignore_seg = (labels == self.ignore_label).astype(labels.dtype)
141        ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets)
142        invalid_mask = np.logical_not(invalid_mask)
143        # NOTE affinity convention returned by affogato: transitions are marked by 0
144        ignore_transitions = ignore_transitions == 0
145        ignore_transitions[invalid_mask] = 0
146        affs[ignore_transitions] = 1
147        mask[ignore_transitions] = 1
148        return affs, mask
class OneHotTransform:
184class OneHotTransform:
185    def __init__(self, class_ids=None):
186        self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids
187
188    def __call__(self, labels):
189        class_ids = np.unique(labels).tolist() if self.class_ids is None else self.class_ids
190        n_classes = len(class_ids)
191        one_hot = np.zeros((n_classes,) + labels.shape, dtype="float32")
192        for i, class_id in enumerate(class_ids):
193            one_hot[i][labels == class_id] = 1.0
194        return one_hot
OneHotTransform(class_ids=None)
185    def __init__(self, class_ids=None):
186        self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids
class_ids
class DistanceTransform:
197class DistanceTransform:
198    """Compute distances to foreground in the labels.
199
200    Args:
201        distances: Whether to compute the absolute distances.
202        directed_distances: Whether to compute the directed distances (vector distances).
203        normalize: Whether to normalize the computed distances.
204        max_distance: Maximal distance at which to threshold the distances.
205        foreground_id: Label id to which the distance is compute.
206        invert Whether to invert the distances:
207        func: Normalization function for the distances.
208    """
209    eps = 1e-7
210
211    def __init__(
212        self,
213        distances: bool = True,
214        directed_distances: bool = False,
215        normalize: bool = True,
216        max_distance: Optional[float] = None,
217        foreground_id=1,
218        invert=False,
219        func=None
220    ):
221        if sum((distances, directed_distances)) == 0:
222            raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
223        self.directed_distances = directed_distances
224        self.distances = distances
225        self.normalize = normalize
226        self.max_distance = max_distance
227        self.foreground_id = foreground_id
228        self.invert = invert
229        self.func = func
230
231    def _compute_distances(self, directed_distances):
232        distances = np.linalg.norm(directed_distances, axis=0)
233        if self.max_distance is not None:
234            distances = np.clip(distances, 0, self.max_distance)
235        if self.normalize:
236            distances /= (distances.max() + self.eps)
237        if self.invert:
238            distances = distances.max() - distances
239        if self.func is not None:
240            distances = self.func(distances)
241        return distances
242
243    def _compute_directed_distances(self, directed_distances):
244        if self.max_distance is not None:
245            directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance)
246        if self.normalize:
247            directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps)
248        if self.invert:
249            directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances
250        if self.func is not None:
251            directed_distances = self.func(directed_distances)
252        return directed_distances
253
254    def _get_distances_for_empty_labels(self, labels):
255        shape = labels.shape
256        fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2)
257        data = np.full((labels.ndim,) + shape, fill_value)
258        return data
259
260    def __call__(self, labels):
261        distance_mask = (labels == self.foreground_id).astype("uint32")
262        # the distances are not computed corrected if they are all zero
263        # so this case needs to be handled separately
264        if distance_mask.sum() == 0:
265            directed_distances = self._get_distances_for_empty_labels(labels)
266        else:
267            ndim = distance_mask.ndim
268            to_channel_first = (ndim,) + tuple(range(ndim))
269            directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first)
270
271        if self.distances:
272            distances = self._compute_distances(directed_distances)
273
274        if self.directed_distances:
275            directed_distances = self._compute_directed_distances(directed_distances)
276
277        if self.distances and self.directed_distances:
278            return np.concatenate((distances[None], directed_distances), axis=0)
279        if self.distances:
280            return distances
281        if self.directed_distances:
282            return directed_distances

Compute distances to foreground in the labels.

Arguments:
  • distances: Whether to compute the absolute distances.
  • directed_distances: Whether to compute the directed distances (vector distances).
  • normalize: Whether to normalize the computed distances.
  • max_distance: Maximal distance at which to threshold the distances.
  • foreground_id: Label id to which the distance is compute.
  • invert Whether to invert the distances:
  • func: Normalization function for the distances.
DistanceTransform( distances: bool = True, directed_distances: bool = False, normalize: bool = True, max_distance: Optional[float] = None, foreground_id=1, invert=False, func=None)
211    def __init__(
212        self,
213        distances: bool = True,
214        directed_distances: bool = False,
215        normalize: bool = True,
216        max_distance: Optional[float] = None,
217        foreground_id=1,
218        invert=False,
219        func=None
220    ):
221        if sum((distances, directed_distances)) == 0:
222            raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
223        self.directed_distances = directed_distances
224        self.distances = distances
225        self.normalize = normalize
226        self.max_distance = max_distance
227        self.foreground_id = foreground_id
228        self.invert = invert
229        self.func = func
eps = 1e-07
directed_distances
distances
normalize
max_distance
foreground_id
invert
func
class PerObjectDistanceTransform:
285class PerObjectDistanceTransform:
286    """Compute normalized distances per object in a segmentation.
287
288    Args:
289        distances: Whether to compute the undirected distances.
290        boundary_distances: Whether to compute the distances to the object boundaries.
291        directed_distances: Whether to compute the directed distances (vector distances).
292        foreground: Whether to return a foreground channel.
293        apply_label: Whether to apply connected components to the labels before computing distances.
294        correct_centers: Whether to correct centers that are not in the objects.
295        min_size: Minimal size of objects for distance calculdation.
296        distance_fill_value: Fill value for the distances outside of objects.
297    """
298    eps = 1e-7
299
300    def __init__(
301        self,
302        distances=True,
303        boundary_distances=True,
304        directed_distances=False,
305        foreground=True,
306        instances=False,
307        apply_label=True,
308        correct_centers=True,
309        min_size=0,
310        distance_fill_value=1.0,
311    ):
312        if sum([distances, directed_distances, boundary_distances]) == 0:
313            raise ValueError("At least one of distances or directed distances has to be passed.")
314        self.distances = distances
315        self.boundary_distances = boundary_distances
316        self.directed_distances = directed_distances
317        self.foreground = foreground
318        self.instances = instances
319
320        self.apply_label = apply_label
321        self.correct_centers = correct_centers
322        self.min_size = min_size
323        self.distance_fill_value = distance_fill_value
324
325    def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
326        # Crop the mask and generate array with the correct center.
327        cropped_mask = mask[bb]
328        cropped_center = tuple(ce - b.start for ce, b in zip(center, bb))
329
330        # The centroid might not be inside of the object.
331        # In this case we correct the center by taking the maximum of the distance to the boundary.
332        # Note: the centroid is still the best estimate for the center, as long as it's in the object.
333        correct_center = not cropped_mask[cropped_center]
334
335        # Compute the boundary distances if necessary.
336        # (Either if we need to correct the center, or compute the boundary distances anyways.)
337        if correct_center or self.boundary_distances:
338            # Crop the boundary mask and compute the boundary distances.
339            cropped_boundary_mask = boundaries[bb]
340            boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask)
341            boundary_distances[~cropped_mask] = 0
342            max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape)
343
344        # Set the crop center to the max dist point
345        if correct_center:
346            # Find the center (= maximal distance from the boundaries).
347            cropped_center = max_dist_point
348
349        cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32")
350        cropped_center_mask[cropped_center] = 1
351
352        # Compute the directed distances,
353        if self.distances or self.directed_distances:
354            this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask)
355        else:
356            this_distances = None
357
358        # Keep only the specified distances:
359        if self.distances and self.directed_distances:  # all distances
360            # Compute the undirected ditacnes from directed distances and concatenate,
361            undir = np.linalg.norm(this_distances, axis=-1, keepdims=True)
362            this_distances = np.concatenate([undir, this_distances], axis=-1)
363
364        elif self.distances:  # only undirected distances
365            # Compute the undirected distances from directed distances and keep only them.
366            this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True)
367
368        elif self.directed_distances:  # only directed distances
369            pass  # We don't have to do anything becasue the directed distances are already computed.
370
371        # Add an extra channel for the boundary distances if specified.
372        if self.boundary_distances:
373            boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None]
374            if this_distances is None:
375                this_distances = boundary_distances
376            else:
377                this_distances = np.concatenate([this_distances, boundary_distances], axis=-1)
378
379        # Set distances outside of the mask to zero.
380        this_distances[~cropped_mask] = 0
381
382        # Normalize the distances.
383        spatial_axes = tuple(range(mask.ndim))
384        this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps)
385
386        # Set the distance values in the global result.
387        distances[bb][cropped_mask] = this_distances[cropped_mask]
388
389        return distances
390
391    def __call__(self, labels):
392        # Apply label (connected components) if specified.
393        if self.apply_label:
394            labels = skimage.measure.label(labels).astype("uint32")
395        else:  # Otherwise just relabel the segmentation.
396            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
397
398        # Filter out small objects if min_size is specified.
399        if self.min_size > 0:
400            ids, sizes = np.unique(labels, return_counts=True)
401            discard_ids = ids[sizes < self.min_size]
402            labels[np.isin(labels, discard_ids)] = 0
403            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
404
405        # Compute the boundaries. They will be used to determine the most central point,
406        # and if 'self.boundary_distances is True' to add the boundary distances.
407        boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32")
408
409        # Compute region properties to derive bounding boxes and centers.
410        ndim = labels.ndim
411        props = skimage.measure.regionprops(labels)
412        bounding_boxes = {
413            prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim))
414            for prop in props
415        }
416
417        # Compute the object centers from centroids.
418        centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props}
419
420        # Compute how many distance channels we have.
421        n_channels = 0
422        if self.distances:  # We need one channel for the overall distances.
423            n_channels += 1
424        if self.boundary_distances:  # We need one channel for the boundary distances.
425            n_channels += 1
426        if self.directed_distances:  # And ndim channels for directed distances.
427            n_channels += ndim
428
429        # Compute the per object distances.
430        distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32")
431        for prop in props:
432            label_id = prop.label
433            mask = labels == label_id
434            distances = self.compute_normalized_object_distances(
435                mask, boundaries, bounding_boxes[label_id], centers[label_id], distances
436            )
437
438        # Bring the distance channel to the first dimension.
439        to_channel_first = (ndim,) + tuple(range(ndim))
440        distances = distances.transpose(to_channel_first)
441
442        # Add the foreground mask as first channel if specified.
443        if self.foreground:
444            binary_labels = (labels > 0).astype("float32")
445            distances = np.concatenate([binary_labels[None], distances], axis=0)
446
447        if self.instances:
448            distances = np.concatenate([labels[None], distances], axis=0)
449
450        return distances

Compute normalized distances per object in a segmentation.

Arguments:
  • distances: Whether to compute the undirected distances.
  • boundary_distances: Whether to compute the distances to the object boundaries.
  • directed_distances: Whether to compute the directed distances (vector distances).
  • foreground: Whether to return a foreground channel.
  • apply_label: Whether to apply connected components to the labels before computing distances.
  • correct_centers: Whether to correct centers that are not in the objects.
  • min_size: Minimal size of objects for distance calculdation.
  • distance_fill_value: Fill value for the distances outside of objects.
PerObjectDistanceTransform( distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=False, apply_label=True, correct_centers=True, min_size=0, distance_fill_value=1.0)
300    def __init__(
301        self,
302        distances=True,
303        boundary_distances=True,
304        directed_distances=False,
305        foreground=True,
306        instances=False,
307        apply_label=True,
308        correct_centers=True,
309        min_size=0,
310        distance_fill_value=1.0,
311    ):
312        if sum([distances, directed_distances, boundary_distances]) == 0:
313            raise ValueError("At least one of distances or directed distances has to be passed.")
314        self.distances = distances
315        self.boundary_distances = boundary_distances
316        self.directed_distances = directed_distances
317        self.foreground = foreground
318        self.instances = instances
319
320        self.apply_label = apply_label
321        self.correct_centers = correct_centers
322        self.min_size = min_size
323        self.distance_fill_value = distance_fill_value
eps = 1e-07
distances
boundary_distances
directed_distances
foreground
instances
apply_label
correct_centers
min_size
distance_fill_value
def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
325    def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
326        # Crop the mask and generate array with the correct center.
327        cropped_mask = mask[bb]
328        cropped_center = tuple(ce - b.start for ce, b in zip(center, bb))
329
330        # The centroid might not be inside of the object.
331        # In this case we correct the center by taking the maximum of the distance to the boundary.
332        # Note: the centroid is still the best estimate for the center, as long as it's in the object.
333        correct_center = not cropped_mask[cropped_center]
334
335        # Compute the boundary distances if necessary.
336        # (Either if we need to correct the center, or compute the boundary distances anyways.)
337        if correct_center or self.boundary_distances:
338            # Crop the boundary mask and compute the boundary distances.
339            cropped_boundary_mask = boundaries[bb]
340            boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask)
341            boundary_distances[~cropped_mask] = 0
342            max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape)
343
344        # Set the crop center to the max dist point
345        if correct_center:
346            # Find the center (= maximal distance from the boundaries).
347            cropped_center = max_dist_point
348
349        cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32")
350        cropped_center_mask[cropped_center] = 1
351
352        # Compute the directed distances,
353        if self.distances or self.directed_distances:
354            this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask)
355        else:
356            this_distances = None
357
358        # Keep only the specified distances:
359        if self.distances and self.directed_distances:  # all distances
360            # Compute the undirected ditacnes from directed distances and concatenate,
361            undir = np.linalg.norm(this_distances, axis=-1, keepdims=True)
362            this_distances = np.concatenate([undir, this_distances], axis=-1)
363
364        elif self.distances:  # only undirected distances
365            # Compute the undirected distances from directed distances and keep only them.
366            this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True)
367
368        elif self.directed_distances:  # only directed distances
369            pass  # We don't have to do anything becasue the directed distances are already computed.
370
371        # Add an extra channel for the boundary distances if specified.
372        if self.boundary_distances:
373            boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None]
374            if this_distances is None:
375                this_distances = boundary_distances
376            else:
377                this_distances = np.concatenate([this_distances, boundary_distances], axis=-1)
378
379        # Set distances outside of the mask to zero.
380        this_distances[~cropped_mask] = 0
381
382        # Normalize the distances.
383        spatial_axes = tuple(range(mask.ndim))
384        this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps)
385
386        # Set the distance values in the global result.
387        distances[bb][cropped_mask] = this_distances[cropped_mask]
388
389        return distances