torch_em.transform.defect

  1from typing import Optional
  2
  3import numpy as np
  4import torch
  5
  6from scipy.ndimage import binary_dilation, map_coordinates
  7from skimage.draw import line
  8from skimage.filters import gaussian
  9
 10import bioimage_cpp as bic
 11
 12from .augmentation import get_augmentations
 13from .raw import standardize
 14from ..data import SegmentationDataset, MinForegroundSampler
 15
 16
 17#
 18# defect augmentations
 19#
 20# TODO
 21# - alignment jitter
 22
 23
 24def get_artifact_source(artifact_path, patch_shape, min_mask_fraction,
 25                        normalizer=standardize,
 26                        raw_key="artifacts", mask_key="alpha_mask"):
 27    """@private
 28    """
 29    augmentation = get_augmentations(ndim=2)
 30    sampler = MinForegroundSampler(min_mask_fraction)
 31    return SegmentationDataset(
 32        artifact_path, raw_key,
 33        artifact_path, mask_key,
 34        patch_shape=patch_shape,
 35        raw_transform=standardize,
 36        transform=augmentation,
 37        sampler=sampler
 38    )
 39
 40
 41class EMDefectAugmentation:
 42    """Augment raw data with transformations similar to defects common in EM data.
 43
 44    Args:
 45        p_drop_slice: Probability for a missing slice.
 46        p_low_contrast: Probability for a low contrast slice.
 47        p_deform_slice: Probability for a deformed slice.
 48        p_paste_artifact: Probability for inserting an artifact from data source.
 49        contrast_scale: Scale of low contrast transformation.
 50        deformation_mode: Deformation mode that should be used.
 51        deformation_strength: Deformation strength in pixel.
 52        artifact_source: Data source for additional artifacts.
 53        mean_val: Mean value for artifact normalization.
 54        std_val: Std value for artifact normalization.
 55    """
 56    def __init__(
 57        self,
 58        p_drop_slice: float,
 59        p_low_contrast: float,
 60        p_deform_slice: float,
 61        p_paste_artifact: float = 0.0,
 62        contrast_scale: float = 0.1,
 63        deformation_mode: str = "undirected",
 64        deformation_strength: float = 10.0,
 65        artifact_source: Optional[torch.utils.data.Dataset] = None,
 66        mean_val: Optional[float] = None,
 67        std_val: Optional[float] = None,
 68    ):
 69        if p_paste_artifact > 0.0:
 70            assert artifact_source is not None
 71        self.artifact_source = artifact_source
 72
 73        # use cumulative probabilities
 74        self.p_drop_slice = p_drop_slice
 75        self.p_low_contrast = self.p_drop_slice + p_low_contrast
 76        self.p_deform_slice = self.p_low_contrast + p_deform_slice
 77        self.p_paste_artifact = self.p_deform_slice + p_paste_artifact
 78        assert self.p_paste_artifact < 1.0
 79
 80        self.contrast_scale = contrast_scale
 81        self.mean_val = mean_val
 82        self.std_val = std_val
 83
 84        # set the parameters for deformation augments
 85        if isinstance(deformation_mode, str):
 86            assert deformation_mode in ('all', 'undirected', 'compress')
 87            self.deformation_mode = deformation_mode
 88        elif isinstance(deformation_mode, (list, tuple)):
 89            assert len(deformation_mode) == 2
 90            assert 'undirected' in deformation_mode
 91            assert 'compress' in deformation_mode
 92            self.deformation_mode = 'all'
 93        self.deformation_strength = deformation_strength
 94
 95    def drop_slice(self, raw):
 96        """@private
 97        """
 98        raw[:] = 0
 99        return raw
100
101    def low_contrast(self, raw):
102        """@private
103        """
104        mean = raw.mean()
105        raw -= mean
106        raw *= self.contrast_scale
107        raw += mean
108        return raw
109
110    # this simulates a typical defect:
111    # missing line of data with rest of data compressed towards the line
112    def compress_slice(self, raw):
113        """@private
114        """
115        shape = raw.shape
116        # randomly choose fixed x or fixed y with p = 1/2
117        fixed_x = np.random.rand() < .5
118        if fixed_x:
119            x0, y0 = 0, np.random.randint(1, shape[1] - 2)
120            x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2)
121        else:
122            x0, y0 = np.random.randint(1, shape[0] - 2), 0
123            x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1
124
125        # generate the mask of the line that should be blacked out
126        line_mask = np.zeros_like(raw, dtype='bool')
127        rr, cc = line(x0, y0, x1, y1)
128        line_mask[rr, cc] = 1
129
130        # generate vectorfield pointing towards the line to compress the image
131        # first we get the unit vector representing the line
132        line_vector = np.array([x1 - x0, y1 - y0], dtype='float32')
133        line_vector /= np.linalg.norm(line_vector)
134        # next, we generate the normal to the line
135        normal_vector = np.zeros_like(line_vector)
136        normal_vector[0] = - line_vector[1]
137        normal_vector[1] = line_vector[0]
138
139        # make meshgrid
140        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
141        # generate the vector field
142        flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw)
143
144        # find the 2 components where coordinates are bigger / smaller than the line
145        # to apply normal vector in the correct direction
146        components = bic.segmentation.label(np.logical_not(line_mask))
147        assert len(np.unique(components)) == 3, "%i" % len(np.unique(components))
148        neg_val = components[0, 0] if fixed_x else components[-1, -1]
149        pos_val = components[-1, -1] if fixed_x else components[0, 0]
150
151        flow_x[components == pos_val] = self.deformation_strength * normal_vector[1]
152        flow_y[components == pos_val] = self.deformation_strength * normal_vector[0]
153        flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1]
154        flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0]
155
156        # add small random noise
157        flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
158        flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
159
160        # apply the flow fields
161        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
162        cval = 0.0 if self.mean_val is None else self.mean_val
163        raw = map_coordinates(
164            raw, (flow_y, flow_x), mode='constant', order=3, cval=cval
165        ).reshape(shape)
166
167        # dilate the line mask and zero out the raw below it
168        line_mask = binary_dilation(line_mask, iterations=10)
169        raw[line_mask] = 0.
170        return raw
171
172    def undirected_deformation(self, raw):
173        """@private
174        """
175        shape = raw.shape
176
177        # make meshgrid
178        x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
179
180        # generate random vector field and smooth it
181        flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength
182        flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength
183        flow_x = gaussian(flow_x, sigma=3.)  # sigma is hard-coded for now
184        flow_y = gaussian(flow_y, sigma=3.)  # sigma is hard-coded for now
185
186        # apply the flow fields
187        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
188        raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape)
189        return raw
190
191    def deform_slice(self, raw):
192        """@private
193        """
194        if self.deformation_mode in ('undirected', 'compress'):
195            mode = self.deformation_mode
196        else:
197            mode = 'undireccted' if np.random.rand() < .5 else 'compress'
198        if mode == 'compress':
199            raw = self.compress_slice(raw)
200        else:
201            raw = self.undirected_deformation(raw)
202        return raw
203
204    def paste_artifact(self, raw):
205        """@private
206        """
207        # draw a random artifact location
208        artifact_index = np.random.randint(len(self.artifact_source))
209        artifact, alpha_mask = self.artifact_source[artifact_index]
210        artifact = artifact.numpy().squeeze()
211        alpha_mask = alpha_mask.numpy().squeeze()
212        assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}"
213        assert alpha_mask.shape == raw.shape
214        assert alpha_mask.min() >= 0., f"{alpha_mask.min()}"
215        assert alpha_mask.max() <= 1., f"{alpha_mask.max()}"
216
217        # blend the raw raw data and the artifact according to the alpha mask
218        raw = raw * (1. - alpha_mask) + artifact * alpha_mask
219        return raw
220
221    def __call__(self, raw: np.ndarray) -> np.ndarray:
222        """Apply defect augmentations to input data.
223
224        Args:
225            raw: The input data.
226
227        Returns:
228            The augmented data.
229        """
230        raw = raw.astype("float32")  # needs to be floating point to avoid errors
231        for z in range(raw.shape[0]):
232            r = np.random.rand()
233            if r < self.p_drop_slice:
234                # print("Drop slice", z)
235                raw[z] = self.drop_slice(raw[z])
236            elif r < self.p_low_contrast:
237                # print("Low contrast", z)
238                raw[z] = self.low_contrast(raw[z])
239            elif r < self.p_deform_slice:
240                # print("Deform slice", z)
241                raw[z] = self.deform_slice(raw[z])
242            elif r < self.p_paste_artifact:
243                # print("Paste artifact", z)
244                raw[z] = self.paste_artifact(raw[z])
245        return raw
class EMDefectAugmentation:
 42class EMDefectAugmentation:
 43    """Augment raw data with transformations similar to defects common in EM data.
 44
 45    Args:
 46        p_drop_slice: Probability for a missing slice.
 47        p_low_contrast: Probability for a low contrast slice.
 48        p_deform_slice: Probability for a deformed slice.
 49        p_paste_artifact: Probability for inserting an artifact from data source.
 50        contrast_scale: Scale of low contrast transformation.
 51        deformation_mode: Deformation mode that should be used.
 52        deformation_strength: Deformation strength in pixel.
 53        artifact_source: Data source for additional artifacts.
 54        mean_val: Mean value for artifact normalization.
 55        std_val: Std value for artifact normalization.
 56    """
 57    def __init__(
 58        self,
 59        p_drop_slice: float,
 60        p_low_contrast: float,
 61        p_deform_slice: float,
 62        p_paste_artifact: float = 0.0,
 63        contrast_scale: float = 0.1,
 64        deformation_mode: str = "undirected",
 65        deformation_strength: float = 10.0,
 66        artifact_source: Optional[torch.utils.data.Dataset] = None,
 67        mean_val: Optional[float] = None,
 68        std_val: Optional[float] = None,
 69    ):
 70        if p_paste_artifact > 0.0:
 71            assert artifact_source is not None
 72        self.artifact_source = artifact_source
 73
 74        # use cumulative probabilities
 75        self.p_drop_slice = p_drop_slice
 76        self.p_low_contrast = self.p_drop_slice + p_low_contrast
 77        self.p_deform_slice = self.p_low_contrast + p_deform_slice
 78        self.p_paste_artifact = self.p_deform_slice + p_paste_artifact
 79        assert self.p_paste_artifact < 1.0
 80
 81        self.contrast_scale = contrast_scale
 82        self.mean_val = mean_val
 83        self.std_val = std_val
 84
 85        # set the parameters for deformation augments
 86        if isinstance(deformation_mode, str):
 87            assert deformation_mode in ('all', 'undirected', 'compress')
 88            self.deformation_mode = deformation_mode
 89        elif isinstance(deformation_mode, (list, tuple)):
 90            assert len(deformation_mode) == 2
 91            assert 'undirected' in deformation_mode
 92            assert 'compress' in deformation_mode
 93            self.deformation_mode = 'all'
 94        self.deformation_strength = deformation_strength
 95
 96    def drop_slice(self, raw):
 97        """@private
 98        """
 99        raw[:] = 0
100        return raw
101
102    def low_contrast(self, raw):
103        """@private
104        """
105        mean = raw.mean()
106        raw -= mean
107        raw *= self.contrast_scale
108        raw += mean
109        return raw
110
111    # this simulates a typical defect:
112    # missing line of data with rest of data compressed towards the line
113    def compress_slice(self, raw):
114        """@private
115        """
116        shape = raw.shape
117        # randomly choose fixed x or fixed y with p = 1/2
118        fixed_x = np.random.rand() < .5
119        if fixed_x:
120            x0, y0 = 0, np.random.randint(1, shape[1] - 2)
121            x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2)
122        else:
123            x0, y0 = np.random.randint(1, shape[0] - 2), 0
124            x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1
125
126        # generate the mask of the line that should be blacked out
127        line_mask = np.zeros_like(raw, dtype='bool')
128        rr, cc = line(x0, y0, x1, y1)
129        line_mask[rr, cc] = 1
130
131        # generate vectorfield pointing towards the line to compress the image
132        # first we get the unit vector representing the line
133        line_vector = np.array([x1 - x0, y1 - y0], dtype='float32')
134        line_vector /= np.linalg.norm(line_vector)
135        # next, we generate the normal to the line
136        normal_vector = np.zeros_like(line_vector)
137        normal_vector[0] = - line_vector[1]
138        normal_vector[1] = line_vector[0]
139
140        # make meshgrid
141        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
142        # generate the vector field
143        flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw)
144
145        # find the 2 components where coordinates are bigger / smaller than the line
146        # to apply normal vector in the correct direction
147        components = bic.segmentation.label(np.logical_not(line_mask))
148        assert len(np.unique(components)) == 3, "%i" % len(np.unique(components))
149        neg_val = components[0, 0] if fixed_x else components[-1, -1]
150        pos_val = components[-1, -1] if fixed_x else components[0, 0]
151
152        flow_x[components == pos_val] = self.deformation_strength * normal_vector[1]
153        flow_y[components == pos_val] = self.deformation_strength * normal_vector[0]
154        flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1]
155        flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0]
156
157        # add small random noise
158        flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
159        flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
160
161        # apply the flow fields
162        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
163        cval = 0.0 if self.mean_val is None else self.mean_val
164        raw = map_coordinates(
165            raw, (flow_y, flow_x), mode='constant', order=3, cval=cval
166        ).reshape(shape)
167
168        # dilate the line mask and zero out the raw below it
169        line_mask = binary_dilation(line_mask, iterations=10)
170        raw[line_mask] = 0.
171        return raw
172
173    def undirected_deformation(self, raw):
174        """@private
175        """
176        shape = raw.shape
177
178        # make meshgrid
179        x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
180
181        # generate random vector field and smooth it
182        flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength
183        flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength
184        flow_x = gaussian(flow_x, sigma=3.)  # sigma is hard-coded for now
185        flow_y = gaussian(flow_y, sigma=3.)  # sigma is hard-coded for now
186
187        # apply the flow fields
188        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
189        raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape)
190        return raw
191
192    def deform_slice(self, raw):
193        """@private
194        """
195        if self.deformation_mode in ('undirected', 'compress'):
196            mode = self.deformation_mode
197        else:
198            mode = 'undireccted' if np.random.rand() < .5 else 'compress'
199        if mode == 'compress':
200            raw = self.compress_slice(raw)
201        else:
202            raw = self.undirected_deformation(raw)
203        return raw
204
205    def paste_artifact(self, raw):
206        """@private
207        """
208        # draw a random artifact location
209        artifact_index = np.random.randint(len(self.artifact_source))
210        artifact, alpha_mask = self.artifact_source[artifact_index]
211        artifact = artifact.numpy().squeeze()
212        alpha_mask = alpha_mask.numpy().squeeze()
213        assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}"
214        assert alpha_mask.shape == raw.shape
215        assert alpha_mask.min() >= 0., f"{alpha_mask.min()}"
216        assert alpha_mask.max() <= 1., f"{alpha_mask.max()}"
217
218        # blend the raw raw data and the artifact according to the alpha mask
219        raw = raw * (1. - alpha_mask) + artifact * alpha_mask
220        return raw
221
222    def __call__(self, raw: np.ndarray) -> np.ndarray:
223        """Apply defect augmentations to input data.
224
225        Args:
226            raw: The input data.
227
228        Returns:
229            The augmented data.
230        """
231        raw = raw.astype("float32")  # needs to be floating point to avoid errors
232        for z in range(raw.shape[0]):
233            r = np.random.rand()
234            if r < self.p_drop_slice:
235                # print("Drop slice", z)
236                raw[z] = self.drop_slice(raw[z])
237            elif r < self.p_low_contrast:
238                # print("Low contrast", z)
239                raw[z] = self.low_contrast(raw[z])
240            elif r < self.p_deform_slice:
241                # print("Deform slice", z)
242                raw[z] = self.deform_slice(raw[z])
243            elif r < self.p_paste_artifact:
244                # print("Paste artifact", z)
245                raw[z] = self.paste_artifact(raw[z])
246        return raw

Augment raw data with transformations similar to defects common in EM data.

Arguments:
  • p_drop_slice: Probability for a missing slice.
  • p_low_contrast: Probability for a low contrast slice.
  • p_deform_slice: Probability for a deformed slice.
  • p_paste_artifact: Probability for inserting an artifact from data source.
  • contrast_scale: Scale of low contrast transformation.
  • deformation_mode: Deformation mode that should be used.
  • deformation_strength: Deformation strength in pixel.
  • artifact_source: Data source for additional artifacts.
  • mean_val: Mean value for artifact normalization.
  • std_val: Std value for artifact normalization.
EMDefectAugmentation( p_drop_slice: float, p_low_contrast: float, p_deform_slice: float, p_paste_artifact: float = 0.0, contrast_scale: float = 0.1, deformation_mode: str = 'undirected', deformation_strength: float = 10.0, artifact_source: Optional[torch.utils.data.dataset.Dataset] = None, mean_val: Optional[float] = None, std_val: Optional[float] = None)
57    def __init__(
58        self,
59        p_drop_slice: float,
60        p_low_contrast: float,
61        p_deform_slice: float,
62        p_paste_artifact: float = 0.0,
63        contrast_scale: float = 0.1,
64        deformation_mode: str = "undirected",
65        deformation_strength: float = 10.0,
66        artifact_source: Optional[torch.utils.data.Dataset] = None,
67        mean_val: Optional[float] = None,
68        std_val: Optional[float] = None,
69    ):
70        if p_paste_artifact > 0.0:
71            assert artifact_source is not None
72        self.artifact_source = artifact_source
73
74        # use cumulative probabilities
75        self.p_drop_slice = p_drop_slice
76        self.p_low_contrast = self.p_drop_slice + p_low_contrast
77        self.p_deform_slice = self.p_low_contrast + p_deform_slice
78        self.p_paste_artifact = self.p_deform_slice + p_paste_artifact
79        assert self.p_paste_artifact < 1.0
80
81        self.contrast_scale = contrast_scale
82        self.mean_val = mean_val
83        self.std_val = std_val
84
85        # set the parameters for deformation augments
86        if isinstance(deformation_mode, str):
87            assert deformation_mode in ('all', 'undirected', 'compress')
88            self.deformation_mode = deformation_mode
89        elif isinstance(deformation_mode, (list, tuple)):
90            assert len(deformation_mode) == 2
91            assert 'undirected' in deformation_mode
92            assert 'compress' in deformation_mode
93            self.deformation_mode = 'all'
94        self.deformation_strength = deformation_strength
artifact_source
p_drop_slice
p_low_contrast
p_deform_slice
p_paste_artifact
contrast_scale
mean_val
std_val
deformation_strength