torch_em.transform.defect

  1import numpy as np
  2
  3from scipy.ndimage import binary_dilation, map_coordinates
  4from skimage.draw import line
  5from skimage.filters import gaussian
  6from skimage.measure import label
  7
  8from .augmentation import get_augmentations
  9from .raw import standardize
 10from ..data import SegmentationDataset, MinForegroundSampler
 11
 12
 13#
 14# defect augmentations
 15#
 16# TODO
 17# - alignment jitter
 18
 19
 20def get_artifact_source(artifact_path, patch_shape, min_mask_fraction,
 21                        normalizer=standardize,
 22                        raw_key="artifacts", mask_key="alpha_mask"):
 23    augmentation = get_augmentations(ndim=2)
 24    sampler = MinForegroundSampler(min_mask_fraction)
 25    return SegmentationDataset(
 26        artifact_path, raw_key,
 27        artifact_path, mask_key,
 28        patch_shape=patch_shape,
 29        raw_transform=standardize,
 30        transform=augmentation,
 31        sampler=sampler
 32    )
 33
 34
 35class EMDefectAugmentation:
 36    """Augment raw data with transformations similar to defects common in EM data.
 37
 38    Arguments:
 39        p_drop_slice: probability for a missing slice
 40        p_low_contrast: probability for a low contrast slice
 41        p_deform_slice: probaboloty for a deformed slice
 42        p_paste_artifact: probability for inserting an artifact from data source
 43        contrast_scale: scale of low contrast transformation
 44        deformation_mode: deformation mode that should be used
 45        deformation_strength: deformation strength in pixel
 46        artifact_source: data source for additional artifacts
 47        mean_val: mean value for artifact normalization
 48        std_val: std value for artifact normalization
 49    """
 50    def __init__(
 51        self,
 52        p_drop_slice,
 53        p_low_contrast,
 54        p_deform_slice,
 55        p_paste_artifact=0.0,
 56        contrast_scale=0.1,
 57        deformation_mode='undirected',
 58        deformation_strength=10,
 59        artifact_source=None,
 60        mean_val=None,
 61        std_val=None
 62    ):
 63        if p_paste_artifact > 0.0:
 64            assert artifact_source is not None
 65        self.artifact_source = artifact_source
 66
 67        # use cumulative probabilities
 68        self.p_drop_slice = p_drop_slice
 69        self.p_low_contrast = self.p_drop_slice + p_low_contrast
 70        self.p_deform_slice = self.p_low_contrast + p_deform_slice
 71        self.p_paste_artifact = self.p_deform_slice + p_paste_artifact
 72        assert self.p_paste_artifact < 1.0
 73
 74        self.contrast_scale = contrast_scale
 75        self.mean_val = mean_val
 76        self.std_val = std_val
 77
 78        # set the parameters for deformation augments
 79        if isinstance(deformation_mode, str):
 80            assert deformation_mode in ('all', 'undirected', 'compress')
 81            self.deformation_mode = deformation_mode
 82        elif isinstance(deformation_mode, (list, tuple)):
 83            assert len(deformation_mode) == 2
 84            assert 'undirected' in deformation_mode
 85            assert 'compress' in deformation_mode
 86            self.deformation_mode = 'all'
 87        self.deformation_strength = deformation_strength
 88
 89    def drop_slice(self, raw):
 90        raw[:] = 0
 91        return raw
 92
 93    def low_contrast(self, raw):
 94        mean = raw.mean()
 95        raw -= mean
 96        raw *= self.contrast_scale
 97        raw += mean
 98        return raw
 99
100    # this simulates a typical defect:
101    # missing line of data with rest of data compressed towards the line
102    def compress_slice(self, raw):
103        shape = raw.shape
104        # randomly choose fixed x or fixed y with p = 1/2
105        fixed_x = np.random.rand() < .5
106        if fixed_x:
107            x0, y0 = 0, np.random.randint(1, shape[1] - 2)
108            x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2)
109        else:
110            x0, y0 = np.random.randint(1, shape[0] - 2), 0
111            x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1
112
113        # generate the mask of the line that should be blacked out
114        line_mask = np.zeros_like(raw, dtype='bool')
115        rr, cc = line(x0, y0, x1, y1)
116        line_mask[rr, cc] = 1
117
118        # generate vectorfield pointing towards the line to compress the image
119        # first we get the unit vector representing the line
120        line_vector = np.array([x1 - x0, y1 - y0], dtype='float32')
121        line_vector /= np.linalg.norm(line_vector)
122        # next, we generate the normal to the line
123        normal_vector = np.zeros_like(line_vector)
124        normal_vector[0] = - line_vector[1]
125        normal_vector[1] = line_vector[0]
126
127        # make meshgrid
128        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
129        # generate the vector field
130        flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw)
131
132        # find the 2 components where coordinates are bigger / smaller than the line
133        # to apply normal vector in the correct direction
134        components = label(np.logical_not(line_mask))
135        assert len(np.unique(components)) == 3, "%i" % len(np.unique(components))
136        neg_val = components[0, 0] if fixed_x else components[-1, -1]
137        pos_val = components[-1, -1] if fixed_x else components[0, 0]
138
139        flow_x[components == pos_val] = self.deformation_strength * normal_vector[1]
140        flow_y[components == pos_val] = self.deformation_strength * normal_vector[0]
141        flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1]
142        flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0]
143
144        # add small random noise
145        flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
146        flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
147
148        # apply the flow fields
149        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
150        cval = 0.0 if self.mean_val is None else self.mean_val
151        raw = map_coordinates(
152            raw, (flow_y, flow_x), mode='constant', order=3, cval=cval
153        ).reshape(shape)
154
155        # dilate the line mask and zero out the raw below it
156        line_mask = binary_dilation(line_mask, iterations=10)
157        raw[line_mask] = 0.
158        return raw
159
160    def undirected_deformation(self, raw):
161        shape = raw.shape
162
163        # make meshgrid
164        x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
165
166        # generate random vector field and smooth it
167        flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength
168        flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength
169        flow_x = gaussian(flow_x, sigma=3.)  # sigma is hard-coded for now
170        flow_y = gaussian(flow_y, sigma=3.)  # sigma is hard-coded for now
171
172        # apply the flow fields
173        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
174        raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape)
175        return raw
176
177    def deform_slice(self, raw):
178        if self.deformation_mode in ('undirected', 'compress'):
179            mode = self.deformation_mode
180        else:
181            mode = 'undireccted' if np.random.rand() < .5 else 'compress'
182        if mode == 'compress':
183            raw = self.compress_slice(raw)
184        else:
185            raw = self.undirected_deformation(raw)
186        return raw
187
188    def paste_artifact(self, raw):
189        # draw a random artifact location
190        artifact_index = np.random.randint(len(self.artifact_source))
191        artifact, alpha_mask = self.artifact_source[artifact_index]
192        artifact = artifact.numpy().squeeze()
193        alpha_mask = alpha_mask.numpy().squeeze()
194        assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}"
195        assert alpha_mask.shape == raw.shape
196        assert alpha_mask.min() >= 0., f"{alpha_mask.min()}"
197        assert alpha_mask.max() <= 1., f"{alpha_mask.max()}"
198
199        # blend the raw raw data and the artifact according to the alpha mask
200        raw = raw * (1. - alpha_mask) + artifact * alpha_mask
201        return raw
202
203    def __call__(self, raw):
204        raw = raw.astype("float32")  # needs to be floating point to avoid errors
205        for z in range(raw.shape[0]):
206            r = np.random.rand()
207            if r < self.p_drop_slice:
208                # print("Drop slice", z)
209                raw[z] = self.drop_slice(raw[z])
210            elif r < self.p_low_contrast:
211                # print("Low contrast", z)
212                raw[z] = self.low_contrast(raw[z])
213            elif r < self.p_deform_slice:
214                # print("Deform slice", z)
215                raw[z] = self.deform_slice(raw[z])
216            elif r < self.p_paste_artifact:
217                # print("Paste artifact", z)
218                raw[z] = self.paste_artifact(raw[z])
219        return raw
def get_artifact_source( artifact_path, patch_shape, min_mask_fraction, normalizer=<function standardize>, raw_key='artifacts', mask_key='alpha_mask'):
21def get_artifact_source(artifact_path, patch_shape, min_mask_fraction,
22                        normalizer=standardize,
23                        raw_key="artifacts", mask_key="alpha_mask"):
24    augmentation = get_augmentations(ndim=2)
25    sampler = MinForegroundSampler(min_mask_fraction)
26    return SegmentationDataset(
27        artifact_path, raw_key,
28        artifact_path, mask_key,
29        patch_shape=patch_shape,
30        raw_transform=standardize,
31        transform=augmentation,
32        sampler=sampler
33    )
class EMDefectAugmentation:
 36class EMDefectAugmentation:
 37    """Augment raw data with transformations similar to defects common in EM data.
 38
 39    Arguments:
 40        p_drop_slice: probability for a missing slice
 41        p_low_contrast: probability for a low contrast slice
 42        p_deform_slice: probaboloty for a deformed slice
 43        p_paste_artifact: probability for inserting an artifact from data source
 44        contrast_scale: scale of low contrast transformation
 45        deformation_mode: deformation mode that should be used
 46        deformation_strength: deformation strength in pixel
 47        artifact_source: data source for additional artifacts
 48        mean_val: mean value for artifact normalization
 49        std_val: std value for artifact normalization
 50    """
 51    def __init__(
 52        self,
 53        p_drop_slice,
 54        p_low_contrast,
 55        p_deform_slice,
 56        p_paste_artifact=0.0,
 57        contrast_scale=0.1,
 58        deformation_mode='undirected',
 59        deformation_strength=10,
 60        artifact_source=None,
 61        mean_val=None,
 62        std_val=None
 63    ):
 64        if p_paste_artifact > 0.0:
 65            assert artifact_source is not None
 66        self.artifact_source = artifact_source
 67
 68        # use cumulative probabilities
 69        self.p_drop_slice = p_drop_slice
 70        self.p_low_contrast = self.p_drop_slice + p_low_contrast
 71        self.p_deform_slice = self.p_low_contrast + p_deform_slice
 72        self.p_paste_artifact = self.p_deform_slice + p_paste_artifact
 73        assert self.p_paste_artifact < 1.0
 74
 75        self.contrast_scale = contrast_scale
 76        self.mean_val = mean_val
 77        self.std_val = std_val
 78
 79        # set the parameters for deformation augments
 80        if isinstance(deformation_mode, str):
 81            assert deformation_mode in ('all', 'undirected', 'compress')
 82            self.deformation_mode = deformation_mode
 83        elif isinstance(deformation_mode, (list, tuple)):
 84            assert len(deformation_mode) == 2
 85            assert 'undirected' in deformation_mode
 86            assert 'compress' in deformation_mode
 87            self.deformation_mode = 'all'
 88        self.deformation_strength = deformation_strength
 89
 90    def drop_slice(self, raw):
 91        raw[:] = 0
 92        return raw
 93
 94    def low_contrast(self, raw):
 95        mean = raw.mean()
 96        raw -= mean
 97        raw *= self.contrast_scale
 98        raw += mean
 99        return raw
100
101    # this simulates a typical defect:
102    # missing line of data with rest of data compressed towards the line
103    def compress_slice(self, raw):
104        shape = raw.shape
105        # randomly choose fixed x or fixed y with p = 1/2
106        fixed_x = np.random.rand() < .5
107        if fixed_x:
108            x0, y0 = 0, np.random.randint(1, shape[1] - 2)
109            x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2)
110        else:
111            x0, y0 = np.random.randint(1, shape[0] - 2), 0
112            x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1
113
114        # generate the mask of the line that should be blacked out
115        line_mask = np.zeros_like(raw, dtype='bool')
116        rr, cc = line(x0, y0, x1, y1)
117        line_mask[rr, cc] = 1
118
119        # generate vectorfield pointing towards the line to compress the image
120        # first we get the unit vector representing the line
121        line_vector = np.array([x1 - x0, y1 - y0], dtype='float32')
122        line_vector /= np.linalg.norm(line_vector)
123        # next, we generate the normal to the line
124        normal_vector = np.zeros_like(line_vector)
125        normal_vector[0] = - line_vector[1]
126        normal_vector[1] = line_vector[0]
127
128        # make meshgrid
129        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
130        # generate the vector field
131        flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw)
132
133        # find the 2 components where coordinates are bigger / smaller than the line
134        # to apply normal vector in the correct direction
135        components = label(np.logical_not(line_mask))
136        assert len(np.unique(components)) == 3, "%i" % len(np.unique(components))
137        neg_val = components[0, 0] if fixed_x else components[-1, -1]
138        pos_val = components[-1, -1] if fixed_x else components[0, 0]
139
140        flow_x[components == pos_val] = self.deformation_strength * normal_vector[1]
141        flow_y[components == pos_val] = self.deformation_strength * normal_vector[0]
142        flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1]
143        flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0]
144
145        # add small random noise
146        flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
147        flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
148
149        # apply the flow fields
150        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
151        cval = 0.0 if self.mean_val is None else self.mean_val
152        raw = map_coordinates(
153            raw, (flow_y, flow_x), mode='constant', order=3, cval=cval
154        ).reshape(shape)
155
156        # dilate the line mask and zero out the raw below it
157        line_mask = binary_dilation(line_mask, iterations=10)
158        raw[line_mask] = 0.
159        return raw
160
161    def undirected_deformation(self, raw):
162        shape = raw.shape
163
164        # make meshgrid
165        x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
166
167        # generate random vector field and smooth it
168        flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength
169        flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength
170        flow_x = gaussian(flow_x, sigma=3.)  # sigma is hard-coded for now
171        flow_y = gaussian(flow_y, sigma=3.)  # sigma is hard-coded for now
172
173        # apply the flow fields
174        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
175        raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape)
176        return raw
177
178    def deform_slice(self, raw):
179        if self.deformation_mode in ('undirected', 'compress'):
180            mode = self.deformation_mode
181        else:
182            mode = 'undireccted' if np.random.rand() < .5 else 'compress'
183        if mode == 'compress':
184            raw = self.compress_slice(raw)
185        else:
186            raw = self.undirected_deformation(raw)
187        return raw
188
189    def paste_artifact(self, raw):
190        # draw a random artifact location
191        artifact_index = np.random.randint(len(self.artifact_source))
192        artifact, alpha_mask = self.artifact_source[artifact_index]
193        artifact = artifact.numpy().squeeze()
194        alpha_mask = alpha_mask.numpy().squeeze()
195        assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}"
196        assert alpha_mask.shape == raw.shape
197        assert alpha_mask.min() >= 0., f"{alpha_mask.min()}"
198        assert alpha_mask.max() <= 1., f"{alpha_mask.max()}"
199
200        # blend the raw raw data and the artifact according to the alpha mask
201        raw = raw * (1. - alpha_mask) + artifact * alpha_mask
202        return raw
203
204    def __call__(self, raw):
205        raw = raw.astype("float32")  # needs to be floating point to avoid errors
206        for z in range(raw.shape[0]):
207            r = np.random.rand()
208            if r < self.p_drop_slice:
209                # print("Drop slice", z)
210                raw[z] = self.drop_slice(raw[z])
211            elif r < self.p_low_contrast:
212                # print("Low contrast", z)
213                raw[z] = self.low_contrast(raw[z])
214            elif r < self.p_deform_slice:
215                # print("Deform slice", z)
216                raw[z] = self.deform_slice(raw[z])
217            elif r < self.p_paste_artifact:
218                # print("Paste artifact", z)
219                raw[z] = self.paste_artifact(raw[z])
220        return raw

Augment raw data with transformations similar to defects common in EM data.

Arguments:
  • p_drop_slice: probability for a missing slice
  • p_low_contrast: probability for a low contrast slice
  • p_deform_slice: probaboloty for a deformed slice
  • p_paste_artifact: probability for inserting an artifact from data source
  • contrast_scale: scale of low contrast transformation
  • deformation_mode: deformation mode that should be used
  • deformation_strength: deformation strength in pixel
  • artifact_source: data source for additional artifacts
  • mean_val: mean value for artifact normalization
  • std_val: std value for artifact normalization
EMDefectAugmentation( p_drop_slice, p_low_contrast, p_deform_slice, p_paste_artifact=0.0, contrast_scale=0.1, deformation_mode='undirected', deformation_strength=10, artifact_source=None, mean_val=None, std_val=None)
51    def __init__(
52        self,
53        p_drop_slice,
54        p_low_contrast,
55        p_deform_slice,
56        p_paste_artifact=0.0,
57        contrast_scale=0.1,
58        deformation_mode='undirected',
59        deformation_strength=10,
60        artifact_source=None,
61        mean_val=None,
62        std_val=None
63    ):
64        if p_paste_artifact > 0.0:
65            assert artifact_source is not None
66        self.artifact_source = artifact_source
67
68        # use cumulative probabilities
69        self.p_drop_slice = p_drop_slice
70        self.p_low_contrast = self.p_drop_slice + p_low_contrast
71        self.p_deform_slice = self.p_low_contrast + p_deform_slice
72        self.p_paste_artifact = self.p_deform_slice + p_paste_artifact
73        assert self.p_paste_artifact < 1.0
74
75        self.contrast_scale = contrast_scale
76        self.mean_val = mean_val
77        self.std_val = std_val
78
79        # set the parameters for deformation augments
80        if isinstance(deformation_mode, str):
81            assert deformation_mode in ('all', 'undirected', 'compress')
82            self.deformation_mode = deformation_mode
83        elif isinstance(deformation_mode, (list, tuple)):
84            assert len(deformation_mode) == 2
85            assert 'undirected' in deformation_mode
86            assert 'compress' in deformation_mode
87            self.deformation_mode = 'all'
88        self.deformation_strength = deformation_strength
artifact_source
p_drop_slice
p_low_contrast
p_deform_slice
p_paste_artifact
contrast_scale
mean_val
std_val
deformation_strength
def drop_slice(self, raw):
90    def drop_slice(self, raw):
91        raw[:] = 0
92        return raw
def low_contrast(self, raw):
94    def low_contrast(self, raw):
95        mean = raw.mean()
96        raw -= mean
97        raw *= self.contrast_scale
98        raw += mean
99        return raw
def compress_slice(self, raw):
103    def compress_slice(self, raw):
104        shape = raw.shape
105        # randomly choose fixed x or fixed y with p = 1/2
106        fixed_x = np.random.rand() < .5
107        if fixed_x:
108            x0, y0 = 0, np.random.randint(1, shape[1] - 2)
109            x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2)
110        else:
111            x0, y0 = np.random.randint(1, shape[0] - 2), 0
112            x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1
113
114        # generate the mask of the line that should be blacked out
115        line_mask = np.zeros_like(raw, dtype='bool')
116        rr, cc = line(x0, y0, x1, y1)
117        line_mask[rr, cc] = 1
118
119        # generate vectorfield pointing towards the line to compress the image
120        # first we get the unit vector representing the line
121        line_vector = np.array([x1 - x0, y1 - y0], dtype='float32')
122        line_vector /= np.linalg.norm(line_vector)
123        # next, we generate the normal to the line
124        normal_vector = np.zeros_like(line_vector)
125        normal_vector[0] = - line_vector[1]
126        normal_vector[1] = line_vector[0]
127
128        # make meshgrid
129        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
130        # generate the vector field
131        flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw)
132
133        # find the 2 components where coordinates are bigger / smaller than the line
134        # to apply normal vector in the correct direction
135        components = label(np.logical_not(line_mask))
136        assert len(np.unique(components)) == 3, "%i" % len(np.unique(components))
137        neg_val = components[0, 0] if fixed_x else components[-1, -1]
138        pos_val = components[-1, -1] if fixed_x else components[0, 0]
139
140        flow_x[components == pos_val] = self.deformation_strength * normal_vector[1]
141        flow_y[components == pos_val] = self.deformation_strength * normal_vector[0]
142        flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1]
143        flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0]
144
145        # add small random noise
146        flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
147        flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.)
148
149        # apply the flow fields
150        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
151        cval = 0.0 if self.mean_val is None else self.mean_val
152        raw = map_coordinates(
153            raw, (flow_y, flow_x), mode='constant', order=3, cval=cval
154        ).reshape(shape)
155
156        # dilate the line mask and zero out the raw below it
157        line_mask = binary_dilation(line_mask, iterations=10)
158        raw[line_mask] = 0.
159        return raw
def undirected_deformation(self, raw):
161    def undirected_deformation(self, raw):
162        shape = raw.shape
163
164        # make meshgrid
165        x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
166
167        # generate random vector field and smooth it
168        flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength
169        flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength
170        flow_x = gaussian(flow_x, sigma=3.)  # sigma is hard-coded for now
171        flow_y = gaussian(flow_y, sigma=3.)  # sigma is hard-coded for now
172
173        # apply the flow fields
174        flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1)
175        raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape)
176        return raw
def deform_slice(self, raw):
178    def deform_slice(self, raw):
179        if self.deformation_mode in ('undirected', 'compress'):
180            mode = self.deformation_mode
181        else:
182            mode = 'undireccted' if np.random.rand() < .5 else 'compress'
183        if mode == 'compress':
184            raw = self.compress_slice(raw)
185        else:
186            raw = self.undirected_deformation(raw)
187        return raw
def paste_artifact(self, raw):
189    def paste_artifact(self, raw):
190        # draw a random artifact location
191        artifact_index = np.random.randint(len(self.artifact_source))
192        artifact, alpha_mask = self.artifact_source[artifact_index]
193        artifact = artifact.numpy().squeeze()
194        alpha_mask = alpha_mask.numpy().squeeze()
195        assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}"
196        assert alpha_mask.shape == raw.shape
197        assert alpha_mask.min() >= 0., f"{alpha_mask.min()}"
198        assert alpha_mask.max() <= 1., f"{alpha_mask.max()}"
199
200        # blend the raw raw data and the artifact according to the alpha mask
201        raw = raw * (1. - alpha_mask) + artifact * alpha_mask
202        return raw