torch_em.transform.label

  1from typing import Callable, List, Optional, Sequence, Union, Tuple
  2
  3import numpy as np
  4import skimage.measure
  5import skimage.segmentation
  6import vigra
  7
  8from ..util import ensure_array, ensure_spatial_array
  9
 10try:
 11    from affogato.affinities import compute_affinities
 12except ImportError:
 13    compute_affinities = None
 14
 15
 16def connected_components(labels: np.ndarray, ndim: Optional[int] = None, ensure_zero: bool = False) -> np.ndarray:
 17    """Apply connected components to a segmentation.
 18
 19    Args:
 20        labels: The input segmentation.
 21        ndim: The expected dimensionality of the data.
 22        ensure_zero: Whether to ensure that the data has a zero label.
 23
 24    Returns:
 25        The segmentation after connected components.
 26    """
 27    labels = ensure_array(labels) if ndim is None else ensure_spatial_array(labels, ndim)
 28    labels = skimage.measure.label(labels)
 29    if ensure_zero and 0 not in labels:
 30        labels -= 1
 31    return labels
 32
 33
 34def labels_to_binary(labels: np.ndarray, background_label: int = 0) -> np.ndarray:
 35    """Transform a segmentation to binary labels.
 36
 37    Args:
 38        labels: The input segmentation.
 39        background_label: The id of the background label.
 40
 41    Returns:
 42        The binary segmentation.
 43    """
 44    return (labels != background_label).astype(labels.dtype)
 45
 46
 47def label_consecutive(labels: np.ndarray, with_background: bool = True) -> np.ndarray:
 48    """Ensure that the input segmentation is labeled consecutively.
 49
 50    Args:
 51        labels: The input segmentation.
 52        with_background: Whether this segmentation has a background label.
 53
 54    Returns:
 55        The consecutively labeled segmentation.
 56    """
 57    if with_background:
 58        seg = skimage.segmentation.relabel_sequential(labels)[0]
 59    else:
 60        if 0 in labels:
 61            labels += 1
 62        seg = skimage.segmentation.relabel_sequential(labels)[0]
 63        assert seg.min() == 1
 64        seg -= 1
 65    return seg
 66
 67
 68class MinSizeLabelTransform:
 69    """Transformation to filter out objects smaller than a minimal size from the segmentation.
 70
 71    Args:
 72        min_size: The minimal object size of the segmentation.
 73        ndim: The dimensionality of the segmentation.
 74        ensure_zero: Ensure that the segmentation contains the id zero.
 75    """
 76    def __init__(self, min_size: Optional[int] = None, ndim: Optional[int] = None, ensure_zero: bool = False):
 77        self.min_size = min_size
 78        self.ndim = ndim
 79        self.ensure_zero = ensure_zero
 80
 81    def __call__(self, labels: np.ndarray) -> np.ndarray:
 82        """Filter out small objects from segmentation.
 83
 84        Args:
 85            labels: The input segmentation.
 86
 87        Returns:
 88            The size filtered segmentation.
 89        """
 90        components = connected_components(labels, ndim=self.ndim, ensure_zero=self.ensure_zero)
 91        if self.min_size is not None:
 92            ids, sizes = np.unique(components, return_counts=True)
 93            filter_ids = ids[sizes < self.min_size]
 94            components[np.isin(components, filter_ids)] = 0
 95            components, _, _ = skimage.segmentation.relabel_sequential(components)
 96        return components
 97
 98
 99# TODO smoothing
100class BoundaryTransform:
101    """Transformation to convert an instance segmentation into boundaries.
102
103    Args:
104        mode: The mode for converting the segmentation to boundaries.
105        add_binary_target: Whether to add a binary mask channel to the transformation output.
106        ndim: The expected dimensionality of the data.
107    """
108    def __init__(self, mode: str = "thick", add_binary_target: bool = False, ndim: Optional[int] = None):
109        self.mode = mode
110        self.add_binary_target = add_binary_target
111        self.ndim = ndim
112
113    def __call__(self, labels: np.ndarray) -> np.ndarray:
114        """Apply the boundary transformation to an input segmentation.
115
116        Args:
117            labels: The input segmentation.
118
119        Returns:
120            The boundaries.
121        """
122        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
123        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
124        if self.add_binary_target:
125            binary = labels_to_binary(labels)[None].astype(boundaries.dtype)
126            target = np.concatenate([binary, boundaries], axis=0)
127        else:
128            target = boundaries
129        return target
130
131
132# TODO smoothing
133class NoToBackgroundBoundaryTransform:
134    """Transformation to convert an instance segmentation into boundaries.
135
136    This transformation sets boundaries with the ignore label to the ignore label
137    in the output of the transformation.
138
139    Args:
140        bg_label: The background label.
141        mask_label: The mask label.
142        mode: The mode for converting the segmentation to boundaries.
143        add_binary_target: Whether to add a binary mask channel to the transformation output.
144        ndim: The expected dimensionality of the data.
145    """
146    def __init__(
147        self,
148        bg_label: int = 0,
149        mask_label: int = -1,
150        mode: str = "thick",
151        add_binary_target: bool = False,
152        ndim: Optional[int] = None,
153    ):
154        self.bg_label = bg_label
155        self.mask_label = mask_label
156        self.mode = mode
157        self.ndim = ndim
158        self.add_binary_target = add_binary_target
159
160    def __call__(self, labels: np.ndarray) -> np.ndarray:
161        """Apply the boundary transformation to an input segmentation.
162
163        Args:
164            labels: The input segmentation.
165
166        Returns:
167            The boundaries.
168        """
169        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
170        # calc normal boundaries
171        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
172
173        # make label image binary and calculate to-background-boundaries
174        labels_binary = (labels != self.bg_label)
175        to_bg_boundaries = skimage.segmentation.find_boundaries(labels_binary, mode=self.mode)[None]
176
177        # mask the to-background-boundaries
178        boundaries = boundaries.astype(np.int8)
179        boundaries[to_bg_boundaries] = self.mask_label
180
181        if self.add_binary_target:
182            binary = labels_to_binary(labels, self.bg_label).astype(boundaries.dtype)
183            binary[labels == self.mask_label] = self.mask_label
184            target = np.concatenate([binary[None], boundaries], axis=0)
185        else:
186            target = boundaries
187
188        return target
189
190
191# TODO smoothing
192class BoundaryTransformWithIgnoreLabel:
193    """Transformation to convert an instance segmentation into boundaries.
194
195    This transformation sets boundaries with the ignore label to the ignore label
196    in the output of the transformation.
197
198    Args:
199        ignore_label: The ignore label.
200        mode: The mode for converting the segmentation to boundaries.
201        add_binary_target: Whether to add a binary mask channel to the transformation output.
202        ndim: The expected dimensionality of the data.
203    """
204    def __init__(
205        self,
206        ignore_label: int = -1,
207        mode: str = "thick",
208        add_binary_target: bool = False,
209        ndim: Optional[int] = None,
210    ):
211        self.ignore_label = ignore_label
212        self.mode = mode
213        self.ndim = ndim
214        self.add_binary_target = add_binary_target
215
216    def __call__(self, labels: np.ndarray) -> np.ndarray:
217        """Apply the boundary transformation to an input segmentation.
218
219        Args:
220            labels: The input segmentation.
221
222        Returns:
223            The boundaries.
224        """
225        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
226        # calculate the normal boundaries
227        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
228
229        # calculate the boundaries for the ignore label
230        labels_ignore = (labels == self.ignore_label)
231        to_ignore_boundaries = skimage.segmentation.find_boundaries(labels_ignore, mode=self.mode)[None]
232
233        # mask the to-background-boundaries
234        boundaries = boundaries.astype(np.int8)
235        boundaries[to_ignore_boundaries] = self.ignore_label
236
237        if self.add_binary_target:
238            binary = labels_to_binary(labels).astype(boundaries.dtype)
239            binary[labels == self.ignore_label] = self.ignore_label
240            target = np.concatenate([binary[None], boundaries], axis=0)
241        else:
242            target = boundaries
243
244        return target
245
246
247# TODO affinity smoothing
248class AffinityTransform:
249    """Transformation to compute affinities from a segmentation.
250
251    Args:
252        offsets: The offsets for computing affinities.
253        ignore_label: The ignore label to use for computing the ignore mask.
254        add_binary_target: Whether to add a binary channel to the affinities.
255        add_mask: Whether to add the ignore mask as extra output channels.
256        include_ignore_transitions: Whether transitions to the ignore label
257            should be positive in the ignore mask or negative in it.
258    """
259    def __init__(
260        self,
261        offsets: List[List[int]],
262        ignore_label: Optional[bool] = None,
263        add_binary_target: bool = False,
264        add_mask: bool = False,
265        include_ignore_transitions: bool = False,
266    ):
267        assert compute_affinities is not None
268        self.offsets = offsets
269        self.ndim = len(self.offsets[0])
270        assert self.ndim in (2, 3)
271
272        self.ignore_label = ignore_label
273        self.add_binary_target = add_binary_target
274        self.add_mask = add_mask
275        self.include_ignore_transitions = include_ignore_transitions
276
277    def add_ignore_transitions(self, affs, mask, labels):
278        """@private
279        """
280        ignore_seg = (labels == self.ignore_label).astype(labels.dtype)
281        ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets)
282        invalid_mask = np.logical_not(invalid_mask)
283        # NOTE affinity convention returned by affogato: transitions are marked by 0
284        ignore_transitions = ignore_transitions == 0
285        ignore_transitions[invalid_mask] = 0
286        affs[ignore_transitions] = 1
287        mask[ignore_transitions] = 1
288        return affs, mask
289
290    def __call__(self, labels: np.ndarray) -> np.ndarray:
291        """Compute the affinities.
292
293        Args:
294            labels: The segmentation.
295
296        Returns:
297            The affinities.
298        """
299        dtype = "uint64"
300        if np.dtype(labels.dtype) in (np.dtype("int16"), np.dtype("int32"), np.dtype("int64")):
301            dtype = "int64"
302        labels = ensure_spatial_array(labels, self.ndim, dtype=dtype)
303        affs, mask = compute_affinities(labels, self.offsets,
304                                        have_ignore_label=self.ignore_label is not None,
305                                        ignore_label=0 if self.ignore_label is None else self.ignore_label)
306        # we use the "disaffinity" convention for training; i.e. 1 means repulsive, 0 attractive
307        affs = 1. - affs
308
309        # remove transitions to the ignore label from the mask
310        if self.ignore_label is not None and self.include_ignore_transitions:
311            affs, mask = self.add_ignore_transitions(affs, mask, labels)
312
313        if self.add_binary_target:
314            binary = labels_to_binary(labels)[None].astype(affs.dtype)
315            assert binary.ndim == affs.ndim
316            affs = np.concatenate([binary, affs], axis=0)
317
318        if self.add_mask:
319            if self.add_binary_target:
320                if self.ignore_label is None:
321                    mask_for_bin = np.ones((1,) + labels.shape, dtype=mask.dtype)
322                else:
323                    mask_for_bin = (labels != self.ignore_label)[None].astype(mask.dtype)
324                assert mask.ndim == mask_for_bin.ndim
325                mask = np.concatenate([mask_for_bin, mask], axis=0)
326            assert affs.shape == mask.shape
327            affs = np.concatenate([affs, mask.astype(affs.dtype)], axis=0)
328
329        return affs
330
331
332class OneHotTransform:
333    """Transformations to compute one-hot labels from a semantic segmentation.
334
335    Args:
336        class_ids: The class ids to convert to one-hot labels.
337    """
338    def __init__(self, class_ids: Optional[Union[int, Sequence[int]]] = None):
339        self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids
340
341    def __call__(self, labels: np.ndarray) -> np.ndarray:
342        """Compute the one hot transformation.
343
344        Args:
345            labels: The segmentation.
346
347        Returns:
348            The one-hot transformation.
349        """
350        class_ids = np.unique(labels).tolist() if self.class_ids is None else self.class_ids
351        n_classes = len(class_ids)
352        one_hot = np.zeros((n_classes,) + labels.shape, dtype="float32")
353        for i, class_id in enumerate(class_ids):
354            one_hot[i][labels == class_id] = 1.0
355        return one_hot
356
357
358class DistanceTransform:
359    """Transformation to compute distances to foreground in the labels.
360
361    Args:
362        distances: Whether to compute the absolute distances.
363        directed_distances: Whether to compute the directed distances (vector distances).
364        normalize: Whether to normalize the computed distances.
365        max_distance: Maximal distance at which to threshold the distances.
366        foreground_id: Label id to which the distance is compute.
367        invert Whether to invert the distances:
368        func: Normalization function for the distances.
369    """
370    eps = 1e-7
371
372    def __init__(
373        self,
374        distances: bool = True,
375        directed_distances: bool = False,
376        normalize: bool = True,
377        max_distance: Optional[float] = None,
378        foreground_id: int = 1,
379        invert: bool = False,
380        func: Optional[Callable] = None,
381    ):
382        if sum((distances, directed_distances)) == 0:
383            raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
384        self.directed_distances = directed_distances
385        self.distances = distances
386        self.normalize = normalize
387        self.max_distance = max_distance
388        self.foreground_id = foreground_id
389        self.invert = invert
390        self.func = func
391
392    def _compute_distances(self, directed_distances):
393        distances = np.linalg.norm(directed_distances, axis=0)
394        if self.max_distance is not None:
395            distances = np.clip(distances, 0, self.max_distance)
396        if self.normalize:
397            distances /= (distances.max() + self.eps)
398        if self.invert:
399            distances = distances.max() - distances
400        if self.func is not None:
401            distances = self.func(distances)
402        return distances
403
404    def _compute_directed_distances(self, directed_distances):
405        if self.max_distance is not None:
406            directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance)
407        if self.normalize:
408            directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps)
409        if self.invert:
410            directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances
411        if self.func is not None:
412            directed_distances = self.func(directed_distances)
413        return directed_distances
414
415    def _get_distances_for_empty_labels(self, labels):
416        shape = labels.shape
417        fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2)
418        data = np.full((labels.ndim,) + shape, fill_value)
419        return data
420
421    def __call__(self, labels: np.ndarray) -> np.ndarray:
422        """Compute the distances.
423
424        Args:
425            labels: The segmentation.
426
427        Returns:
428            The distances.
429        """
430        distance_mask = (labels == self.foreground_id).astype("uint32")
431        # the distances are not computed corrected if they are all zero
432        # so this case needs to be handled separately
433        if distance_mask.sum() == 0:
434            directed_distances = self._get_distances_for_empty_labels(labels)
435        else:
436            ndim = distance_mask.ndim
437            to_channel_first = (ndim,) + tuple(range(ndim))
438            directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first)
439
440        if self.distances:
441            distances = self._compute_distances(directed_distances)
442
443        if self.directed_distances:
444            directed_distances = self._compute_directed_distances(directed_distances)
445
446        if self.distances and self.directed_distances:
447            return np.concatenate((distances[None], directed_distances), axis=0)
448        if self.distances:
449            return distances
450        if self.directed_distances:
451            return directed_distances
452
453
454class PerObjectDistanceTransform:
455    """Transformation to compute normalized distances per object in a segmentation.
456
457    Args:
458        distances: Whether to compute the undirected distances.
459        boundary_distances: Whether to compute the distances to the object boundaries.
460        directed_distances: Whether to compute the directed distances (vector distances).
461        foreground: Whether to return a foreground channel.
462        apply_label: Whether to apply connected components to the labels before computing distances.
463        correct_centers: Whether to correct centers that are not in the objects.
464        min_size: Minimal size of objects for distance calculdation.
465        distance_fill_value: Fill value for the distances outside of objects.
466        sampling: The spacing of the distance transform. This is especially relevant for anisotropic data;
467            for which it is recommended to use a sampling of (ANISOTROPY_FACTOR, 1, 1).
468    """
469    eps = 1e-7
470
471    def __init__(
472        self,
473        distances: bool = True,
474        boundary_distances: bool = True,
475        directed_distances: bool = False,
476        foreground: bool = True,
477        instances: bool = False,
478        apply_label: bool = True,
479        correct_centers: bool = True,
480        min_size: int = 0,
481        distance_fill_value: float = 1.0,
482        sampling: Optional[Tuple[float, ...]] = None
483    ):
484        if sum([distances, directed_distances, boundary_distances]) == 0:
485            raise ValueError("At least one of distances or directed distances has to be passed.")
486        self.distances = distances
487        self.boundary_distances = boundary_distances
488        self.directed_distances = directed_distances
489        self.foreground = foreground
490        self.instances = instances
491
492        self.apply_label = apply_label
493        self.correct_centers = correct_centers
494        self.min_size = min_size
495        self.distance_fill_value = distance_fill_value
496        self.sampling = sampling
497
498    def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
499        """@private
500        """
501        # Crop the mask and generate array with the correct center.
502        cropped_mask = mask[bb]
503        cropped_center = tuple(ce - b.start for ce, b in zip(center, bb))
504
505        # The centroid might not be inside of the object.
506        # In this case we correct the center by taking the maximum of the distance to the boundary.
507        # Note: the centroid is still the best estimate for the center, as long as it's in the object.
508        correct_center = not cropped_mask[cropped_center]
509
510        # Compute the boundary distances if necessary.
511        # (Either if we need to correct the center, or compute the boundary distances anyways.)
512        if correct_center or self.boundary_distances:
513            # Crop the boundary mask and compute the boundary distances.
514            cropped_boundary_mask = boundaries[bb]
515            boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask, pixel_pitch=self.sampling)
516            boundary_distances[~cropped_mask] = 0
517            max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape)
518
519        # Set the crop center to the max dist point
520        if correct_center:
521            # Find the center (= maximal distance from the boundaries).
522            cropped_center = max_dist_point
523
524        cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32")
525        cropped_center_mask[cropped_center] = 1
526
527        # Compute the directed distances,
528        if self.distances or self.directed_distances:
529            this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask, pixel_pitch=self.sampling)
530        else:
531            this_distances = None
532
533        # Keep only the specified distances:
534        if self.distances and self.directed_distances:  # all distances
535            # Compute the undirected ditacnes from directed distances and concatenate,
536            undir = np.linalg.norm(this_distances, axis=-1, keepdims=True)
537            this_distances = np.concatenate([undir, this_distances], axis=-1)
538
539        elif self.distances:  # only undirected distances
540            # Compute the undirected distances from directed distances and keep only them.
541            this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True)
542
543        elif self.directed_distances:  # only directed distances
544            pass  # We don't have to do anything becasue the directed distances are already computed.
545
546        # Add an extra channel for the boundary distances if specified.
547        if self.boundary_distances:
548            boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None]
549            if this_distances is None:
550                this_distances = boundary_distances
551            else:
552                this_distances = np.concatenate([this_distances, boundary_distances], axis=-1)
553
554        # Set distances outside of the mask to zero.
555        this_distances[~cropped_mask] = 0
556
557        # Normalize the distances.
558        spatial_axes = tuple(range(mask.ndim))
559        this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps)
560
561        # Set the distance values in the global result.
562        distances[bb][cropped_mask] = this_distances[cropped_mask]
563
564        return distances
565
566    def __call__(self, labels: np.ndarray) -> np.ndarray:
567        """Compute the per object distance transform.
568
569        Args:
570            labels: The segmentation
571
572        Returns:
573            The distances.
574        """
575        # Apply label (connected components) if specified.
576        if self.apply_label:
577            labels = skimage.measure.label(labels).astype("uint32")
578        else:  # Otherwise just relabel the segmentation.
579            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
580
581        # Filter out small objects if min_size is specified.
582        if self.min_size > 0:
583            ids, sizes = np.unique(labels, return_counts=True)
584            discard_ids = ids[sizes < self.min_size]
585            labels[np.isin(labels, discard_ids)] = 0
586            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
587
588        # Compute the boundaries. They will be used to determine the most central point,
589        # and if 'self.boundary_distances is True' to add the boundary distances.
590        boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32")
591
592        # Compute region properties to derive bounding boxes and centers.
593        ndim = labels.ndim
594        props = skimage.measure.regionprops(labels)
595        bounding_boxes = {
596            prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim)) for prop in props
597        }
598
599        # Compute the object centers from centroids.
600        centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props}
601
602        # Compute how many distance channels we have.
603        n_channels = 0
604        if self.distances:  # We need one channel for the overall distances.
605            n_channels += 1
606        if self.boundary_distances:  # We need one channel for the boundary distances.
607            n_channels += 1
608        if self.directed_distances:  # And ndim channels for directed distances.
609            n_channels += ndim
610
611        # Compute the per object distances.
612        distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32")
613        for prop in props:
614            label_id = prop.label
615            mask = labels == label_id
616            distances = self.compute_normalized_object_distances(
617                mask, boundaries, bounding_boxes[label_id], centers[label_id], distances
618            )
619
620        # Bring the distance channel to the first dimension.
621        to_channel_first = (ndim,) + tuple(range(ndim))
622        distances = distances.transpose(to_channel_first)
623
624        # Add the foreground mask as first channel if specified.
625        if self.foreground:
626            binary_labels = (labels > 0).astype("float32")
627            distances = np.concatenate([binary_labels[None], distances], axis=0)
628
629        if self.instances:
630            distances = np.concatenate([labels[None], distances], axis=0)
631
632        return distances
def connected_components( labels: numpy.ndarray, ndim: Optional[int] = None, ensure_zero: bool = False) -> numpy.ndarray:
17def connected_components(labels: np.ndarray, ndim: Optional[int] = None, ensure_zero: bool = False) -> np.ndarray:
18    """Apply connected components to a segmentation.
19
20    Args:
21        labels: The input segmentation.
22        ndim: The expected dimensionality of the data.
23        ensure_zero: Whether to ensure that the data has a zero label.
24
25    Returns:
26        The segmentation after connected components.
27    """
28    labels = ensure_array(labels) if ndim is None else ensure_spatial_array(labels, ndim)
29    labels = skimage.measure.label(labels)
30    if ensure_zero and 0 not in labels:
31        labels -= 1
32    return labels

Apply connected components to a segmentation.

Arguments:
  • labels: The input segmentation.
  • ndim: The expected dimensionality of the data.
  • ensure_zero: Whether to ensure that the data has a zero label.
Returns:

The segmentation after connected components.

def labels_to_binary(labels: numpy.ndarray, background_label: int = 0) -> numpy.ndarray:
35def labels_to_binary(labels: np.ndarray, background_label: int = 0) -> np.ndarray:
36    """Transform a segmentation to binary labels.
37
38    Args:
39        labels: The input segmentation.
40        background_label: The id of the background label.
41
42    Returns:
43        The binary segmentation.
44    """
45    return (labels != background_label).astype(labels.dtype)

Transform a segmentation to binary labels.

Arguments:
  • labels: The input segmentation.
  • background_label: The id of the background label.
Returns:

The binary segmentation.

def label_consecutive(labels: numpy.ndarray, with_background: bool = True) -> numpy.ndarray:
48def label_consecutive(labels: np.ndarray, with_background: bool = True) -> np.ndarray:
49    """Ensure that the input segmentation is labeled consecutively.
50
51    Args:
52        labels: The input segmentation.
53        with_background: Whether this segmentation has a background label.
54
55    Returns:
56        The consecutively labeled segmentation.
57    """
58    if with_background:
59        seg = skimage.segmentation.relabel_sequential(labels)[0]
60    else:
61        if 0 in labels:
62            labels += 1
63        seg = skimage.segmentation.relabel_sequential(labels)[0]
64        assert seg.min() == 1
65        seg -= 1
66    return seg

Ensure that the input segmentation is labeled consecutively.

Arguments:
  • labels: The input segmentation.
  • with_background: Whether this segmentation has a background label.
Returns:

The consecutively labeled segmentation.

class MinSizeLabelTransform:
69class MinSizeLabelTransform:
70    """Transformation to filter out objects smaller than a minimal size from the segmentation.
71
72    Args:
73        min_size: The minimal object size of the segmentation.
74        ndim: The dimensionality of the segmentation.
75        ensure_zero: Ensure that the segmentation contains the id zero.
76    """
77    def __init__(self, min_size: Optional[int] = None, ndim: Optional[int] = None, ensure_zero: bool = False):
78        self.min_size = min_size
79        self.ndim = ndim
80        self.ensure_zero = ensure_zero
81
82    def __call__(self, labels: np.ndarray) -> np.ndarray:
83        """Filter out small objects from segmentation.
84
85        Args:
86            labels: The input segmentation.
87
88        Returns:
89            The size filtered segmentation.
90        """
91        components = connected_components(labels, ndim=self.ndim, ensure_zero=self.ensure_zero)
92        if self.min_size is not None:
93            ids, sizes = np.unique(components, return_counts=True)
94            filter_ids = ids[sizes < self.min_size]
95            components[np.isin(components, filter_ids)] = 0
96            components, _, _ = skimage.segmentation.relabel_sequential(components)
97        return components

Transformation to filter out objects smaller than a minimal size from the segmentation.

Arguments:
  • min_size: The minimal object size of the segmentation.
  • ndim: The dimensionality of the segmentation.
  • ensure_zero: Ensure that the segmentation contains the id zero.
MinSizeLabelTransform( min_size: Optional[int] = None, ndim: Optional[int] = None, ensure_zero: bool = False)
77    def __init__(self, min_size: Optional[int] = None, ndim: Optional[int] = None, ensure_zero: bool = False):
78        self.min_size = min_size
79        self.ndim = ndim
80        self.ensure_zero = ensure_zero
min_size
ndim
ensure_zero
class BoundaryTransform:
101class BoundaryTransform:
102    """Transformation to convert an instance segmentation into boundaries.
103
104    Args:
105        mode: The mode for converting the segmentation to boundaries.
106        add_binary_target: Whether to add a binary mask channel to the transformation output.
107        ndim: The expected dimensionality of the data.
108    """
109    def __init__(self, mode: str = "thick", add_binary_target: bool = False, ndim: Optional[int] = None):
110        self.mode = mode
111        self.add_binary_target = add_binary_target
112        self.ndim = ndim
113
114    def __call__(self, labels: np.ndarray) -> np.ndarray:
115        """Apply the boundary transformation to an input segmentation.
116
117        Args:
118            labels: The input segmentation.
119
120        Returns:
121            The boundaries.
122        """
123        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
124        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
125        if self.add_binary_target:
126            binary = labels_to_binary(labels)[None].astype(boundaries.dtype)
127            target = np.concatenate([binary, boundaries], axis=0)
128        else:
129            target = boundaries
130        return target

Transformation to convert an instance segmentation into boundaries.

Arguments:
  • mode: The mode for converting the segmentation to boundaries.
  • add_binary_target: Whether to add a binary mask channel to the transformation output.
  • ndim: The expected dimensionality of the data.
BoundaryTransform( mode: str = 'thick', add_binary_target: bool = False, ndim: Optional[int] = None)
109    def __init__(self, mode: str = "thick", add_binary_target: bool = False, ndim: Optional[int] = None):
110        self.mode = mode
111        self.add_binary_target = add_binary_target
112        self.ndim = ndim
mode
add_binary_target
ndim
class NoToBackgroundBoundaryTransform:
134class NoToBackgroundBoundaryTransform:
135    """Transformation to convert an instance segmentation into boundaries.
136
137    This transformation sets boundaries with the ignore label to the ignore label
138    in the output of the transformation.
139
140    Args:
141        bg_label: The background label.
142        mask_label: The mask label.
143        mode: The mode for converting the segmentation to boundaries.
144        add_binary_target: Whether to add a binary mask channel to the transformation output.
145        ndim: The expected dimensionality of the data.
146    """
147    def __init__(
148        self,
149        bg_label: int = 0,
150        mask_label: int = -1,
151        mode: str = "thick",
152        add_binary_target: bool = False,
153        ndim: Optional[int] = None,
154    ):
155        self.bg_label = bg_label
156        self.mask_label = mask_label
157        self.mode = mode
158        self.ndim = ndim
159        self.add_binary_target = add_binary_target
160
161    def __call__(self, labels: np.ndarray) -> np.ndarray:
162        """Apply the boundary transformation to an input segmentation.
163
164        Args:
165            labels: The input segmentation.
166
167        Returns:
168            The boundaries.
169        """
170        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
171        # calc normal boundaries
172        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
173
174        # make label image binary and calculate to-background-boundaries
175        labels_binary = (labels != self.bg_label)
176        to_bg_boundaries = skimage.segmentation.find_boundaries(labels_binary, mode=self.mode)[None]
177
178        # mask the to-background-boundaries
179        boundaries = boundaries.astype(np.int8)
180        boundaries[to_bg_boundaries] = self.mask_label
181
182        if self.add_binary_target:
183            binary = labels_to_binary(labels, self.bg_label).astype(boundaries.dtype)
184            binary[labels == self.mask_label] = self.mask_label
185            target = np.concatenate([binary[None], boundaries], axis=0)
186        else:
187            target = boundaries
188
189        return target

Transformation to convert an instance segmentation into boundaries.

This transformation sets boundaries with the ignore label to the ignore label in the output of the transformation.

Arguments:
  • bg_label: The background label.
  • mask_label: The mask label.
  • mode: The mode for converting the segmentation to boundaries.
  • add_binary_target: Whether to add a binary mask channel to the transformation output.
  • ndim: The expected dimensionality of the data.
NoToBackgroundBoundaryTransform( bg_label: int = 0, mask_label: int = -1, mode: str = 'thick', add_binary_target: bool = False, ndim: Optional[int] = None)
147    def __init__(
148        self,
149        bg_label: int = 0,
150        mask_label: int = -1,
151        mode: str = "thick",
152        add_binary_target: bool = False,
153        ndim: Optional[int] = None,
154    ):
155        self.bg_label = bg_label
156        self.mask_label = mask_label
157        self.mode = mode
158        self.ndim = ndim
159        self.add_binary_target = add_binary_target
bg_label
mask_label
mode
ndim
add_binary_target
class BoundaryTransformWithIgnoreLabel:
193class BoundaryTransformWithIgnoreLabel:
194    """Transformation to convert an instance segmentation into boundaries.
195
196    This transformation sets boundaries with the ignore label to the ignore label
197    in the output of the transformation.
198
199    Args:
200        ignore_label: The ignore label.
201        mode: The mode for converting the segmentation to boundaries.
202        add_binary_target: Whether to add a binary mask channel to the transformation output.
203        ndim: The expected dimensionality of the data.
204    """
205    def __init__(
206        self,
207        ignore_label: int = -1,
208        mode: str = "thick",
209        add_binary_target: bool = False,
210        ndim: Optional[int] = None,
211    ):
212        self.ignore_label = ignore_label
213        self.mode = mode
214        self.ndim = ndim
215        self.add_binary_target = add_binary_target
216
217    def __call__(self, labels: np.ndarray) -> np.ndarray:
218        """Apply the boundary transformation to an input segmentation.
219
220        Args:
221            labels: The input segmentation.
222
223        Returns:
224            The boundaries.
225        """
226        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
227        # calculate the normal boundaries
228        boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]
229
230        # calculate the boundaries for the ignore label
231        labels_ignore = (labels == self.ignore_label)
232        to_ignore_boundaries = skimage.segmentation.find_boundaries(labels_ignore, mode=self.mode)[None]
233
234        # mask the to-background-boundaries
235        boundaries = boundaries.astype(np.int8)
236        boundaries[to_ignore_boundaries] = self.ignore_label
237
238        if self.add_binary_target:
239            binary = labels_to_binary(labels).astype(boundaries.dtype)
240            binary[labels == self.ignore_label] = self.ignore_label
241            target = np.concatenate([binary[None], boundaries], axis=0)
242        else:
243            target = boundaries
244
245        return target

Transformation to convert an instance segmentation into boundaries.

This transformation sets boundaries with the ignore label to the ignore label in the output of the transformation.

Arguments:
  • ignore_label: The ignore label.
  • mode: The mode for converting the segmentation to boundaries.
  • add_binary_target: Whether to add a binary mask channel to the transformation output.
  • ndim: The expected dimensionality of the data.
BoundaryTransformWithIgnoreLabel( ignore_label: int = -1, mode: str = 'thick', add_binary_target: bool = False, ndim: Optional[int] = None)
205    def __init__(
206        self,
207        ignore_label: int = -1,
208        mode: str = "thick",
209        add_binary_target: bool = False,
210        ndim: Optional[int] = None,
211    ):
212        self.ignore_label = ignore_label
213        self.mode = mode
214        self.ndim = ndim
215        self.add_binary_target = add_binary_target
ignore_label
mode
ndim
add_binary_target
class AffinityTransform:
249class AffinityTransform:
250    """Transformation to compute affinities from a segmentation.
251
252    Args:
253        offsets: The offsets for computing affinities.
254        ignore_label: The ignore label to use for computing the ignore mask.
255        add_binary_target: Whether to add a binary channel to the affinities.
256        add_mask: Whether to add the ignore mask as extra output channels.
257        include_ignore_transitions: Whether transitions to the ignore label
258            should be positive in the ignore mask or negative in it.
259    """
260    def __init__(
261        self,
262        offsets: List[List[int]],
263        ignore_label: Optional[bool] = None,
264        add_binary_target: bool = False,
265        add_mask: bool = False,
266        include_ignore_transitions: bool = False,
267    ):
268        assert compute_affinities is not None
269        self.offsets = offsets
270        self.ndim = len(self.offsets[0])
271        assert self.ndim in (2, 3)
272
273        self.ignore_label = ignore_label
274        self.add_binary_target = add_binary_target
275        self.add_mask = add_mask
276        self.include_ignore_transitions = include_ignore_transitions
277
278    def add_ignore_transitions(self, affs, mask, labels):
279        """@private
280        """
281        ignore_seg = (labels == self.ignore_label).astype(labels.dtype)
282        ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets)
283        invalid_mask = np.logical_not(invalid_mask)
284        # NOTE affinity convention returned by affogato: transitions are marked by 0
285        ignore_transitions = ignore_transitions == 0
286        ignore_transitions[invalid_mask] = 0
287        affs[ignore_transitions] = 1
288        mask[ignore_transitions] = 1
289        return affs, mask
290
291    def __call__(self, labels: np.ndarray) -> np.ndarray:
292        """Compute the affinities.
293
294        Args:
295            labels: The segmentation.
296
297        Returns:
298            The affinities.
299        """
300        dtype = "uint64"
301        if np.dtype(labels.dtype) in (np.dtype("int16"), np.dtype("int32"), np.dtype("int64")):
302            dtype = "int64"
303        labels = ensure_spatial_array(labels, self.ndim, dtype=dtype)
304        affs, mask = compute_affinities(labels, self.offsets,
305                                        have_ignore_label=self.ignore_label is not None,
306                                        ignore_label=0 if self.ignore_label is None else self.ignore_label)
307        # we use the "disaffinity" convention for training; i.e. 1 means repulsive, 0 attractive
308        affs = 1. - affs
309
310        # remove transitions to the ignore label from the mask
311        if self.ignore_label is not None and self.include_ignore_transitions:
312            affs, mask = self.add_ignore_transitions(affs, mask, labels)
313
314        if self.add_binary_target:
315            binary = labels_to_binary(labels)[None].astype(affs.dtype)
316            assert binary.ndim == affs.ndim
317            affs = np.concatenate([binary, affs], axis=0)
318
319        if self.add_mask:
320            if self.add_binary_target:
321                if self.ignore_label is None:
322                    mask_for_bin = np.ones((1,) + labels.shape, dtype=mask.dtype)
323                else:
324                    mask_for_bin = (labels != self.ignore_label)[None].astype(mask.dtype)
325                assert mask.ndim == mask_for_bin.ndim
326                mask = np.concatenate([mask_for_bin, mask], axis=0)
327            assert affs.shape == mask.shape
328            affs = np.concatenate([affs, mask.astype(affs.dtype)], axis=0)
329
330        return affs

Transformation to compute affinities from a segmentation.

Arguments:
  • offsets: The offsets for computing affinities.
  • ignore_label: The ignore label to use for computing the ignore mask.
  • add_binary_target: Whether to add a binary channel to the affinities.
  • add_mask: Whether to add the ignore mask as extra output channels.
  • include_ignore_transitions: Whether transitions to the ignore label should be positive in the ignore mask or negative in it.
AffinityTransform( offsets: List[List[int]], ignore_label: Optional[bool] = None, add_binary_target: bool = False, add_mask: bool = False, include_ignore_transitions: bool = False)
260    def __init__(
261        self,
262        offsets: List[List[int]],
263        ignore_label: Optional[bool] = None,
264        add_binary_target: bool = False,
265        add_mask: bool = False,
266        include_ignore_transitions: bool = False,
267    ):
268        assert compute_affinities is not None
269        self.offsets = offsets
270        self.ndim = len(self.offsets[0])
271        assert self.ndim in (2, 3)
272
273        self.ignore_label = ignore_label
274        self.add_binary_target = add_binary_target
275        self.add_mask = add_mask
276        self.include_ignore_transitions = include_ignore_transitions
offsets
ndim
ignore_label
add_binary_target
add_mask
include_ignore_transitions
class OneHotTransform:
333class OneHotTransform:
334    """Transformations to compute one-hot labels from a semantic segmentation.
335
336    Args:
337        class_ids: The class ids to convert to one-hot labels.
338    """
339    def __init__(self, class_ids: Optional[Union[int, Sequence[int]]] = None):
340        self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids
341
342    def __call__(self, labels: np.ndarray) -> np.ndarray:
343        """Compute the one hot transformation.
344
345        Args:
346            labels: The segmentation.
347
348        Returns:
349            The one-hot transformation.
350        """
351        class_ids = np.unique(labels).tolist() if self.class_ids is None else self.class_ids
352        n_classes = len(class_ids)
353        one_hot = np.zeros((n_classes,) + labels.shape, dtype="float32")
354        for i, class_id in enumerate(class_ids):
355            one_hot[i][labels == class_id] = 1.0
356        return one_hot

Transformations to compute one-hot labels from a semantic segmentation.

Arguments:
  • class_ids: The class ids to convert to one-hot labels.
OneHotTransform(class_ids: Union[int, Sequence[int], NoneType] = None)
339    def __init__(self, class_ids: Optional[Union[int, Sequence[int]]] = None):
340        self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids
class_ids
class DistanceTransform:
359class DistanceTransform:
360    """Transformation to compute distances to foreground in the labels.
361
362    Args:
363        distances: Whether to compute the absolute distances.
364        directed_distances: Whether to compute the directed distances (vector distances).
365        normalize: Whether to normalize the computed distances.
366        max_distance: Maximal distance at which to threshold the distances.
367        foreground_id: Label id to which the distance is compute.
368        invert Whether to invert the distances:
369        func: Normalization function for the distances.
370    """
371    eps = 1e-7
372
373    def __init__(
374        self,
375        distances: bool = True,
376        directed_distances: bool = False,
377        normalize: bool = True,
378        max_distance: Optional[float] = None,
379        foreground_id: int = 1,
380        invert: bool = False,
381        func: Optional[Callable] = None,
382    ):
383        if sum((distances, directed_distances)) == 0:
384            raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
385        self.directed_distances = directed_distances
386        self.distances = distances
387        self.normalize = normalize
388        self.max_distance = max_distance
389        self.foreground_id = foreground_id
390        self.invert = invert
391        self.func = func
392
393    def _compute_distances(self, directed_distances):
394        distances = np.linalg.norm(directed_distances, axis=0)
395        if self.max_distance is not None:
396            distances = np.clip(distances, 0, self.max_distance)
397        if self.normalize:
398            distances /= (distances.max() + self.eps)
399        if self.invert:
400            distances = distances.max() - distances
401        if self.func is not None:
402            distances = self.func(distances)
403        return distances
404
405    def _compute_directed_distances(self, directed_distances):
406        if self.max_distance is not None:
407            directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance)
408        if self.normalize:
409            directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps)
410        if self.invert:
411            directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances
412        if self.func is not None:
413            directed_distances = self.func(directed_distances)
414        return directed_distances
415
416    def _get_distances_for_empty_labels(self, labels):
417        shape = labels.shape
418        fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2)
419        data = np.full((labels.ndim,) + shape, fill_value)
420        return data
421
422    def __call__(self, labels: np.ndarray) -> np.ndarray:
423        """Compute the distances.
424
425        Args:
426            labels: The segmentation.
427
428        Returns:
429            The distances.
430        """
431        distance_mask = (labels == self.foreground_id).astype("uint32")
432        # the distances are not computed corrected if they are all zero
433        # so this case needs to be handled separately
434        if distance_mask.sum() == 0:
435            directed_distances = self._get_distances_for_empty_labels(labels)
436        else:
437            ndim = distance_mask.ndim
438            to_channel_first = (ndim,) + tuple(range(ndim))
439            directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first)
440
441        if self.distances:
442            distances = self._compute_distances(directed_distances)
443
444        if self.directed_distances:
445            directed_distances = self._compute_directed_distances(directed_distances)
446
447        if self.distances and self.directed_distances:
448            return np.concatenate((distances[None], directed_distances), axis=0)
449        if self.distances:
450            return distances
451        if self.directed_distances:
452            return directed_distances

Transformation to compute distances to foreground in the labels.

Arguments:
  • distances: Whether to compute the absolute distances.
  • directed_distances: Whether to compute the directed distances (vector distances).
  • normalize: Whether to normalize the computed distances.
  • max_distance: Maximal distance at which to threshold the distances.
  • foreground_id: Label id to which the distance is compute.
  • invert Whether to invert the distances:
  • func: Normalization function for the distances.
DistanceTransform( distances: bool = True, directed_distances: bool = False, normalize: bool = True, max_distance: Optional[float] = None, foreground_id: int = 1, invert: bool = False, func: Optional[Callable] = None)
373    def __init__(
374        self,
375        distances: bool = True,
376        directed_distances: bool = False,
377        normalize: bool = True,
378        max_distance: Optional[float] = None,
379        foreground_id: int = 1,
380        invert: bool = False,
381        func: Optional[Callable] = None,
382    ):
383        if sum((distances, directed_distances)) == 0:
384            raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
385        self.directed_distances = directed_distances
386        self.distances = distances
387        self.normalize = normalize
388        self.max_distance = max_distance
389        self.foreground_id = foreground_id
390        self.invert = invert
391        self.func = func
eps = 1e-07
directed_distances
distances
normalize
max_distance
foreground_id
invert
func
class PerObjectDistanceTransform:
455class PerObjectDistanceTransform:
456    """Transformation to compute normalized distances per object in a segmentation.
457
458    Args:
459        distances: Whether to compute the undirected distances.
460        boundary_distances: Whether to compute the distances to the object boundaries.
461        directed_distances: Whether to compute the directed distances (vector distances).
462        foreground: Whether to return a foreground channel.
463        apply_label: Whether to apply connected components to the labels before computing distances.
464        correct_centers: Whether to correct centers that are not in the objects.
465        min_size: Minimal size of objects for distance calculdation.
466        distance_fill_value: Fill value for the distances outside of objects.
467        sampling: The spacing of the distance transform. This is especially relevant for anisotropic data;
468            for which it is recommended to use a sampling of (ANISOTROPY_FACTOR, 1, 1).
469    """
470    eps = 1e-7
471
472    def __init__(
473        self,
474        distances: bool = True,
475        boundary_distances: bool = True,
476        directed_distances: bool = False,
477        foreground: bool = True,
478        instances: bool = False,
479        apply_label: bool = True,
480        correct_centers: bool = True,
481        min_size: int = 0,
482        distance_fill_value: float = 1.0,
483        sampling: Optional[Tuple[float, ...]] = None
484    ):
485        if sum([distances, directed_distances, boundary_distances]) == 0:
486            raise ValueError("At least one of distances or directed distances has to be passed.")
487        self.distances = distances
488        self.boundary_distances = boundary_distances
489        self.directed_distances = directed_distances
490        self.foreground = foreground
491        self.instances = instances
492
493        self.apply_label = apply_label
494        self.correct_centers = correct_centers
495        self.min_size = min_size
496        self.distance_fill_value = distance_fill_value
497        self.sampling = sampling
498
499    def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
500        """@private
501        """
502        # Crop the mask and generate array with the correct center.
503        cropped_mask = mask[bb]
504        cropped_center = tuple(ce - b.start for ce, b in zip(center, bb))
505
506        # The centroid might not be inside of the object.
507        # In this case we correct the center by taking the maximum of the distance to the boundary.
508        # Note: the centroid is still the best estimate for the center, as long as it's in the object.
509        correct_center = not cropped_mask[cropped_center]
510
511        # Compute the boundary distances if necessary.
512        # (Either if we need to correct the center, or compute the boundary distances anyways.)
513        if correct_center or self.boundary_distances:
514            # Crop the boundary mask and compute the boundary distances.
515            cropped_boundary_mask = boundaries[bb]
516            boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask, pixel_pitch=self.sampling)
517            boundary_distances[~cropped_mask] = 0
518            max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape)
519
520        # Set the crop center to the max dist point
521        if correct_center:
522            # Find the center (= maximal distance from the boundaries).
523            cropped_center = max_dist_point
524
525        cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32")
526        cropped_center_mask[cropped_center] = 1
527
528        # Compute the directed distances,
529        if self.distances or self.directed_distances:
530            this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask, pixel_pitch=self.sampling)
531        else:
532            this_distances = None
533
534        # Keep only the specified distances:
535        if self.distances and self.directed_distances:  # all distances
536            # Compute the undirected ditacnes from directed distances and concatenate,
537            undir = np.linalg.norm(this_distances, axis=-1, keepdims=True)
538            this_distances = np.concatenate([undir, this_distances], axis=-1)
539
540        elif self.distances:  # only undirected distances
541            # Compute the undirected distances from directed distances and keep only them.
542            this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True)
543
544        elif self.directed_distances:  # only directed distances
545            pass  # We don't have to do anything becasue the directed distances are already computed.
546
547        # Add an extra channel for the boundary distances if specified.
548        if self.boundary_distances:
549            boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None]
550            if this_distances is None:
551                this_distances = boundary_distances
552            else:
553                this_distances = np.concatenate([this_distances, boundary_distances], axis=-1)
554
555        # Set distances outside of the mask to zero.
556        this_distances[~cropped_mask] = 0
557
558        # Normalize the distances.
559        spatial_axes = tuple(range(mask.ndim))
560        this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps)
561
562        # Set the distance values in the global result.
563        distances[bb][cropped_mask] = this_distances[cropped_mask]
564
565        return distances
566
567    def __call__(self, labels: np.ndarray) -> np.ndarray:
568        """Compute the per object distance transform.
569
570        Args:
571            labels: The segmentation
572
573        Returns:
574            The distances.
575        """
576        # Apply label (connected components) if specified.
577        if self.apply_label:
578            labels = skimage.measure.label(labels).astype("uint32")
579        else:  # Otherwise just relabel the segmentation.
580            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
581
582        # Filter out small objects if min_size is specified.
583        if self.min_size > 0:
584            ids, sizes = np.unique(labels, return_counts=True)
585            discard_ids = ids[sizes < self.min_size]
586            labels[np.isin(labels, discard_ids)] = 0
587            labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")
588
589        # Compute the boundaries. They will be used to determine the most central point,
590        # and if 'self.boundary_distances is True' to add the boundary distances.
591        boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32")
592
593        # Compute region properties to derive bounding boxes and centers.
594        ndim = labels.ndim
595        props = skimage.measure.regionprops(labels)
596        bounding_boxes = {
597            prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim)) for prop in props
598        }
599
600        # Compute the object centers from centroids.
601        centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props}
602
603        # Compute how many distance channels we have.
604        n_channels = 0
605        if self.distances:  # We need one channel for the overall distances.
606            n_channels += 1
607        if self.boundary_distances:  # We need one channel for the boundary distances.
608            n_channels += 1
609        if self.directed_distances:  # And ndim channels for directed distances.
610            n_channels += ndim
611
612        # Compute the per object distances.
613        distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32")
614        for prop in props:
615            label_id = prop.label
616            mask = labels == label_id
617            distances = self.compute_normalized_object_distances(
618                mask, boundaries, bounding_boxes[label_id], centers[label_id], distances
619            )
620
621        # Bring the distance channel to the first dimension.
622        to_channel_first = (ndim,) + tuple(range(ndim))
623        distances = distances.transpose(to_channel_first)
624
625        # Add the foreground mask as first channel if specified.
626        if self.foreground:
627            binary_labels = (labels > 0).astype("float32")
628            distances = np.concatenate([binary_labels[None], distances], axis=0)
629
630        if self.instances:
631            distances = np.concatenate([labels[None], distances], axis=0)
632
633        return distances

Transformation to compute normalized distances per object in a segmentation.

Arguments:
  • distances: Whether to compute the undirected distances.
  • boundary_distances: Whether to compute the distances to the object boundaries.
  • directed_distances: Whether to compute the directed distances (vector distances).
  • foreground: Whether to return a foreground channel.
  • apply_label: Whether to apply connected components to the labels before computing distances.
  • correct_centers: Whether to correct centers that are not in the objects.
  • min_size: Minimal size of objects for distance calculdation.
  • distance_fill_value: Fill value for the distances outside of objects.
  • sampling: The spacing of the distance transform. This is especially relevant for anisotropic data; for which it is recommended to use a sampling of (ANISOTROPY_FACTOR, 1, 1).
PerObjectDistanceTransform( distances: bool = True, boundary_distances: bool = True, directed_distances: bool = False, foreground: bool = True, instances: bool = False, apply_label: bool = True, correct_centers: bool = True, min_size: int = 0, distance_fill_value: float = 1.0, sampling: Optional[Tuple[float, ...]] = None)
472    def __init__(
473        self,
474        distances: bool = True,
475        boundary_distances: bool = True,
476        directed_distances: bool = False,
477        foreground: bool = True,
478        instances: bool = False,
479        apply_label: bool = True,
480        correct_centers: bool = True,
481        min_size: int = 0,
482        distance_fill_value: float = 1.0,
483        sampling: Optional[Tuple[float, ...]] = None
484    ):
485        if sum([distances, directed_distances, boundary_distances]) == 0:
486            raise ValueError("At least one of distances or directed distances has to be passed.")
487        self.distances = distances
488        self.boundary_distances = boundary_distances
489        self.directed_distances = directed_distances
490        self.foreground = foreground
491        self.instances = instances
492
493        self.apply_label = apply_label
494        self.correct_centers = correct_centers
495        self.min_size = min_size
496        self.distance_fill_value = distance_fill_value
497        self.sampling = sampling
eps = 1e-07
distances
boundary_distances
directed_distances
foreground
instances
apply_label
correct_centers
min_size
distance_fill_value
sampling