torch_em.shallow2deep.transform
1import numpy as np 2import skimage.segmentation 3from scipy.ndimage.morphology import distance_transform_edt 4from torch_em.util import ensure_array, ensure_spatial_array 5 6 7class ForegroundTransform: 8 def __init__(self, label_id=None, ndim=None, ignore_radius=1): 9 self.label_id = label_id 10 self.ndim = ndim 11 self.ignore_radius = ignore_radius 12 13 def __call__(self, labels): 14 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 15 target = labels > 0 if self.label_id is None else labels == self.label_id 16 if self.ignore_radius > 0: 17 dist = distance_transform_edt(target == 0) 18 ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0) 19 target[ignore_mask] = -1 20 return target[None] 21 22 23class BoundaryTransform: 24 def __init__(self, mode="thick", ndim=None, ignore_radius=2, add_binary_target=False): 25 self.mode = mode 26 self.ndim = ndim 27 self.ignore_radius = ignore_radius 28 self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None 29 30 def __call__(self, labels): 31 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 32 target = skimage.segmentation.find_boundaries(labels, mode=self.mode).astype("int8") 33 34 if self.ignore_radius > 0: 35 dist = distance_transform_edt(target == 0) 36 ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0) 37 target[ignore_mask] = -1 38 39 if self.foreground_trafo is not None: 40 target[target == 1] = 2 41 fg_target = self.foreground_trafo(labels)[0] 42 assert fg_target.shape == target.shape, f"{fg_target}.shape, {target.shape}" 43 fg_mask = np.logical_and(fg_target == 1, target == 0) 44 target[fg_mask] = 1 45 46 return target[None]
class
ForegroundTransform:
8class ForegroundTransform: 9 def __init__(self, label_id=None, ndim=None, ignore_radius=1): 10 self.label_id = label_id 11 self.ndim = ndim 12 self.ignore_radius = ignore_radius 13 14 def __call__(self, labels): 15 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 16 target = labels > 0 if self.label_id is None else labels == self.label_id 17 if self.ignore_radius > 0: 18 dist = distance_transform_edt(target == 0) 19 ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0) 20 target[ignore_mask] = -1 21 return target[None]
class
BoundaryTransform:
24class BoundaryTransform: 25 def __init__(self, mode="thick", ndim=None, ignore_radius=2, add_binary_target=False): 26 self.mode = mode 27 self.ndim = ndim 28 self.ignore_radius = ignore_radius 29 self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None 30 31 def __call__(self, labels): 32 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 33 target = skimage.segmentation.find_boundaries(labels, mode=self.mode).astype("int8") 34 35 if self.ignore_radius > 0: 36 dist = distance_transform_edt(target == 0) 37 ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0) 38 target[ignore_mask] = -1 39 40 if self.foreground_trafo is not None: 41 target[target == 1] = 2 42 fg_target = self.foreground_trafo(labels)[0] 43 assert fg_target.shape == target.shape, f"{fg_target}.shape, {target.shape}" 44 fg_mask = np.logical_and(fg_target == 1, target == 0) 45 target[fg_mask] = 1 46 47 return target[None]