torch_em.shallow2deep.transform

 1import numpy as np
 2import skimage.segmentation
 3from scipy.ndimage.morphology import distance_transform_edt
 4from torch_em.util import ensure_array, ensure_spatial_array
 5
 6
 7class ForegroundTransform:
 8    def __init__(self, label_id=None, ndim=None, ignore_radius=1):
 9        self.label_id = label_id
10        self.ndim = ndim
11        self.ignore_radius = ignore_radius
12
13    def __call__(self, labels):
14        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
15        target = labels > 0 if self.label_id is None else labels == self.label_id
16        if self.ignore_radius > 0:
17            dist = distance_transform_edt(target == 0)
18            ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0)
19            target[ignore_mask] = -1
20        return target[None]
21
22
23class BoundaryTransform:
24    def __init__(self, mode="thick", ndim=None, ignore_radius=2, add_binary_target=False):
25        self.mode = mode
26        self.ndim = ndim
27        self.ignore_radius = ignore_radius
28        self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None
29
30    def __call__(self, labels):
31        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
32        target = skimage.segmentation.find_boundaries(labels, mode=self.mode).astype("int8")
33
34        if self.ignore_radius > 0:
35            dist = distance_transform_edt(target == 0)
36            ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0)
37            target[ignore_mask] = -1
38
39        if self.foreground_trafo is not None:
40            target[target == 1] = 2
41            fg_target = self.foreground_trafo(labels)[0]
42            assert fg_target.shape == target.shape, f"{fg_target}.shape, {target.shape}"
43            fg_mask = np.logical_and(fg_target == 1, target == 0)
44            target[fg_mask] = 1
45
46        return target[None]
class ForegroundTransform:
 8class ForegroundTransform:
 9    def __init__(self, label_id=None, ndim=None, ignore_radius=1):
10        self.label_id = label_id
11        self.ndim = ndim
12        self.ignore_radius = ignore_radius
13
14    def __call__(self, labels):
15        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
16        target = labels > 0 if self.label_id is None else labels == self.label_id
17        if self.ignore_radius > 0:
18            dist = distance_transform_edt(target == 0)
19            ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0)
20            target[ignore_mask] = -1
21        return target[None]
ForegroundTransform(label_id=None, ndim=None, ignore_radius=1)
 9    def __init__(self, label_id=None, ndim=None, ignore_radius=1):
10        self.label_id = label_id
11        self.ndim = ndim
12        self.ignore_radius = ignore_radius
label_id
ndim
ignore_radius
class BoundaryTransform:
24class BoundaryTransform:
25    def __init__(self, mode="thick", ndim=None, ignore_radius=2, add_binary_target=False):
26        self.mode = mode
27        self.ndim = ndim
28        self.ignore_radius = ignore_radius
29        self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None
30
31    def __call__(self, labels):
32        labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
33        target = skimage.segmentation.find_boundaries(labels, mode=self.mode).astype("int8")
34
35        if self.ignore_radius > 0:
36            dist = distance_transform_edt(target == 0)
37            ignore_mask = np.logical_and(dist <= self.ignore_radius, target == 0)
38            target[ignore_mask] = -1
39
40        if self.foreground_trafo is not None:
41            target[target == 1] = 2
42            fg_target = self.foreground_trafo(labels)[0]
43            assert fg_target.shape == target.shape, f"{fg_target}.shape, {target.shape}"
44            fg_mask = np.logical_and(fg_target == 1, target == 0)
45            target[fg_mask] = 1
46
47        return target[None]
BoundaryTransform(mode='thick', ndim=None, ignore_radius=2, add_binary_target=False)
25    def __init__(self, mode="thick", ndim=None, ignore_radius=2, add_binary_target=False):
26        self.mode = mode
27        self.ndim = ndim
28        self.ignore_radius = ignore_radius
29        self.foreground_trafo = ForegroundTransform(ndim=ndim, ignore_radius=0) if add_binary_target else None
mode
ndim
ignore_radius
foreground_trafo