torch_em.shallow2deep.transform
1from typing import Optional 2 3import numpy as np 4import skimage.segmentation 5from scipy.ndimage.morphology import distance_transform_edt 6from torch_em.util import ensure_array, ensure_spatial_array 7 8 9class ForegroundTransform: 10 """Transformation to convert labels into a foreground mask. 11 12 Args: 13 label_id: The label id to use for extracting the foreground mask. 14 If None, all label values larger than zero will be used to compute the foreground mask. 15 ndim: The dimensionality of the data. 16 ignore_radius: The radius around the foreground label to set to the ignore label. 17 """ 18 def __init__(self, label_id: Optional[int] = None, ndim: Optional[int] = None, ignore_radius: int = 1): 19 self.label_id = label_id 20 self.ndim = ndim 21 self.ignore_radius = ignore_radius 22 23 def __call__(self, labels: np.ndarray) -> np.ndarray: 24 """Apply the transformation to the segmentation data. 25 26 Args: 27 labels: The segmentation data. 28 29 Returns: 30 The foreground mask. 31 """ 32 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 33 target = labels > 0 if self.label_id is None else labels == self.label_id 34 if self.ignore_radius > 0: 35 dist = distance_transform_edt(target == 0) 36 ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0) 37 target[ignore_mask] = -1 38 return target[None] 39 40 41class BoundaryTransform: 42 """Transformation to convert labels into boundaries. 43 44 Args: 45 mode: The mode for computing the boundaries. 46 ndim: The dimensionality of the data. 47 ignore_radius: The radius around the foreground label to set to the ignore label. 48 add_binary_target: Whether to add a binary mask as additional channel. 49 """ 50 def __init__( 51 self, 52 mode: str = "thick", 53 ndim: Optional[int] = None, 54 ignore_radius: int = 2, 55 add_binary_target: bool = False 56 ): 57 self.mode = mode 58 self.ndim = ndim 59 self.ignore_radius = ignore_radius 60 self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None 61 62 def __call__(self, labels: np.ndarray) -> np.ndarray: 63 """Apply the boundary transform to the data. 64 65 Args: 66 labels: The segmentation data. 67 68 Returns: 69 The transformed labels. 70 """ 71 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 72 target = skimage.segmentation.find_boundaries(labels, mode=self.mode).astype("int8") 73 74 if self.ignore_radius > 0: 75 dist = distance_transform_edt(target == 0) 76 ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0) 77 target[ignore_mask] = -1 78 79 if self.foreground_trafo is not None: 80 target[target == 1] = 2 81 fg_target = self.foreground_trafo(labels)[0] 82 assert fg_target.shape == target.shape, f"{fg_target}.shape, {target.shape}" 83 fg_mask = np.logical_and(fg_target == 1, target == 0) 84 target[fg_mask] = 1 85 86 return target[None]
class
ForegroundTransform:
10class ForegroundTransform: 11 """Transformation to convert labels into a foreground mask. 12 13 Args: 14 label_id: The label id to use for extracting the foreground mask. 15 If None, all label values larger than zero will be used to compute the foreground mask. 16 ndim: The dimensionality of the data. 17 ignore_radius: The radius around the foreground label to set to the ignore label. 18 """ 19 def __init__(self, label_id: Optional[int] = None, ndim: Optional[int] = None, ignore_radius: int = 1): 20 self.label_id = label_id 21 self.ndim = ndim 22 self.ignore_radius = ignore_radius 23 24 def __call__(self, labels: np.ndarray) -> np.ndarray: 25 """Apply the transformation to the segmentation data. 26 27 Args: 28 labels: The segmentation data. 29 30 Returns: 31 The foreground mask. 32 """ 33 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 34 target = labels > 0 if self.label_id is None else labels == self.label_id 35 if self.ignore_radius > 0: 36 dist = distance_transform_edt(target == 0) 37 ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0) 38 target[ignore_mask] = -1 39 return target[None]
Transformation to convert labels into a foreground mask.
Arguments:
- label_id: The label id to use for extracting the foreground mask. If None, all label values larger than zero will be used to compute the foreground mask.
- ndim: The dimensionality of the data.
- ignore_radius: The radius around the foreground label to set to the ignore label.
class
BoundaryTransform:
42class BoundaryTransform: 43 """Transformation to convert labels into boundaries. 44 45 Args: 46 mode: The mode for computing the boundaries. 47 ndim: The dimensionality of the data. 48 ignore_radius: The radius around the foreground label to set to the ignore label. 49 add_binary_target: Whether to add a binary mask as additional channel. 50 """ 51 def __init__( 52 self, 53 mode: str = "thick", 54 ndim: Optional[int] = None, 55 ignore_radius: int = 2, 56 add_binary_target: bool = False 57 ): 58 self.mode = mode 59 self.ndim = ndim 60 self.ignore_radius = ignore_radius 61 self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None 62 63 def __call__(self, labels: np.ndarray) -> np.ndarray: 64 """Apply the boundary transform to the data. 65 66 Args: 67 labels: The segmentation data. 68 69 Returns: 70 The transformed labels. 71 """ 72 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 73 target = skimage.segmentation.find_boundaries(labels, mode=self.mode).astype("int8") 74 75 if self.ignore_radius > 0: 76 dist = distance_transform_edt(target == 0) 77 ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0) 78 target[ignore_mask] = -1 79 80 if self.foreground_trafo is not None: 81 target[target == 1] = 2 82 fg_target = self.foreground_trafo(labels)[0] 83 assert fg_target.shape == target.shape, f"{fg_target}.shape, {target.shape}" 84 fg_mask = np.logical_and(fg_target == 1, target == 0) 85 target[fg_mask] = 1 86 87 return target[None]
Transformation to convert labels into boundaries.
Arguments:
- mode: The mode for computing the boundaries.
- ndim: The dimensionality of the data.
- ignore_radius: The radius around the foreground label to set to the ignore label.
- add_binary_target: Whether to add a binary mask as additional channel.
BoundaryTransform( mode: str = 'thick', ndim: Optional[int] = None, ignore_radius: int = 2, add_binary_target: bool = False)
51 def __init__( 52 self, 53 mode: str = "thick", 54 ndim: Optional[int] = None, 55 ignore_radius: int = 2, 56 add_binary_target: bool = False 57 ): 58 self.mode = mode 59 self.ndim = ndim 60 self.ignore_radius = ignore_radius 61 self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None