torch_em.transform.defect

  1from typing import Optional
  2
  3import numpy as np
  4import torch
  5
  6from scipy.ndimage import binary_dilation, map_coordinates
  7from skimage.draw import line
  8from skimage.filters import gaussian
  9from skimage.measure import label
 10
 11from .augmentation import get_augmentations
 12from .raw import standardize
 13from ..data import SegmentationDataset, MinForegroundSampler
 14
 15
 16#
 17# defect augmentations
 18#
 19# TODO
 20# - alignment jitter
 21
 22
 23def get_artifact_source(artifact_path, patch_shape, min_mask_fraction,
 24                        normalizer=standardize,
 25                        raw_key="artifacts", mask_key="alpha_mask"):
 26    """@private
 27    """
 28    augmentation = get_augmentations(ndim=2)
 29    sampler = MinForegroundSampler(min_mask_fraction)
 30    return SegmentationDataset(
 31        artifact_path, raw_key,
 32        artifact_path, mask_key,
 33        patch_shape=patch_shape,
 34        raw_transform=standardize,
 35        transform=augmentation,
 36        sampler=sampler
 37    )
 38
 39
 40class EMDefectAugmentation:
 41    """Augment raw data with transformations similar to defects common in EM data.
 42
 43    Args:
 44        p_drop_slice: Probability for a missing slice.
 45        p_low_contrast: Probability for a low contrast slice.
 46        p_deform_slice: Probaboloty for a deformed slice.
 47        p_paste_artifact: Probability for inserting an artifact from data source.
 48        contrast_scale: Scale of low contrast transformation.
 49        deformation_mode: Deformation mode that should be used.
 50        deformation_strength: Deformation strength in pixel.
 51        artifact_source: Data source for additional artifacts.
 52        mean_val: Mean value for artifact normalization.
 53        std_val: Std value for artifact normalization.
 54    """
 55    def __init__(
 56        self,
 57        p_drop_slice: float,
 58        p_low_contrast: float,
 59        p_deform_slice: float,
 60        p_paste_artifact: float = 0.0,
 61        contrast_scale: float = 0.1,
 62        deformation_mode: str = "undirected",
 63        deformation_strength: float = 10.0,
 64        artifact_source: Optional[torch.utils.data.Dataset] = None,
 65        mean_val: Optional[float] = None,
 66        std_val: Optional[float] = None,
 67    ):
 68        if p_paste_artifact > 0.0:
 69            assert artifact_source is not None
 70        self.artifact_source = artifact_source
 71
 72        # use cumulative probabilities
 73        self.p_drop_slice = p_drop_slice
 74        self.p_low_contrast = self.p_drop_slice + p_low_contrast
 75        self.p_deform_slice = self.p_low_contrast + p_deform_slice
 76        self.p_paste_artifact = self.p_deform_slice + p_paste_artifact
 77        assert self.p_paste_artifact < 1.0
 78
 79        self.contrast_scale = contrast_scale
 80        self.mean_val = mean_val
 81        self.std_val = std_val
 82
 83        # set the parameters for deformation augments
 84        if isinstance(deformation_mode, str):
 85            assert deformation_mode in ('all', 'undirected', 'compress')
 86            self.deformation_mode = deformation_mode
 87        elif isinstance(deformation_mode, (list, tuple)):
 88            assert len(deformation_mode) == 2
 89            assert 'undirected' in deformation_mode
 90            assert 'compress' in deformation_mode
 91            self.deformation_mode = 'all'
 92        self.deformation_strength = deformation_strength
 93
 94    def drop_slice(self, raw):
 95        """@private
 96        """
 97        raw[:] = 0
 98        return raw
 99
100    def low_contrast(self, raw):
101        """@private
102        """
103        mean = raw.mean()
104        raw -= mean
105        raw *= self.contrast_scale
106        raw += mean
107        return raw
108
109    # this simulates a typical defect:
110    # missing line of data with rest of data compressed towards the line
111    def compress_slice(self, raw):
112        """@private
113        """
114        shape = raw.shape
115        # randomly choose fixed x or fixed y with p = 1/2
116        fixed_x = np.random.rand() < .5
117        if fixed_x:
118            x0, y0 = 0, np.random.randint(1, shape[1] - 2)
119            x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2)
120        else:
121            x0, y0 = np.random.randint(1, shape[0] - 2), 0
122            x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1
123
124        # generate the mask of the line that should be blacked out
125        line_mask = np.zeros_like(raw, dtype='bool')
126        rr, cc = line(x0, y0, x1, y1)
127        line_mask[rr, cc] = 1
128
129        # generate vectorfield pointing towards the line to compress the image
130        # first we get the unit vector representing the line
131        line_vector = np.array([x1 - x0, y1 - y0], dtype='float32')
132        line_vector /= np.linalg.norm(line_vector)
133        # next, we generate the normal to the line
134        normal_vector = np.zeros_like(line_vector)
135        normal_vector[0] = - line_vector[1]
136        normal_vector[1] = line_vector[0]
137
138        # make meshgrid
139        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
140        # generate the vector field
141        flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw)
142
143        # find the 2 components where coordinates are bigger / smaller than the line
144        # to apply normal vector in the correct direction
145        components = label(np.logical_not(line_mask))
146        assert len(np.unique(components)) == 3, "%i" % len(np.unique(components))
147        neg_val = components[0, 0] if fixed_x else components[-1, -1]
148        pos_val = components[-1, -1] if fixed_x else components[0, 0]
149
150        flow_x[components == pos_val] = self.deformation_strength * normal_vector[1]
151        flow_y[components == pos_val] = self.deformation_strength * normal_vector[0]
152        flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1]
153        flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0]
154
155        # add small random noise
156        flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
157        flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
158
159        # apply the flow fields
160        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
161        cval = 0.0 if self.mean_val is None else self.mean_val
162        raw = map_coordinates(
163            raw, (flow_y, flow_x), mode='constant', order=3, cval=cval
164        ).reshape(shape)
165
166        # dilate the line mask and zero out the raw below it
167        line_mask = binary_dilation(line_mask, iterations=10)
168        raw[line_mask] = 0.
169        return raw
170
171    def undirected_deformation(self, raw):
172        """@private
173        """
174        shape = raw.shape
175
176        # make meshgrid
177        x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
178
179        # generate random vector field and smooth it
180        flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength
181        flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength
182        flow_x = gaussian(flow_x, sigma=3.)  # sigma is hard-coded for now
183        flow_y = gaussian(flow_y, sigma=3.)  # sigma is hard-coded for now
184
185        # apply the flow fields
186        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
187        raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape)
188        return raw
189
190    def deform_slice(self, raw):
191        """@private
192        """
193        if self.deformation_mode in ('undirected', 'compress'):
194            mode = self.deformation_mode
195        else:
196            mode = 'undireccted' if np.random.rand() < .5 else 'compress'
197        if mode == 'compress':
198            raw = self.compress_slice(raw)
199        else:
200            raw = self.undirected_deformation(raw)
201        return raw
202
203    def paste_artifact(self, raw):
204        """@private
205        """
206        # draw a random artifact location
207        artifact_index = np.random.randint(len(self.artifact_source))
208        artifact, alpha_mask = self.artifact_source[artifact_index]
209        artifact = artifact.numpy().squeeze()
210        alpha_mask = alpha_mask.numpy().squeeze()
211        assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}"
212        assert alpha_mask.shape == raw.shape
213        assert alpha_mask.min() >= 0., f"{alpha_mask.min()}"
214        assert alpha_mask.max() <= 1., f"{alpha_mask.max()}"
215
216        # blend the raw raw data and the artifact according to the alpha mask
217        raw = raw * (1. - alpha_mask) + artifact * alpha_mask
218        return raw
219
220    def __call__(self, raw: np.ndarray) -> np.ndarray:
221        """Apply defect augmentations to input data.
222
223        Args:
224            raw: The input data.
225
226        Returns:
227            The augmented data.
228        """
229        raw = raw.astype("float32")  # needs to be floating point to avoid errors
230        for z in range(raw.shape[0]):
231            r = np.random.rand()
232            if r < self.p_drop_slice:
233                # print("Drop slice", z)
234                raw[z] = self.drop_slice(raw[z])
235            elif r < self.p_low_contrast:
236                # print("Low contrast", z)
237                raw[z] = self.low_contrast(raw[z])
238            elif r < self.p_deform_slice:
239                # print("Deform slice", z)
240                raw[z] = self.deform_slice(raw[z])
241            elif r < self.p_paste_artifact:
242                # print("Paste artifact", z)
243                raw[z] = self.paste_artifact(raw[z])
244        return raw
class EMDefectAugmentation:
 41class EMDefectAugmentation:
 42    """Augment raw data with transformations similar to defects common in EM data.
 43
 44    Args:
 45        p_drop_slice: Probability for a missing slice.
 46        p_low_contrast: Probability for a low contrast slice.
 47        p_deform_slice: Probaboloty for a deformed slice.
 48        p_paste_artifact: Probability for inserting an artifact from data source.
 49        contrast_scale: Scale of low contrast transformation.
 50        deformation_mode: Deformation mode that should be used.
 51        deformation_strength: Deformation strength in pixel.
 52        artifact_source: Data source for additional artifacts.
 53        mean_val: Mean value for artifact normalization.
 54        std_val: Std value for artifact normalization.
 55    """
 56    def __init__(
 57        self,
 58        p_drop_slice: float,
 59        p_low_contrast: float,
 60        p_deform_slice: float,
 61        p_paste_artifact: float = 0.0,
 62        contrast_scale: float = 0.1,
 63        deformation_mode: str = "undirected",
 64        deformation_strength: float = 10.0,
 65        artifact_source: Optional[torch.utils.data.Dataset] = None,
 66        mean_val: Optional[float] = None,
 67        std_val: Optional[float] = None,
 68    ):
 69        if p_paste_artifact > 0.0:
 70            assert artifact_source is not None
 71        self.artifact_source = artifact_source
 72
 73        # use cumulative probabilities
 74        self.p_drop_slice = p_drop_slice
 75        self.p_low_contrast = self.p_drop_slice + p_low_contrast
 76        self.p_deform_slice = self.p_low_contrast + p_deform_slice
 77        self.p_paste_artifact = self.p_deform_slice + p_paste_artifact
 78        assert self.p_paste_artifact < 1.0
 79
 80        self.contrast_scale = contrast_scale
 81        self.mean_val = mean_val
 82        self.std_val = std_val
 83
 84        # set the parameters for deformation augments
 85        if isinstance(deformation_mode, str):
 86            assert deformation_mode in ('all', 'undirected', 'compress')
 87            self.deformation_mode = deformation_mode
 88        elif isinstance(deformation_mode, (list, tuple)):
 89            assert len(deformation_mode) == 2
 90            assert 'undirected' in deformation_mode
 91            assert 'compress' in deformation_mode
 92            self.deformation_mode = 'all'
 93        self.deformation_strength = deformation_strength
 94
 95    def drop_slice(self, raw):
 96        """@private
 97        """
 98        raw[:] = 0
 99        return raw
100
101    def low_contrast(self, raw):
102        """@private
103        """
104        mean = raw.mean()
105        raw -= mean
106        raw *= self.contrast_scale
107        raw += mean
108        return raw
109
110    # this simulates a typical defect:
111    # missing line of data with rest of data compressed towards the line
112    def compress_slice(self, raw):
113        """@private
114        """
115        shape = raw.shape
116        # randomly choose fixed x or fixed y with p = 1/2
117        fixed_x = np.random.rand() < .5
118        if fixed_x:
119            x0, y0 = 0, np.random.randint(1, shape[1] - 2)
120            x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2)
121        else:
122            x0, y0 = np.random.randint(1, shape[0] - 2), 0
123            x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1
124
125        # generate the mask of the line that should be blacked out
126        line_mask = np.zeros_like(raw, dtype='bool')
127        rr, cc = line(x0, y0, x1, y1)
128        line_mask[rr, cc] = 1
129
130        # generate vectorfield pointing towards the line to compress the image
131        # first we get the unit vector representing the line
132        line_vector = np.array([x1 - x0, y1 - y0], dtype='float32')
133        line_vector /= np.linalg.norm(line_vector)
134        # next, we generate the normal to the line
135        normal_vector = np.zeros_like(line_vector)
136        normal_vector[0] = - line_vector[1]
137        normal_vector[1] = line_vector[0]
138
139        # make meshgrid
140        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
141        # generate the vector field
142        flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw)
143
144        # find the 2 components where coordinates are bigger / smaller than the line
145        # to apply normal vector in the correct direction
146        components = label(np.logical_not(line_mask))
147        assert len(np.unique(components)) == 3, "%i" % len(np.unique(components))
148        neg_val = components[0, 0] if fixed_x else components[-1, -1]
149        pos_val = components[-1, -1] if fixed_x else components[0, 0]
150
151        flow_x[components == pos_val] = self.deformation_strength * normal_vector[1]
152        flow_y[components == pos_val] = self.deformation_strength * normal_vector[0]
153        flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1]
154        flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0]
155
156        # add small random noise
157        flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
158        flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
159
160        # apply the flow fields
161        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
162        cval = 0.0 if self.mean_val is None else self.mean_val
163        raw = map_coordinates(
164            raw, (flow_y, flow_x), mode='constant', order=3, cval=cval
165        ).reshape(shape)
166
167        # dilate the line mask and zero out the raw below it
168        line_mask = binary_dilation(line_mask, iterations=10)
169        raw[line_mask] = 0.
170        return raw
171
172    def undirected_deformation(self, raw):
173        """@private
174        """
175        shape = raw.shape
176
177        # make meshgrid
178        x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
179
180        # generate random vector field and smooth it
181        flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength
182        flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength
183        flow_x = gaussian(flow_x, sigma=3.)  # sigma is hard-coded for now
184        flow_y = gaussian(flow_y, sigma=3.)  # sigma is hard-coded for now
185
186        # apply the flow fields
187        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
188        raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape)
189        return raw
190
191    def deform_slice(self, raw):
192        """@private
193        """
194        if self.deformation_mode in ('undirected', 'compress'):
195            mode = self.deformation_mode
196        else:
197            mode = 'undireccted' if np.random.rand() < .5 else 'compress'
198        if mode == 'compress':
199            raw = self.compress_slice(raw)
200        else:
201            raw = self.undirected_deformation(raw)
202        return raw
203
204    def paste_artifact(self, raw):
205        """@private
206        """
207        # draw a random artifact location
208        artifact_index = np.random.randint(len(self.artifact_source))
209        artifact, alpha_mask = self.artifact_source[artifact_index]
210        artifact = artifact.numpy().squeeze()
211        alpha_mask = alpha_mask.numpy().squeeze()
212        assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}"
213        assert alpha_mask.shape == raw.shape
214        assert alpha_mask.min() >= 0., f"{alpha_mask.min()}"
215        assert alpha_mask.max() <= 1., f"{alpha_mask.max()}"
216
217        # blend the raw raw data and the artifact according to the alpha mask
218        raw = raw * (1. - alpha_mask) + artifact * alpha_mask
219        return raw
220
221    def __call__(self, raw: np.ndarray) -> np.ndarray:
222        """Apply defect augmentations to input data.
223
224        Args:
225            raw: The input data.
226
227        Returns:
228            The augmented data.
229        """
230        raw = raw.astype("float32")  # needs to be floating point to avoid errors
231        for z in range(raw.shape[0]):
232            r = np.random.rand()
233            if r < self.p_drop_slice:
234                # print("Drop slice", z)
235                raw[z] = self.drop_slice(raw[z])
236            elif r < self.p_low_contrast:
237                # print("Low contrast", z)
238                raw[z] = self.low_contrast(raw[z])
239            elif r < self.p_deform_slice:
240                # print("Deform slice", z)
241                raw[z] = self.deform_slice(raw[z])
242            elif r < self.p_paste_artifact:
243                # print("Paste artifact", z)
244                raw[z] = self.paste_artifact(raw[z])
245        return raw

Augment raw data with transformations similar to defects common in EM data.

Arguments:
  • p_drop_slice: Probability for a missing slice.
  • p_low_contrast: Probability for a low contrast slice.
  • p_deform_slice: Probaboloty for a deformed slice.
  • p_paste_artifact: Probability for inserting an artifact from data source.
  • contrast_scale: Scale of low contrast transformation.
  • deformation_mode: Deformation mode that should be used.
  • deformation_strength: Deformation strength in pixel.
  • artifact_source: Data source for additional artifacts.
  • mean_val: Mean value for artifact normalization.
  • std_val: Std value for artifact normalization.
EMDefectAugmentation( p_drop_slice: float, p_low_contrast: float, p_deform_slice: float, p_paste_artifact: float = 0.0, contrast_scale: float = 0.1, deformation_mode: str = 'undirected', deformation_strength: float = 10.0, artifact_source: Optional[torch.utils.data.dataset.Dataset] = None, mean_val: Optional[float] = None, std_val: Optional[float] = None)
56    def __init__(
57        self,
58        p_drop_slice: float,
59        p_low_contrast: float,
60        p_deform_slice: float,
61        p_paste_artifact: float = 0.0,
62        contrast_scale: float = 0.1,
63        deformation_mode: str = "undirected",
64        deformation_strength: float = 10.0,
65        artifact_source: Optional[torch.utils.data.Dataset] = None,
66        mean_val: Optional[float] = None,
67        std_val: Optional[float] = None,
68    ):
69        if p_paste_artifact > 0.0:
70            assert artifact_source is not None
71        self.artifact_source = artifact_source
72
73        # use cumulative probabilities
74        self.p_drop_slice = p_drop_slice
75        self.p_low_contrast = self.p_drop_slice + p_low_contrast
76        self.p_deform_slice = self.p_low_contrast + p_deform_slice
77        self.p_paste_artifact = self.p_deform_slice + p_paste_artifact
78        assert self.p_paste_artifact < 1.0
79
80        self.contrast_scale = contrast_scale
81        self.mean_val = mean_val
82        self.std_val = std_val
83
84        # set the parameters for deformation augments
85        if isinstance(deformation_mode, str):
86            assert deformation_mode in ('all', 'undirected', 'compress')
87            self.deformation_mode = deformation_mode
88        elif isinstance(deformation_mode, (list, tuple)):
89            assert len(deformation_mode) == 2
90            assert 'undirected' in deformation_mode
91            assert 'compress' in deformation_mode
92            self.deformation_mode = 'all'
93        self.deformation_strength = deformation_strength
artifact_source
p_drop_slice
p_low_contrast
p_deform_slice
p_paste_artifact
contrast_scale
mean_val
std_val
deformation_strength