torch_em.shallow2deep.transform

 1from typing import Optional
 2
 3import numpy as np
 4import skimage.segmentation
 5from scipy.ndimage.morphology import distance_transform_edt
 6from torch_em.util import ensure_array, ensure_spatial_array
 7
 8
 9class ForegroundTransform:
10    """Transformation to convert labels into a foreground mask.
11
12    Args:
13        label_id: The label id to use for extracting the foreground mask.
14            If None, all label values larger than zero will be used to compute the foreground mask.
15        ndim: The dimensionality of the data.
16        ignore_radius: The radius around the foreground label to set to the ignore label.
17    """
18    def __init__(self, label_id: Optional[int] = None, ndim: Optional[int] = None, ignore_radius: int = 1):
19        self.label_id = label_id
20        self.ndim = ndim
21        self.ignore_radius = ignore_radius
22
23    def __call__(self, labels: np.ndarray) -> np.ndarray:
24        """Apply the transformation to the segmentation data.
25
26        Args:
27            labels: The segmentation data.
28
29        Returns:
30            The foreground mask.
31        """
32        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
33        target = labels > 0 if self.label_id is None else labels == self.label_id
34        if self.ignore_radius > 0:
35            dist = distance_transform_edt(target == 0)
36            ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0)
37            target[ignore_mask] = -1
38        return target[None]
39
40
41class BoundaryTransform:
42    """Transformation to convert labels into boundaries.
43
44    Args:
45        mode: The mode for computing the boundaries.
46        ndim: The dimensionality of the data.
47        ignore_radius: The radius around the foreground label to set to the ignore label.
48        add_binary_target: Whether to add a binary mask as additional channel.
49    """
50    def __init__(
51        self,
52        mode: str = "thick",
53        ndim: Optional[int] = None,
54        ignore_radius: int = 2,
55        add_binary_target: bool = False
56    ):
57        self.mode = mode
58        self.ndim = ndim
59        self.ignore_radius = ignore_radius
60        self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None
61
62    def __call__(self, labels: np.ndarray) -> np.ndarray:
63        """Apply the boundary transform to the data.
64
65        Args:
66            labels: The segmentation data.
67
68        Returns:
69            The transformed labels.
70        """
71        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
72        target = skimage.segmentation.find_boundaries(labels, mode=self.mode).astype("int8")
73
74        if self.ignore_radius > 0:
75            dist = distance_transform_edt(target == 0)
76            ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0)
77            target[ignore_mask] = -1
78
79        if self.foreground_trafo is not None:
80            target[target == 1] = 2
81            fg_target = self.foreground_trafo(labels)[0]
82            assert fg_target.shape == target.shape, f"{fg_target}.shape, {target.shape}"
83            fg_mask = np.logical_and(fg_target == 1, target == 0)
84            target[fg_mask] = 1
85
86        return target[None]
class ForegroundTransform:
10class ForegroundTransform:
11    """Transformation to convert labels into a foreground mask.
12
13    Args:
14        label_id: The label id to use for extracting the foreground mask.
15            If None, all label values larger than zero will be used to compute the foreground mask.
16        ndim: The dimensionality of the data.
17        ignore_radius: The radius around the foreground label to set to the ignore label.
18    """
19    def __init__(self, label_id: Optional[int] = None, ndim: Optional[int] = None, ignore_radius: int = 1):
20        self.label_id = label_id
21        self.ndim = ndim
22        self.ignore_radius = ignore_radius
23
24    def __call__(self, labels: np.ndarray) -> np.ndarray:
25        """Apply the transformation to the segmentation data.
26
27        Args:
28            labels: The segmentation data.
29
30        Returns:
31            The foreground mask.
32        """
33        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
34        target = labels > 0 if self.label_id is None else labels == self.label_id
35        if self.ignore_radius > 0:
36            dist = distance_transform_edt(target == 0)
37            ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0)
38            target[ignore_mask] = -1
39        return target[None]

Transformation to convert labels into a foreground mask.

Arguments:
  • label_id: The label id to use for extracting the foreground mask. If None, all label values larger than zero will be used to compute the foreground mask.
  • ndim: The dimensionality of the data.
  • ignore_radius: The radius around the foreground label to set to the ignore label.
ForegroundTransform( label_id: Optional[int] = None, ndim: Optional[int] = None, ignore_radius: int = 1)
19    def __init__(self, label_id: Optional[int] = None, ndim: Optional[int] = None, ignore_radius: int = 1):
20        self.label_id = label_id
21        self.ndim = ndim
22        self.ignore_radius = ignore_radius
label_id
ndim
ignore_radius
class BoundaryTransform:
42class BoundaryTransform:
43    """Transformation to convert labels into boundaries.
44
45    Args:
46        mode: The mode for computing the boundaries.
47        ndim: The dimensionality of the data.
48        ignore_radius: The radius around the foreground label to set to the ignore label.
49        add_binary_target: Whether to add a binary mask as additional channel.
50    """
51    def __init__(
52        self,
53        mode: str = "thick",
54        ndim: Optional[int] = None,
55        ignore_radius: int = 2,
56        add_binary_target: bool = False
57    ):
58        self.mode = mode
59        self.ndim = ndim
60        self.ignore_radius = ignore_radius
61        self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None
62
63    def __call__(self, labels: np.ndarray) -> np.ndarray:
64        """Apply the boundary transform to the data.
65
66        Args:
67            labels: The segmentation data.
68
69        Returns:
70            The transformed labels.
71        """
72        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
73        target = skimage.segmentation.find_boundaries(labels, mode=self.mode).astype("int8")
74
75        if self.ignore_radius > 0:
76            dist = distance_transform_edt(target == 0)
77            ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0)
78            target[ignore_mask] = -1
79
80        if self.foreground_trafo is not None:
81            target[target == 1] = 2
82            fg_target = self.foreground_trafo(labels)[0]
83            assert fg_target.shape == target.shape, f"{fg_target}.shape, {target.shape}"
84            fg_mask = np.logical_and(fg_target == 1, target == 0)
85            target[fg_mask] = 1
86
87        return target[None]

Transformation to convert labels into boundaries.

Arguments:
  • mode: The mode for computing the boundaries.
  • ndim: The dimensionality of the data.
  • ignore_radius: The radius around the foreground label to set to the ignore label.
  • add_binary_target: Whether to add a binary mask as additional channel.
BoundaryTransform( mode: str = 'thick', ndim: Optional[int] = None, ignore_radius: int = 2, add_binary_target: bool = False)
51    def __init__(
52        self,
53        mode: str = "thick",
54        ndim: Optional[int] = None,
55        ignore_radius: int = 2,
56        add_binary_target: bool = False
57    ):
58        self.mode = mode
59        self.ndim = ndim
60        self.ignore_radius = ignore_radius
61        self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None
mode
ndim
ignore_radius
foreground_trafo