torch_em.transform.label

  1from typing import Callable, List, Optional, Sequence, Union
  2
  3import numpy as np
  4import skimage.measure
  5import skimage.segmentation
  6import vigra
  7
  8from ..util import ensure_array, ensure_spatial_array
  9
 10try:
 11    from affogato.affinities import compute_affinities
 12except ImportError:
 13    compute_affinities = None
 14
 15
 16def connected_components(labels: np.ndarray, ndim: Optional[int] = None, ensure_zero: bool = False) -> np.ndarray:
 17    """Apply connected components to a segmentation.
 18
 19    Args:
 20        labels: The input segmentation.
 21        ndim: The expected dimensionality of the data.
 22        ensure_zero: Whether to ensure that the data has a zero label.
 23
 24    Returns:
 25        The segmentation after connected components.
 26    """
 27    labels = ensure_array(labels) if ndim is None else ensure_spatial_array(labels, ndim)
 28    labels = skimage.measure.label(labels)
 29    if ensure_zero and 0 not in labels:
 30        labels -= 1
 31    return labels
 32
 33
 34def labels_to_binary(labels: np.ndarray, background_label: int = 0) -> np.ndarray:
 35    """Transform a segmentation to binary labels.
 36
 37    Args:
 38        labels: The input segmentation.
 39        background_label: The id of the background label.
 40
 41    Returns:
 42        The binary segmentation.
 43    """
 44    return (labels != background_label).astype(labels.dtype)
 45
 46
 47def label_consecutive(labels: np.ndarray, with_background: bool = True) -> np.ndarray:
 48    """Ensure that the input segmentation is labeled consecutively.
 49
 50    Args:
 51        labels: The input segmentation.
 52        with_background: Whether this segmentation has a background label.
 53
 54    Returns:
 55        The consecutively labeled segmentation.
 56    """
 57    if with_background:
 58        seg = skimage.segmentation.relabel_sequential(labels)[0]
 59    else:
 60        if 0 in labels:
 61            labels += 1
 62        seg = skimage.segmentation.relabel_sequential(labels)[0]
 63        assert seg.min() == 1
 64        seg -= 1
 65    return seg
 66
 67
 68class MinSizeLabelTransform:
 69    """Transformation to filter out objects smaller than a minimal size from the segmentation.
 70
 71    Args:
 72        min_size: The minimal object size of the segmentation.
 73        ndim: The dimensionality of the segmentation.
 74        ensure_zero: Ensure that the segmentation contains the id zero.
 75    """
 76    def __init__(self, min_size: Optional[int] = None, ndim: Optional[int] = None, ensure_zero: bool = False):
 77        self.min_size = min_size
 78        self.ndim = ndim
 79        self.ensure_zero = ensure_zero
 80
 81    def __call__(self, labels: np.ndarray) -> np.ndarray:
 82        """Filter out small objects from segmentation.
 83
 84        Args:
 85            labels: The input segmentation.
 86
 87        Returns:
 88            The size filtered segmentation.
 89        """
 90        components = connected_components(labels, ndim=self.ndim, ensure_zero=self.ensure_zero)
 91        if self.min_size is not None:
 92            ids, sizes = np.unique(components, return_counts=True)
 93            filter_ids = ids[sizes < self.min_size]
 94            components[np.isin(components, filter_ids)] = 0
 95            components, _, _ = skimage.segmentation.relabel_sequential(components)
 96        return components
 97
 98
 99# TODO smoothing
100class BoundaryTransform:
101    """Transformation to convert an instance segmentation into boundaries.
102
103    Args:
104        mode: The mode for converting the segmentation to boundaries.
105        add_binary_target: Whether to add a binary mask channel to the transformation output.
106        ndim: The expected dimensionality of the data.
107    """
108    def __init__(self, mode: str = "thick", add_binary_target: bool = False, ndim: Optional[int] = None):
109        self.mode = mode
110        self.add_binary_target = add_binary_target
111        self.ndim = ndim
112
113    def __call__(self, labels: np.ndarray) -> np.ndarray:
114        """Apply the boundary transformation to an input segmentation.
115
116        Args:
117            labels: The input segmentation.
118
119        Returns:
120            The boundaries.
121        """
122        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
123        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
124        if self.add_binary_target:
125            binary = labels_to_binary(labels)[None].astype(boundaries.dtype)
126            target = np.concatenate([binary, boundaries], axis=0)
127        else:
128            target = boundaries
129        return target
130
131
132# TODO smoothing
133class NoToBackgroundBoundaryTransform:
134    """Transformation to convert an instance segmentation into boundaries.
135
136    This transformation sets boundaries with the ignore label to the ignore label
137    in the output of the transformation.
138
139    Args:
140        bg_label: The background label.
141        mask_label: The mask label.
142        mode: The mode for converting the segmentation to boundaries.
143        add_binary_target: Whether to add a binary mask channel to the transformation output.
144        ndim: The expected dimensionality of the data.
145    """
146    def __init__(
147        self,
148        bg_label: int = 0,
149        mask_label: int = -1,
150        mode: str = "thick",
151        add_binary_target: bool = False,
152        ndim: Optional[int] = None,
153    ):
154        self.bg_label = bg_label
155        self.mask_label = mask_label
156        self.mode = mode
157        self.ndim = ndim
158        self.add_binary_target = add_binary_target
159
160    def __call__(self, labels: np.ndarray) -> np.ndarray:
161        """Apply the boundary transformation to an input segmentation.
162
163        Args:
164            labels: The input segmentation.
165
166        Returns:
167            The boundaries.
168        """
169        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
170        # calc normal boundaries
171        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
172
173        # make label image binary and calculate to-background-boundaries
174        labels_binary = (labels != self.bg_label)
175        to_bg_boundaries = skimage.segmentation.find_boundaries(labels_binary, mode=self.mode)[None]
176
177        # mask the to-background-boundaries
178        boundaries = boundaries.astype(np.int8)
179        boundaries[to_bg_boundaries] = self.mask_label
180
181        if self.add_binary_target:
182            binary = labels_to_binary(labels, self.bg_label).astype(boundaries.dtype)
183            binary[labels == self.mask_label] = self.mask_label
184            target = np.concatenate([binary[None], boundaries], axis=0)
185        else:
186            target = boundaries
187
188        return target
189
190
191# TODO smoothing
192class BoundaryTransformWithIgnoreLabel:
193    """Transformation to convert an instance segmentation into boundaries.
194
195    This transformation sets boundaries with the ignore label to the ignore label
196    in the output of the transformation.
197
198    Args:
199        ignore_label: The ignore label.
200        mode: The mode for converting the segmentation to boundaries.
201        add_binary_target: Whether to add a binary mask channel to the transformation output.
202        ndim: The expected dimensionality of the data.
203    """
204    def __init__(
205        self,
206        ignore_label: int = -1,
207        mode: str = "thick",
208        add_binary_target: bool = False,
209        ndim: Optional[int] = None,
210    ):
211        self.ignore_label = ignore_label
212        self.mode = mode
213        self.ndim = ndim
214        self.add_binary_target = add_binary_target
215
216    def __call__(self, labels: np.ndarray) -> np.ndarray:
217        """Apply the boundary transformation to an input segmentation.
218
219        Args:
220            labels: The input segmentation.
221
222        Returns:
223            The boundaries.
224        """
225        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
226        # calculate the normal boundaries
227        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
228
229        # calculate the boundaries for the ignore label
230        labels_ignore = (labels == self.ignore_label)
231        to_ignore_boundaries = skimage.segmentation.find_boundaries(labels_ignore, mode=self.mode)[None]
232
233        # mask the to-background-boundaries
234        boundaries = boundaries.astype(np.int8)
235        boundaries[to_ignore_boundaries] = self.ignore_label
236
237        if self.add_binary_target:
238            binary = labels_to_binary(labels).astype(boundaries.dtype)
239            binary[labels == self.ignore_label] = self.ignore_label
240            target = np.concatenate([binary[None], boundaries], axis=0)
241        else:
242            target = boundaries
243
244        return target
245
246
247# TODO affinity smoothing
248class AffinityTransform:
249    """Transformation to compute affinities from a segmentation.
250
251    Args:
252        offsets: The offsets for computing affinities.
253        ignore_label: The ignore label to use for computing the ignore mask.
254        add_binary_target: Whether to add a binary channel to the affinities.
255        add_mask: Whether to add the ignore mask as extra output channels.
256        include_ignore_transitions: Whether transitions to the ignore label
257            should be positive in the ignore mask or negative in it.
258    """
259    def __init__(
260        self,
261        offsets: List[List[int]],
262        ignore_label: Optional[bool] = None,
263        add_binary_target: bool = False,
264        add_mask: bool = False,
265        include_ignore_transitions: bool = False,
266    ):
267        assert compute_affinities is not None
268        self.offsets = offsets
269        self.ndim = len(self.offsets[0])
270        assert self.ndim in (2, 3)
271
272        self.ignore_label = ignore_label
273        self.add_binary_target = add_binary_target
274        self.add_mask = add_mask
275        self.include_ignore_transitions = include_ignore_transitions
276
277    def add_ignore_transitions(self, affs, mask, labels):
278        """@private
279        """
280        ignore_seg = (labels == self.ignore_label).astype(labels.dtype)
281        ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets)
282        invalid_mask = np.logical_not(invalid_mask)
283        # NOTE affinity convention returned by affogato: transitions are marked by 0
284        ignore_transitions = ignore_transitions == 0
285        ignore_transitions[invalid_mask] = 0
286        affs[ignore_transitions] = 1
287        mask[ignore_transitions] = 1
288        return affs, mask
289
290    def __call__(self, labels: np.ndarray) -> np.ndarray:
291        """Compute the affinities.
292
293        Args:
294            labels: The segmentation.
295
296        Returns:
297            The affinities.
298        """
299        dtype = "uint64"
300        if np.dtype(labels.dtype) in (np.dtype("int16"), np.dtype("int32"), np.dtype("int64")):
301            dtype = "int64"
302        labels = ensure_spatial_array(labels, self.ndim, dtype=dtype)
303        affs, mask = compute_affinities(labels, self.offsets,
304                                        have_ignore_label=self.ignore_label is not None,
305                                        ignore_label=0 if self.ignore_label is None else self.ignore_label)
306        # we use the "disaffinity" convention for training; i.e. 1 means repulsive, 0 attractive
307        affs = 1. - affs
308
309        # remove transitions to the ignore label from the mask
310        if self.ignore_label is not None and self.include_ignore_transitions:
311            affs, mask = self.add_ignore_transitions(affs, mask, labels)
312
313        if self.add_binary_target:
314            binary = labels_to_binary(labels)[None].astype(affs.dtype)
315            assert binary.ndim == affs.ndim
316            affs = np.concatenate([binary, affs], axis=0)
317
318        if self.add_mask:
319            if self.add_binary_target:
320                if self.ignore_label is None:
321                    mask_for_bin = np.ones((1,) + labels.shape, dtype=mask.dtype)
322                else:
323                    mask_for_bin = (labels != self.ignore_label)[None].astype(mask.dtype)
324                assert mask.ndim == mask_for_bin.ndim
325                mask = np.concatenate([mask_for_bin, mask], axis=0)
326            assert affs.shape == mask.shape
327            affs = np.concatenate([affs, mask.astype(affs.dtype)], axis=0)
328
329        return affs
330
331
332class OneHotTransform:
333    """Transformations to compute one-hot labels from a semantic segmentation.
334
335    Args:
336        class_ids: The class ids to convert to one-hot labels.
337    """
338    def __init__(self, class_ids: Optional[Union[int, Sequence[int]]] = None):
339        self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids
340
341    def __call__(self, labels: np.ndarray) -> np.ndarray:
342        """Compute the one hot transformation.
343
344        Args:
345            labels: The segmentation.
346
347        Returns:
348            The one-hot transformation.
349        """
350        class_ids = np.unique(labels).tolist() if self.class_ids is None else self.class_ids
351        n_classes = len(class_ids)
352        one_hot = np.zeros((n_classes,) + labels.shape, dtype="float32")
353        for i, class_id in enumerate(class_ids):
354            one_hot[i][labels == class_id] = 1.0
355        return one_hot
356
357
358class DistanceTransform:
359    """Transformation to compute distances to foreground in the labels.
360
361    Args:
362        distances: Whether to compute the absolute distances.
363        directed_distances: Whether to compute the directed distances (vector distances).
364        normalize: Whether to normalize the computed distances.
365        max_distance: Maximal distance at which to threshold the distances.
366        foreground_id: Label id to which the distance is compute.
367        invert Whether to invert the distances:
368        func: Normalization function for the distances.
369    """
370    eps = 1e-7
371
372    def __init__(
373        self,
374        distances: bool = True,
375        directed_distances: bool = False,
376        normalize: bool = True,
377        max_distance: Optional[float] = None,
378        foreground_id: int = 1,
379        invert: bool = False,
380        func: Optional[Callable] = None,
381    ):
382        if sum((distances, directed_distances)) == 0:
383            raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
384        self.directed_distances = directed_distances
385        self.distances = distances
386        self.normalize = normalize
387        self.max_distance = max_distance
388        self.foreground_id = foreground_id
389        self.invert = invert
390        self.func = func
391
392    def _compute_distances(self, directed_distances):
393        distances = np.linalg.norm(directed_distances, axis=0)
394        if self.max_distance is not None:
395            distances = np.clip(distances, 0, self.max_distance)
396        if self.normalize:
397            distances /= (distances.max() + self.eps)
398        if self.invert:
399            distances = distances.max() - distances
400        if self.func is not None:
401            distances = self.func(distances)
402        return distances
403
404    def _compute_directed_distances(self, directed_distances):
405        if self.max_distance is not None:
406            directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance)
407        if self.normalize:
408            directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps)
409        if self.invert:
410            directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances
411        if self.func is not None:
412            directed_distances = self.func(directed_distances)
413        return directed_distances
414
415    def _get_distances_for_empty_labels(self, labels):
416        shape = labels.shape
417        fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2)
418        data = np.full((labels.ndim,) + shape, fill_value)
419        return data
420
421    def __call__(self, labels: np.ndarray) -> np.ndarray:
422        """Compute the distances.
423
424        Args:
425            labels: The segmentation.
426
427        Returns:
428            The distances.
429        """
430        distance_mask = (labels == self.foreground_id).astype("uint32")
431        # the distances are not computed corrected if they are all zero
432        # so this case needs to be handled separately
433        if distance_mask.sum() == 0:
434            directed_distances = self._get_distances_for_empty_labels(labels)
435        else:
436            ndim = distance_mask.ndim
437            to_channel_first = (ndim,) + tuple(range(ndim))
438            directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first)
439
440        if self.distances:
441            distances = self._compute_distances(directed_distances)
442
443        if self.directed_distances:
444            directed_distances = self._compute_directed_distances(directed_distances)
445
446        if self.distances and self.directed_distances:
447            return np.concatenate((distances[None], directed_distances), axis=0)
448        if self.distances:
449            return distances
450        if self.directed_distances:
451            return directed_distances
452
453
454class PerObjectDistanceTransform:
455    """Transformation to compute normalized distances per object in a segmentation.
456
457    Args:
458        distances: Whether to compute the undirected distances.
459        boundary_distances: Whether to compute the distances to the object boundaries.
460        directed_distances: Whether to compute the directed distances (vector distances).
461        foreground: Whether to return a foreground channel.
462        apply_label: Whether to apply connected components to the labels before computing distances.
463        correct_centers: Whether to correct centers that are not in the objects.
464        min_size: Minimal size of objects for distance calculdation.
465        distance_fill_value: Fill value for the distances outside of objects.
466    """
467    eps = 1e-7
468
469    def __init__(
470        self,
471        distances: bool = True,
472        boundary_distances: bool = True,
473        directed_distances: bool = False,
474        foreground: bool = True,
475        instances: bool = False,
476        apply_label: bool = True,
477        correct_centers: bool = True,
478        min_size: int = 0,
479        distance_fill_value: float = 1.0,
480    ):
481        if sum([distances, directed_distances, boundary_distances]) == 0:
482            raise ValueError("At least one of distances or directed distances has to be passed.")
483        self.distances = distances
484        self.boundary_distances = boundary_distances
485        self.directed_distances = directed_distances
486        self.foreground = foreground
487        self.instances = instances
488
489        self.apply_label = apply_label
490        self.correct_centers = correct_centers
491        self.min_size = min_size
492        self.distance_fill_value = distance_fill_value
493
494    def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
495        """@private
496        """
497        # Crop the mask and generate array with the correct center.
498        cropped_mask = mask[bb]
499        cropped_center = tuple(ce - b.start for ce, b in zip(center, bb))
500
501        # The centroid might not be inside of the object.
502        # In this case we correct the center by taking the maximum of the distance to the boundary.
503        # Note: the centroid is still the best estimate for the center, as long as it's in the object.
504        correct_center = not cropped_mask[cropped_center]
505
506        # Compute the boundary distances if necessary.
507        # (Either if we need to correct the center, or compute the boundary distances anyways.)
508        if correct_center or self.boundary_distances:
509            # Crop the boundary mask and compute the boundary distances.
510            cropped_boundary_mask = boundaries[bb]
511            boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask)
512            boundary_distances[~cropped_mask] = 0
513            max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape)
514
515        # Set the crop center to the max dist point
516        if correct_center:
517            # Find the center (= maximal distance from the boundaries).
518            cropped_center = max_dist_point
519
520        cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32")
521        cropped_center_mask[cropped_center] = 1
522
523        # Compute the directed distances,
524        if self.distances or self.directed_distances:
525            this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask)
526        else:
527            this_distances = None
528
529        # Keep only the specified distances:
530        if self.distances and self.directed_distances:  # all distances
531            # Compute the undirected ditacnes from directed distances and concatenate,
532            undir = np.linalg.norm(this_distances, axis=-1, keepdims=True)
533            this_distances = np.concatenate([undir, this_distances], axis=-1)
534
535        elif self.distances:  # only undirected distances
536            # Compute the undirected distances from directed distances and keep only them.
537            this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True)
538
539        elif self.directed_distances:  # only directed distances
540            pass  # We don't have to do anything becasue the directed distances are already computed.
541
542        # Add an extra channel for the boundary distances if specified.
543        if self.boundary_distances:
544            boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None]
545            if this_distances is None:
546                this_distances = boundary_distances
547            else:
548                this_distances = np.concatenate([this_distances, boundary_distances], axis=-1)
549
550        # Set distances outside of the mask to zero.
551        this_distances[~cropped_mask] = 0
552
553        # Normalize the distances.
554        spatial_axes = tuple(range(mask.ndim))
555        this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps)
556
557        # Set the distance values in the global result.
558        distances[bb][cropped_mask] = this_distances[cropped_mask]
559
560        return distances
561
562    def __call__(self, labels: np.ndarray) -> np.ndarray:
563        """Compute the per object distance transform.
564
565        Args:
566            labels: The segmentation
567
568        Returns:
569            The distances.
570        """
571        # Apply label (connected components) if specified.
572        if self.apply_label:
573            labels = skimage.measure.label(labels).astype("uint32")
574        else:  # Otherwise just relabel the segmentation.
575            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
576
577        # Filter out small objects if min_size is specified.
578        if self.min_size > 0:
579            ids, sizes = np.unique(labels, return_counts=True)
580            discard_ids = ids[sizes < self.min_size]
581            labels[np.isin(labels, discard_ids)] = 0
582            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
583
584        # Compute the boundaries. They will be used to determine the most central point,
585        # and if 'self.boundary_distances is True' to add the boundary distances.
586        boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32")
587
588        # Compute region properties to derive bounding boxes and centers.
589        ndim = labels.ndim
590        props = skimage.measure.regionprops(labels)
591        bounding_boxes = {
592            prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim)) for prop in props
593        }
594
595        # Compute the object centers from centroids.
596        centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props}
597
598        # Compute how many distance channels we have.
599        n_channels = 0
600        if self.distances:  # We need one channel for the overall distances.
601            n_channels += 1
602        if self.boundary_distances:  # We need one channel for the boundary distances.
603            n_channels += 1
604        if self.directed_distances:  # And ndim channels for directed distances.
605            n_channels += ndim
606
607        # Compute the per object distances.
608        distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32")
609        for prop in props:
610            label_id = prop.label
611            mask = labels == label_id
612            distances = self.compute_normalized_object_distances(
613                mask, boundaries, bounding_boxes[label_id], centers[label_id], distances
614            )
615
616        # Bring the distance channel to the first dimension.
617        to_channel_first = (ndim,) + tuple(range(ndim))
618        distances = distances.transpose(to_channel_first)
619
620        # Add the foreground mask as first channel if specified.
621        if self.foreground:
622            binary_labels = (labels > 0).astype("float32")
623            distances = np.concatenate([binary_labels[None], distances], axis=0)
624
625        if self.instances:
626            distances = np.concatenate([labels[None], distances], axis=0)
627
628        return distances
def connected_components( labels: numpy.ndarray, ndim: Optional[int] = None, ensure_zero: bool = False) -> numpy.ndarray:
17def connected_components(labels: np.ndarray, ndim: Optional[int] = None, ensure_zero: bool = False) -> np.ndarray:
18    """Apply connected components to a segmentation.
19
20    Args:
21        labels: The input segmentation.
22        ndim: The expected dimensionality of the data.
23        ensure_zero: Whether to ensure that the data has a zero label.
24
25    Returns:
26        The segmentation after connected components.
27    """
28    labels = ensure_array(labels) if ndim is None else ensure_spatial_array(labels, ndim)
29    labels = skimage.measure.label(labels)
30    if ensure_zero and 0 not in labels:
31        labels -= 1
32    return labels

Apply connected components to a segmentation.

Arguments:
  • labels: The input segmentation.
  • ndim: The expected dimensionality of the data.
  • ensure_zero: Whether to ensure that the data has a zero label.
Returns:

The segmentation after connected components.

def labels_to_binary(labels: numpy.ndarray, background_label: int = 0) -> numpy.ndarray:
35def labels_to_binary(labels: np.ndarray, background_label: int = 0) -> np.ndarray:
36    """Transform a segmentation to binary labels.
37
38    Args:
39        labels: The input segmentation.
40        background_label: The id of the background label.
41
42    Returns:
43        The binary segmentation.
44    """
45    return (labels != background_label).astype(labels.dtype)

Transform a segmentation to binary labels.

Arguments:
  • labels: The input segmentation.
  • background_label: The id of the background label.
Returns:

The binary segmentation.

def label_consecutive(labels: numpy.ndarray, with_background: bool = True) -> numpy.ndarray:
48def label_consecutive(labels: np.ndarray, with_background: bool = True) -> np.ndarray:
49    """Ensure that the input segmentation is labeled consecutively.
50
51    Args:
52        labels: The input segmentation.
53        with_background: Whether this segmentation has a background label.
54
55    Returns:
56        The consecutively labeled segmentation.
57    """
58    if with_background:
59        seg = skimage.segmentation.relabel_sequential(labels)[0]
60    else:
61        if 0 in labels:
62            labels += 1
63        seg = skimage.segmentation.relabel_sequential(labels)[0]
64        assert seg.min() == 1
65        seg -= 1
66    return seg

Ensure that the input segmentation is labeled consecutively.

Arguments:
  • labels: The input segmentation.
  • with_background: Whether this segmentation has a background label.
Returns:

The consecutively labeled segmentation.

class MinSizeLabelTransform:
69class MinSizeLabelTransform:
70    """Transformation to filter out objects smaller than a minimal size from the segmentation.
71
72    Args:
73        min_size: The minimal object size of the segmentation.
74        ndim: The dimensionality of the segmentation.
75        ensure_zero: Ensure that the segmentation contains the id zero.
76    """
77    def __init__(self, min_size: Optional[int] = None, ndim: Optional[int] = None, ensure_zero: bool = False):
78        self.min_size = min_size
79        self.ndim = ndim
80        self.ensure_zero = ensure_zero
81
82    def __call__(self, labels: np.ndarray) -> np.ndarray:
83        """Filter out small objects from segmentation.
84
85        Args:
86            labels: The input segmentation.
87
88        Returns:
89            The size filtered segmentation.
90        """
91        components = connected_components(labels, ndim=self.ndim, ensure_zero=self.ensure_zero)
92        if self.min_size is not None:
93            ids, sizes = np.unique(components, return_counts=True)
94            filter_ids = ids[sizes < self.min_size]
95            components[np.isin(components, filter_ids)] = 0
96            components, _, _ = skimage.segmentation.relabel_sequential(components)
97        return components

Transformation to filter out objects smaller than a minimal size from the segmentation.

Arguments:
  • min_size: The minimal object size of the segmentation.
  • ndim: The dimensionality of the segmentation.
  • ensure_zero: Ensure that the segmentation contains the id zero.
MinSizeLabelTransform( min_size: Optional[int] = None, ndim: Optional[int] = None, ensure_zero: bool = False)
77    def __init__(self, min_size: Optional[int] = None, ndim: Optional[int] = None, ensure_zero: bool = False):
78        self.min_size = min_size
79        self.ndim = ndim
80        self.ensure_zero = ensure_zero
min_size
ndim
ensure_zero
class BoundaryTransform:
101class BoundaryTransform:
102    """Transformation to convert an instance segmentation into boundaries.
103
104    Args:
105        mode: The mode for converting the segmentation to boundaries.
106        add_binary_target: Whether to add a binary mask channel to the transformation output.
107        ndim: The expected dimensionality of the data.
108    """
109    def __init__(self, mode: str = "thick", add_binary_target: bool = False, ndim: Optional[int] = None):
110        self.mode = mode
111        self.add_binary_target = add_binary_target
112        self.ndim = ndim
113
114    def __call__(self, labels: np.ndarray) -> np.ndarray:
115        """Apply the boundary transformation to an input segmentation.
116
117        Args:
118            labels: The input segmentation.
119
120        Returns:
121            The boundaries.
122        """
123        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
124        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
125        if self.add_binary_target:
126            binary = labels_to_binary(labels)[None].astype(boundaries.dtype)
127            target = np.concatenate([binary, boundaries], axis=0)
128        else:
129            target = boundaries
130        return target

Transformation to convert an instance segmentation into boundaries.

Arguments:
  • mode: The mode for converting the segmentation to boundaries.
  • add_binary_target: Whether to add a binary mask channel to the transformation output.
  • ndim: The expected dimensionality of the data.
BoundaryTransform( mode: str = 'thick', add_binary_target: bool = False, ndim: Optional[int] = None)
109    def __init__(self, mode: str = "thick", add_binary_target: bool = False, ndim: Optional[int] = None):
110        self.mode = mode
111        self.add_binary_target = add_binary_target
112        self.ndim = ndim
mode
add_binary_target
ndim
class NoToBackgroundBoundaryTransform:
134class NoToBackgroundBoundaryTransform:
135    """Transformation to convert an instance segmentation into boundaries.
136
137    This transformation sets boundaries with the ignore label to the ignore label
138    in the output of the transformation.
139
140    Args:
141        bg_label: The background label.
142        mask_label: The mask label.
143        mode: The mode for converting the segmentation to boundaries.
144        add_binary_target: Whether to add a binary mask channel to the transformation output.
145        ndim: The expected dimensionality of the data.
146    """
147    def __init__(
148        self,
149        bg_label: int = 0,
150        mask_label: int = -1,
151        mode: str = "thick",
152        add_binary_target: bool = False,
153        ndim: Optional[int] = None,
154    ):
155        self.bg_label = bg_label
156        self.mask_label = mask_label
157        self.mode = mode
158        self.ndim = ndim
159        self.add_binary_target = add_binary_target
160
161    def __call__(self, labels: np.ndarray) -> np.ndarray:
162        """Apply the boundary transformation to an input segmentation.
163
164        Args:
165            labels: The input segmentation.
166
167        Returns:
168            The boundaries.
169        """
170        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
171        # calc normal boundaries
172        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
173
174        # make label image binary and calculate to-background-boundaries
175        labels_binary = (labels != self.bg_label)
176        to_bg_boundaries = skimage.segmentation.find_boundaries(labels_binary, mode=self.mode)[None]
177
178        # mask the to-background-boundaries
179        boundaries = boundaries.astype(np.int8)
180        boundaries[to_bg_boundaries] = self.mask_label
181
182        if self.add_binary_target:
183            binary = labels_to_binary(labels, self.bg_label).astype(boundaries.dtype)
184            binary[labels == self.mask_label] = self.mask_label
185            target = np.concatenate([binary[None], boundaries], axis=0)
186        else:
187            target = boundaries
188
189        return target

Transformation to convert an instance segmentation into boundaries.

This transformation sets boundaries with the ignore label to the ignore label in the output of the transformation.

Arguments:
  • bg_label: The background label.
  • mask_label: The mask label.
  • mode: The mode for converting the segmentation to boundaries.
  • add_binary_target: Whether to add a binary mask channel to the transformation output.
  • ndim: The expected dimensionality of the data.
NoToBackgroundBoundaryTransform( bg_label: int = 0, mask_label: int = -1, mode: str = 'thick', add_binary_target: bool = False, ndim: Optional[int] = None)
147    def __init__(
148        self,
149        bg_label: int = 0,
150        mask_label: int = -1,
151        mode: str = "thick",
152        add_binary_target: bool = False,
153        ndim: Optional[int] = None,
154    ):
155        self.bg_label = bg_label
156        self.mask_label = mask_label
157        self.mode = mode
158        self.ndim = ndim
159        self.add_binary_target = add_binary_target
bg_label
mask_label
mode
ndim
add_binary_target
class BoundaryTransformWithIgnoreLabel:
193class BoundaryTransformWithIgnoreLabel:
194    """Transformation to convert an instance segmentation into boundaries.
195
196    This transformation sets boundaries with the ignore label to the ignore label
197    in the output of the transformation.
198
199    Args:
200        ignore_label: The ignore label.
201        mode: The mode for converting the segmentation to boundaries.
202        add_binary_target: Whether to add a binary mask channel to the transformation output.
203        ndim: The expected dimensionality of the data.
204    """
205    def __init__(
206        self,
207        ignore_label: int = -1,
208        mode: str = "thick",
209        add_binary_target: bool = False,
210        ndim: Optional[int] = None,
211    ):
212        self.ignore_label = ignore_label
213        self.mode = mode
214        self.ndim = ndim
215        self.add_binary_target = add_binary_target
216
217    def __call__(self, labels: np.ndarray) -> np.ndarray:
218        """Apply the boundary transformation to an input segmentation.
219
220        Args:
221            labels: The input segmentation.
222
223        Returns:
224            The boundaries.
225        """
226        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
227        # calculate the normal boundaries
228        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
229
230        # calculate the boundaries for the ignore label
231        labels_ignore = (labels == self.ignore_label)
232        to_ignore_boundaries = skimage.segmentation.find_boundaries(labels_ignore, mode=self.mode)[None]
233
234        # mask the to-background-boundaries
235        boundaries = boundaries.astype(np.int8)
236        boundaries[to_ignore_boundaries] = self.ignore_label
237
238        if self.add_binary_target:
239            binary = labels_to_binary(labels).astype(boundaries.dtype)
240            binary[labels == self.ignore_label] = self.ignore_label
241            target = np.concatenate([binary[None], boundaries], axis=0)
242        else:
243            target = boundaries
244
245        return target

Transformation to convert an instance segmentation into boundaries.

This transformation sets boundaries with the ignore label to the ignore label in the output of the transformation.

Arguments:
  • ignore_label: The ignore label.
  • mode: The mode for converting the segmentation to boundaries.
  • add_binary_target: Whether to add a binary mask channel to the transformation output.
  • ndim: The expected dimensionality of the data.
BoundaryTransformWithIgnoreLabel( ignore_label: int = -1, mode: str = 'thick', add_binary_target: bool = False, ndim: Optional[int] = None)
205    def __init__(
206        self,
207        ignore_label: int = -1,
208        mode: str = "thick",
209        add_binary_target: bool = False,
210        ndim: Optional[int] = None,
211    ):
212        self.ignore_label = ignore_label
213        self.mode = mode
214        self.ndim = ndim
215        self.add_binary_target = add_binary_target
ignore_label
mode
ndim
add_binary_target
class AffinityTransform:
249class AffinityTransform:
250    """Transformation to compute affinities from a segmentation.
251
252    Args:
253        offsets: The offsets for computing affinities.
254        ignore_label: The ignore label to use for computing the ignore mask.
255        add_binary_target: Whether to add a binary channel to the affinities.
256        add_mask: Whether to add the ignore mask as extra output channels.
257        include_ignore_transitions: Whether transitions to the ignore label
258            should be positive in the ignore mask or negative in it.
259    """
260    def __init__(
261        self,
262        offsets: List[List[int]],
263        ignore_label: Optional[bool] = None,
264        add_binary_target: bool = False,
265        add_mask: bool = False,
266        include_ignore_transitions: bool = False,
267    ):
268        assert compute_affinities is not None
269        self.offsets = offsets
270        self.ndim = len(self.offsets[0])
271        assert self.ndim in (2, 3)
272
273        self.ignore_label = ignore_label
274        self.add_binary_target = add_binary_target
275        self.add_mask = add_mask
276        self.include_ignore_transitions = include_ignore_transitions
277
278    def add_ignore_transitions(self, affs, mask, labels):
279        """@private
280        """
281        ignore_seg = (labels == self.ignore_label).astype(labels.dtype)
282        ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets)
283        invalid_mask = np.logical_not(invalid_mask)
284        # NOTE affinity convention returned by affogato: transitions are marked by 0
285        ignore_transitions = ignore_transitions == 0
286        ignore_transitions[invalid_mask] = 0
287        affs[ignore_transitions] = 1
288        mask[ignore_transitions] = 1
289        return affs, mask
290
291    def __call__(self, labels: np.ndarray) -> np.ndarray:
292        """Compute the affinities.
293
294        Args:
295            labels: The segmentation.
296
297        Returns:
298            The affinities.
299        """
300        dtype = "uint64"
301        if np.dtype(labels.dtype) in (np.dtype("int16"), np.dtype("int32"), np.dtype("int64")):
302            dtype = "int64"
303        labels = ensure_spatial_array(labels, self.ndim, dtype=dtype)
304        affs, mask = compute_affinities(labels, self.offsets,
305                                        have_ignore_label=self.ignore_label is not None,
306                                        ignore_label=0 if self.ignore_label is None else self.ignore_label)
307        # we use the "disaffinity" convention for training; i.e. 1 means repulsive, 0 attractive
308        affs = 1. - affs
309
310        # remove transitions to the ignore label from the mask
311        if self.ignore_label is not None and self.include_ignore_transitions:
312            affs, mask = self.add_ignore_transitions(affs, mask, labels)
313
314        if self.add_binary_target:
315            binary = labels_to_binary(labels)[None].astype(affs.dtype)
316            assert binary.ndim == affs.ndim
317            affs = np.concatenate([binary, affs], axis=0)
318
319        if self.add_mask:
320            if self.add_binary_target:
321                if self.ignore_label is None:
322                    mask_for_bin = np.ones((1,) + labels.shape, dtype=mask.dtype)
323                else:
324                    mask_for_bin = (labels != self.ignore_label)[None].astype(mask.dtype)
325                assert mask.ndim == mask_for_bin.ndim
326                mask = np.concatenate([mask_for_bin, mask], axis=0)
327            assert affs.shape == mask.shape
328            affs = np.concatenate([affs, mask.astype(affs.dtype)], axis=0)
329
330        return affs

Transformation to compute affinities from a segmentation.

Arguments:
  • offsets: The offsets for computing affinities.
  • ignore_label: The ignore label to use for computing the ignore mask.
  • add_binary_target: Whether to add a binary channel to the affinities.
  • add_mask: Whether to add the ignore mask as extra output channels.
  • include_ignore_transitions: Whether transitions to the ignore label should be positive in the ignore mask or negative in it.
AffinityTransform( offsets: List[List[int]], ignore_label: Optional[bool] = None, add_binary_target: bool = False, add_mask: bool = False, include_ignore_transitions: bool = False)
260    def __init__(
261        self,
262        offsets: List[List[int]],
263        ignore_label: Optional[bool] = None,
264        add_binary_target: bool = False,
265        add_mask: bool = False,
266        include_ignore_transitions: bool = False,
267    ):
268        assert compute_affinities is not None
269        self.offsets = offsets
270        self.ndim = len(self.offsets[0])
271        assert self.ndim in (2, 3)
272
273        self.ignore_label = ignore_label
274        self.add_binary_target = add_binary_target
275        self.add_mask = add_mask
276        self.include_ignore_transitions = include_ignore_transitions
offsets
ndim
ignore_label
add_binary_target
add_mask
include_ignore_transitions
class OneHotTransform:
333class OneHotTransform:
334    """Transformations to compute one-hot labels from a semantic segmentation.
335
336    Args:
337        class_ids: The class ids to convert to one-hot labels.
338    """
339    def __init__(self, class_ids: Optional[Union[int, Sequence[int]]] = None):
340        self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids
341
342    def __call__(self, labels: np.ndarray) -> np.ndarray:
343        """Compute the one hot transformation.
344
345        Args:
346            labels: The segmentation.
347
348        Returns:
349            The one-hot transformation.
350        """
351        class_ids = np.unique(labels).tolist() if self.class_ids is None else self.class_ids
352        n_classes = len(class_ids)
353        one_hot = np.zeros((n_classes,) + labels.shape, dtype="float32")
354        for i, class_id in enumerate(class_ids):
355            one_hot[i][labels == class_id] = 1.0
356        return one_hot

Transformations to compute one-hot labels from a semantic segmentation.

Arguments:
  • class_ids: The class ids to convert to one-hot labels.
OneHotTransform(class_ids: Union[int, Sequence[int], NoneType] = None)
339    def __init__(self, class_ids: Optional[Union[int, Sequence[int]]] = None):
340        self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids
class_ids
class DistanceTransform:
359class DistanceTransform:
360    """Transformation to compute distances to foreground in the labels.
361
362    Args:
363        distances: Whether to compute the absolute distances.
364        directed_distances: Whether to compute the directed distances (vector distances).
365        normalize: Whether to normalize the computed distances.
366        max_distance: Maximal distance at which to threshold the distances.
367        foreground_id: Label id to which the distance is compute.
368        invert Whether to invert the distances:
369        func: Normalization function for the distances.
370    """
371    eps = 1e-7
372
373    def __init__(
374        self,
375        distances: bool = True,
376        directed_distances: bool = False,
377        normalize: bool = True,
378        max_distance: Optional[float] = None,
379        foreground_id: int = 1,
380        invert: bool = False,
381        func: Optional[Callable] = None,
382    ):
383        if sum((distances, directed_distances)) == 0:
384            raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
385        self.directed_distances = directed_distances
386        self.distances = distances
387        self.normalize = normalize
388        self.max_distance = max_distance
389        self.foreground_id = foreground_id
390        self.invert = invert
391        self.func = func
392
393    def _compute_distances(self, directed_distances):
394        distances = np.linalg.norm(directed_distances, axis=0)
395        if self.max_distance is not None:
396            distances = np.clip(distances, 0, self.max_distance)
397        if self.normalize:
398            distances /= (distances.max() + self.eps)
399        if self.invert:
400            distances = distances.max() - distances
401        if self.func is not None:
402            distances = self.func(distances)
403        return distances
404
405    def _compute_directed_distances(self, directed_distances):
406        if self.max_distance is not None:
407            directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance)
408        if self.normalize:
409            directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps)
410        if self.invert:
411            directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances
412        if self.func is not None:
413            directed_distances = self.func(directed_distances)
414        return directed_distances
415
416    def _get_distances_for_empty_labels(self, labels):
417        shape = labels.shape
418        fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2)
419        data = np.full((labels.ndim,) + shape, fill_value)
420        return data
421
422    def __call__(self, labels: np.ndarray) -> np.ndarray:
423        """Compute the distances.
424
425        Args:
426            labels: The segmentation.
427
428        Returns:
429            The distances.
430        """
431        distance_mask = (labels == self.foreground_id).astype("uint32")
432        # the distances are not computed corrected if they are all zero
433        # so this case needs to be handled separately
434        if distance_mask.sum() == 0:
435            directed_distances = self._get_distances_for_empty_labels(labels)
436        else:
437            ndim = distance_mask.ndim
438            to_channel_first = (ndim,) + tuple(range(ndim))
439            directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first)
440
441        if self.distances:
442            distances = self._compute_distances(directed_distances)
443
444        if self.directed_distances:
445            directed_distances = self._compute_directed_distances(directed_distances)
446
447        if self.distances and self.directed_distances:
448            return np.concatenate((distances[None], directed_distances), axis=0)
449        if self.distances:
450            return distances
451        if self.directed_distances:
452            return directed_distances

Transformation to compute distances to foreground in the labels.

Arguments:
  • distances: Whether to compute the absolute distances.
  • directed_distances: Whether to compute the directed distances (vector distances).
  • normalize: Whether to normalize the computed distances.
  • max_distance: Maximal distance at which to threshold the distances.
  • foreground_id: Label id to which the distance is compute.
  • invert Whether to invert the distances:
  • func: Normalization function for the distances.
DistanceTransform( distances: bool = True, directed_distances: bool = False, normalize: bool = True, max_distance: Optional[float] = None, foreground_id: int = 1, invert: bool = False, func: Optional[Callable] = None)
373    def __init__(
374        self,
375        distances: bool = True,
376        directed_distances: bool = False,
377        normalize: bool = True,
378        max_distance: Optional[float] = None,
379        foreground_id: int = 1,
380        invert: bool = False,
381        func: Optional[Callable] = None,
382    ):
383        if sum((distances, directed_distances)) == 0:
384            raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
385        self.directed_distances = directed_distances
386        self.distances = distances
387        self.normalize = normalize
388        self.max_distance = max_distance
389        self.foreground_id = foreground_id
390        self.invert = invert
391        self.func = func
eps = 1e-07
directed_distances
distances
normalize
max_distance
foreground_id
invert
func
class PerObjectDistanceTransform:
455class PerObjectDistanceTransform:
456    """Transformation to compute normalized distances per object in a segmentation.
457
458    Args:
459        distances: Whether to compute the undirected distances.
460        boundary_distances: Whether to compute the distances to the object boundaries.
461        directed_distances: Whether to compute the directed distances (vector distances).
462        foreground: Whether to return a foreground channel.
463        apply_label: Whether to apply connected components to the labels before computing distances.
464        correct_centers: Whether to correct centers that are not in the objects.
465        min_size: Minimal size of objects for distance calculdation.
466        distance_fill_value: Fill value for the distances outside of objects.
467    """
468    eps = 1e-7
469
470    def __init__(
471        self,
472        distances: bool = True,
473        boundary_distances: bool = True,
474        directed_distances: bool = False,
475        foreground: bool = True,
476        instances: bool = False,
477        apply_label: bool = True,
478        correct_centers: bool = True,
479        min_size: int = 0,
480        distance_fill_value: float = 1.0,
481    ):
482        if sum([distances, directed_distances, boundary_distances]) == 0:
483            raise ValueError("At least one of distances or directed distances has to be passed.")
484        self.distances = distances
485        self.boundary_distances = boundary_distances
486        self.directed_distances = directed_distances
487        self.foreground = foreground
488        self.instances = instances
489
490        self.apply_label = apply_label
491        self.correct_centers = correct_centers
492        self.min_size = min_size
493        self.distance_fill_value = distance_fill_value
494
495    def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
496        """@private
497        """
498        # Crop the mask and generate array with the correct center.
499        cropped_mask = mask[bb]
500        cropped_center = tuple(ce - b.start for ce, b in zip(center, bb))
501
502        # The centroid might not be inside of the object.
503        # In this case we correct the center by taking the maximum of the distance to the boundary.
504        # Note: the centroid is still the best estimate for the center, as long as it's in the object.
505        correct_center = not cropped_mask[cropped_center]
506
507        # Compute the boundary distances if necessary.
508        # (Either if we need to correct the center, or compute the boundary distances anyways.)
509        if correct_center or self.boundary_distances:
510            # Crop the boundary mask and compute the boundary distances.
511            cropped_boundary_mask = boundaries[bb]
512            boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask)
513            boundary_distances[~cropped_mask] = 0
514            max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape)
515
516        # Set the crop center to the max dist point
517        if correct_center:
518            # Find the center (= maximal distance from the boundaries).
519            cropped_center = max_dist_point
520
521        cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32")
522        cropped_center_mask[cropped_center] = 1
523
524        # Compute the directed distances,
525        if self.distances or self.directed_distances:
526            this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask)
527        else:
528            this_distances = None
529
530        # Keep only the specified distances:
531        if self.distances and self.directed_distances:  # all distances
532            # Compute the undirected ditacnes from directed distances and concatenate,
533            undir = np.linalg.norm(this_distances, axis=-1, keepdims=True)
534            this_distances = np.concatenate([undir, this_distances], axis=-1)
535
536        elif self.distances:  # only undirected distances
537            # Compute the undirected distances from directed distances and keep only them.
538            this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True)
539
540        elif self.directed_distances:  # only directed distances
541            pass  # We don't have to do anything becasue the directed distances are already computed.
542
543        # Add an extra channel for the boundary distances if specified.
544        if self.boundary_distances:
545            boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None]
546            if this_distances is None:
547                this_distances = boundary_distances
548            else:
549                this_distances = np.concatenate([this_distances, boundary_distances], axis=-1)
550
551        # Set distances outside of the mask to zero.
552        this_distances[~cropped_mask] = 0
553
554        # Normalize the distances.
555        spatial_axes = tuple(range(mask.ndim))
556        this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps)
557
558        # Set the distance values in the global result.
559        distances[bb][cropped_mask] = this_distances[cropped_mask]
560
561        return distances
562
563    def __call__(self, labels: np.ndarray) -> np.ndarray:
564        """Compute the per object distance transform.
565
566        Args:
567            labels: The segmentation
568
569        Returns:
570            The distances.
571        """
572        # Apply label (connected components) if specified.
573        if self.apply_label:
574            labels = skimage.measure.label(labels).astype("uint32")
575        else:  # Otherwise just relabel the segmentation.
576            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
577
578        # Filter out small objects if min_size is specified.
579        if self.min_size > 0:
580            ids, sizes = np.unique(labels, return_counts=True)
581            discard_ids = ids[sizes < self.min_size]
582            labels[np.isin(labels, discard_ids)] = 0
583            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
584
585        # Compute the boundaries. They will be used to determine the most central point,
586        # and if 'self.boundary_distances is True' to add the boundary distances.
587        boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32")
588
589        # Compute region properties to derive bounding boxes and centers.
590        ndim = labels.ndim
591        props = skimage.measure.regionprops(labels)
592        bounding_boxes = {
593            prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim)) for prop in props
594        }
595
596        # Compute the object centers from centroids.
597        centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props}
598
599        # Compute how many distance channels we have.
600        n_channels = 0
601        if self.distances:  # We need one channel for the overall distances.
602            n_channels += 1
603        if self.boundary_distances:  # We need one channel for the boundary distances.
604            n_channels += 1
605        if self.directed_distances:  # And ndim channels for directed distances.
606            n_channels += ndim
607
608        # Compute the per object distances.
609        distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32")
610        for prop in props:
611            label_id = prop.label
612            mask = labels == label_id
613            distances = self.compute_normalized_object_distances(
614                mask, boundaries, bounding_boxes[label_id], centers[label_id], distances
615            )
616
617        # Bring the distance channel to the first dimension.
618        to_channel_first = (ndim,) + tuple(range(ndim))
619        distances = distances.transpose(to_channel_first)
620
621        # Add the foreground mask as first channel if specified.
622        if self.foreground:
623            binary_labels = (labels > 0).astype("float32")
624            distances = np.concatenate([binary_labels[None], distances], axis=0)
625
626        if self.instances:
627            distances = np.concatenate([labels[None], distances], axis=0)
628
629        return distances

Transformation to compute normalized distances per object in a segmentation.

Arguments:
  • distances: Whether to compute the undirected distances.
  • boundary_distances: Whether to compute the distances to the object boundaries.
  • directed_distances: Whether to compute the directed distances (vector distances).
  • foreground: Whether to return a foreground channel.
  • apply_label: Whether to apply connected components to the labels before computing distances.
  • correct_centers: Whether to correct centers that are not in the objects.
  • min_size: Minimal size of objects for distance calculdation.
  • distance_fill_value: Fill value for the distances outside of objects.
PerObjectDistanceTransform( distances: bool = True, boundary_distances: bool = True, directed_distances: bool = False, foreground: bool = True, instances: bool = False, apply_label: bool = True, correct_centers: bool = True, min_size: int = 0, distance_fill_value: float = 1.0)
470    def __init__(
471        self,
472        distances: bool = True,
473        boundary_distances: bool = True,
474        directed_distances: bool = False,
475        foreground: bool = True,
476        instances: bool = False,
477        apply_label: bool = True,
478        correct_centers: bool = True,
479        min_size: int = 0,
480        distance_fill_value: float = 1.0,
481    ):
482        if sum([distances, directed_distances, boundary_distances]) == 0:
483            raise ValueError("At least one of distances or directed distances has to be passed.")
484        self.distances = distances
485        self.boundary_distances = boundary_distances
486        self.directed_distances = directed_distances
487        self.foreground = foreground
488        self.instances = instances
489
490        self.apply_label = apply_label
491        self.correct_centers = correct_centers
492        self.min_size = min_size
493        self.distance_fill_value = distance_fill_value
eps = 1e-07
distances
boundary_distances
directed_distances
foreground
instances
apply_label
correct_centers
min_size
distance_fill_value