torch_em.transform.defect
1import numpy as np 2 3from scipy.ndimage import binary_dilation, map_coordinates 4from skimage.draw import line 5from skimage.filters import gaussian 6from skimage.measure import label 7 8from .augmentation import get_augmentations 9from .raw import standardize 10from ..data import SegmentationDataset, MinForegroundSampler 11 12 13# 14# defect augmentations 15# 16# TODO 17# - alignment jitter 18 19 20def get_artifact_source(artifact_path, patch_shape, min_mask_fraction, 21 normalizer=standardize, 22 raw_key="artifacts", mask_key="alpha_mask"): 23 augmentation = get_augmentations(ndim=2) 24 sampler = MinForegroundSampler(min_mask_fraction) 25 return SegmentationDataset( 26 artifact_path, raw_key, 27 artifact_path, mask_key, 28 patch_shape=patch_shape, 29 raw_transform=standardize, 30 transform=augmentation, 31 sampler=sampler 32 ) 33 34 35class EMDefectAugmentation: 36 """Augment raw data with transformations similar to defects common in EM data. 37 38 Arguments: 39 p_drop_slice: probability for a missing slice 40 p_low_contrast: probability for a low contrast slice 41 p_deform_slice: probaboloty for a deformed slice 42 p_paste_artifact: probability for inserting an artifact from data source 43 contrast_scale: scale of low contrast transformation 44 deformation_mode: deformation mode that should be used 45 deformation_strength: deformation strength in pixel 46 artifact_source: data source for additional artifacts 47 mean_val: mean value for artifact normalization 48 std_val: std value for artifact normalization 49 """ 50 def __init__( 51 self, 52 p_drop_slice, 53 p_low_contrast, 54 p_deform_slice, 55 p_paste_artifact=0.0, 56 contrast_scale=0.1, 57 deformation_mode='undirected', 58 deformation_strength=10, 59 artifact_source=None, 60 mean_val=None, 61 std_val=None 62 ): 63 if p_paste_artifact > 0.0: 64 assert artifact_source is not None 65 self.artifact_source = artifact_source 66 67 # use cumulative probabilities 68 self.p_drop_slice = p_drop_slice 69 self.p_low_contrast = self.p_drop_slice + p_low_contrast 70 self.p_deform_slice = self.p_low_contrast + p_deform_slice 71 self.p_paste_artifact = self.p_deform_slice + p_paste_artifact 72 assert self.p_paste_artifact < 1.0 73 74 self.contrast_scale = contrast_scale 75 self.mean_val = mean_val 76 self.std_val = std_val 77 78 # set the parameters for deformation augments 79 if isinstance(deformation_mode, str): 80 assert deformation_mode in ('all', 'undirected', 'compress') 81 self.deformation_mode = deformation_mode 82 elif isinstance(deformation_mode, (list, tuple)): 83 assert len(deformation_mode) == 2 84 assert 'undirected' in deformation_mode 85 assert 'compress' in deformation_mode 86 self.deformation_mode = 'all' 87 self.deformation_strength = deformation_strength 88 89 def drop_slice(self, raw): 90 raw[:] = 0 91 return raw 92 93 def low_contrast(self, raw): 94 mean = raw.mean() 95 raw -= mean 96 raw *= self.contrast_scale 97 raw += mean 98 return raw 99 100 # this simulates a typical defect: 101 # missing line of data with rest of data compressed towards the line 102 def compress_slice(self, raw): 103 shape = raw.shape 104 # randomly choose fixed x or fixed y with p = 1/2 105 fixed_x = np.random.rand() < .5 106 if fixed_x: 107 x0, y0 = 0, np.random.randint(1, shape[1] - 2) 108 x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2) 109 else: 110 x0, y0 = np.random.randint(1, shape[0] - 2), 0 111 x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1 112 113 # generate the mask of the line that should be blacked out 114 line_mask = np.zeros_like(raw, dtype='bool') 115 rr, cc = line(x0, y0, x1, y1) 116 line_mask[rr, cc] = 1 117 118 # generate vectorfield pointing towards the line to compress the image 119 # first we get the unit vector representing the line 120 line_vector = np.array([x1 - x0, y1 - y0], dtype='float32') 121 line_vector /= np.linalg.norm(line_vector) 122 # next, we generate the normal to the line 123 normal_vector = np.zeros_like(line_vector) 124 normal_vector[0] = - line_vector[1] 125 normal_vector[1] = line_vector[0] 126 127 # make meshgrid 128 x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) 129 # generate the vector field 130 flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw) 131 132 # find the 2 components where coordinates are bigger / smaller than the line 133 # to apply normal vector in the correct direction 134 components = label(np.logical_not(line_mask)) 135 assert len(np.unique(components)) == 3, "%i" % len(np.unique(components)) 136 neg_val = components[0, 0] if fixed_x else components[-1, -1] 137 pos_val = components[-1, -1] if fixed_x else components[0, 0] 138 139 flow_x[components == pos_val] = self.deformation_strength * normal_vector[1] 140 flow_y[components == pos_val] = self.deformation_strength * normal_vector[0] 141 flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1] 142 flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0] 143 144 # add small random noise 145 flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 146 flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 147 148 # apply the flow fields 149 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 150 cval = 0.0 if self.mean_val is None else self.mean_val 151 raw = map_coordinates( 152 raw, (flow_y, flow_x), mode='constant', order=3, cval=cval 153 ).reshape(shape) 154 155 # dilate the line mask and zero out the raw below it 156 line_mask = binary_dilation(line_mask, iterations=10) 157 raw[line_mask] = 0. 158 return raw 159 160 def undirected_deformation(self, raw): 161 shape = raw.shape 162 163 # make meshgrid 164 x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1])) 165 166 # generate random vector field and smooth it 167 flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength 168 flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength 169 flow_x = gaussian(flow_x, sigma=3.) # sigma is hard-coded for now 170 flow_y = gaussian(flow_y, sigma=3.) # sigma is hard-coded for now 171 172 # apply the flow fields 173 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 174 raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape) 175 return raw 176 177 def deform_slice(self, raw): 178 if self.deformation_mode in ('undirected', 'compress'): 179 mode = self.deformation_mode 180 else: 181 mode = 'undireccted' if np.random.rand() < .5 else 'compress' 182 if mode == 'compress': 183 raw = self.compress_slice(raw) 184 else: 185 raw = self.undirected_deformation(raw) 186 return raw 187 188 def paste_artifact(self, raw): 189 # draw a random artifact location 190 artifact_index = np.random.randint(len(self.artifact_source)) 191 artifact, alpha_mask = self.artifact_source[artifact_index] 192 artifact = artifact.numpy().squeeze() 193 alpha_mask = alpha_mask.numpy().squeeze() 194 assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}" 195 assert alpha_mask.shape == raw.shape 196 assert alpha_mask.min() >= 0., f"{alpha_mask.min()}" 197 assert alpha_mask.max() <= 1., f"{alpha_mask.max()}" 198 199 # blend the raw raw data and the artifact according to the alpha mask 200 raw = raw * (1. - alpha_mask) + artifact * alpha_mask 201 return raw 202 203 def __call__(self, raw): 204 raw = raw.astype("float32") # needs to be floating point to avoid errors 205 for z in range(raw.shape[0]): 206 r = np.random.rand() 207 if r < self.p_drop_slice: 208 # print("Drop slice", z) 209 raw[z] = self.drop_slice(raw[z]) 210 elif r < self.p_low_contrast: 211 # print("Low contrast", z) 212 raw[z] = self.low_contrast(raw[z]) 213 elif r < self.p_deform_slice: 214 # print("Deform slice", z) 215 raw[z] = self.deform_slice(raw[z]) 216 elif r < self.p_paste_artifact: 217 # print("Paste artifact", z) 218 raw[z] = self.paste_artifact(raw[z]) 219 return raw
def
get_artifact_source( artifact_path, patch_shape, min_mask_fraction, normalizer=<function standardize>, raw_key='artifacts', mask_key='alpha_mask'):
21def get_artifact_source(artifact_path, patch_shape, min_mask_fraction, 22 normalizer=standardize, 23 raw_key="artifacts", mask_key="alpha_mask"): 24 augmentation = get_augmentations(ndim=2) 25 sampler = MinForegroundSampler(min_mask_fraction) 26 return SegmentationDataset( 27 artifact_path, raw_key, 28 artifact_path, mask_key, 29 patch_shape=patch_shape, 30 raw_transform=standardize, 31 transform=augmentation, 32 sampler=sampler 33 )
class
EMDefectAugmentation:
36class EMDefectAugmentation: 37 """Augment raw data with transformations similar to defects common in EM data. 38 39 Arguments: 40 p_drop_slice: probability for a missing slice 41 p_low_contrast: probability for a low contrast slice 42 p_deform_slice: probaboloty for a deformed slice 43 p_paste_artifact: probability for inserting an artifact from data source 44 contrast_scale: scale of low contrast transformation 45 deformation_mode: deformation mode that should be used 46 deformation_strength: deformation strength in pixel 47 artifact_source: data source for additional artifacts 48 mean_val: mean value for artifact normalization 49 std_val: std value for artifact normalization 50 """ 51 def __init__( 52 self, 53 p_drop_slice, 54 p_low_contrast, 55 p_deform_slice, 56 p_paste_artifact=0.0, 57 contrast_scale=0.1, 58 deformation_mode='undirected', 59 deformation_strength=10, 60 artifact_source=None, 61 mean_val=None, 62 std_val=None 63 ): 64 if p_paste_artifact > 0.0: 65 assert artifact_source is not None 66 self.artifact_source = artifact_source 67 68 # use cumulative probabilities 69 self.p_drop_slice = p_drop_slice 70 self.p_low_contrast = self.p_drop_slice + p_low_contrast 71 self.p_deform_slice = self.p_low_contrast + p_deform_slice 72 self.p_paste_artifact = self.p_deform_slice + p_paste_artifact 73 assert self.p_paste_artifact < 1.0 74 75 self.contrast_scale = contrast_scale 76 self.mean_val = mean_val 77 self.std_val = std_val 78 79 # set the parameters for deformation augments 80 if isinstance(deformation_mode, str): 81 assert deformation_mode in ('all', 'undirected', 'compress') 82 self.deformation_mode = deformation_mode 83 elif isinstance(deformation_mode, (list, tuple)): 84 assert len(deformation_mode) == 2 85 assert 'undirected' in deformation_mode 86 assert 'compress' in deformation_mode 87 self.deformation_mode = 'all' 88 self.deformation_strength = deformation_strength 89 90 def drop_slice(self, raw): 91 raw[:] = 0 92 return raw 93 94 def low_contrast(self, raw): 95 mean = raw.mean() 96 raw -= mean 97 raw *= self.contrast_scale 98 raw += mean 99 return raw 100 101 # this simulates a typical defect: 102 # missing line of data with rest of data compressed towards the line 103 def compress_slice(self, raw): 104 shape = raw.shape 105 # randomly choose fixed x or fixed y with p = 1/2 106 fixed_x = np.random.rand() < .5 107 if fixed_x: 108 x0, y0 = 0, np.random.randint(1, shape[1] - 2) 109 x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2) 110 else: 111 x0, y0 = np.random.randint(1, shape[0] - 2), 0 112 x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1 113 114 # generate the mask of the line that should be blacked out 115 line_mask = np.zeros_like(raw, dtype='bool') 116 rr, cc = line(x0, y0, x1, y1) 117 line_mask[rr, cc] = 1 118 119 # generate vectorfield pointing towards the line to compress the image 120 # first we get the unit vector representing the line 121 line_vector = np.array([x1 - x0, y1 - y0], dtype='float32') 122 line_vector /= np.linalg.norm(line_vector) 123 # next, we generate the normal to the line 124 normal_vector = np.zeros_like(line_vector) 125 normal_vector[0] = - line_vector[1] 126 normal_vector[1] = line_vector[0] 127 128 # make meshgrid 129 x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) 130 # generate the vector field 131 flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw) 132 133 # find the 2 components where coordinates are bigger / smaller than the line 134 # to apply normal vector in the correct direction 135 components = label(np.logical_not(line_mask)) 136 assert len(np.unique(components)) == 3, "%i" % len(np.unique(components)) 137 neg_val = components[0, 0] if fixed_x else components[-1, -1] 138 pos_val = components[-1, -1] if fixed_x else components[0, 0] 139 140 flow_x[components == pos_val] = self.deformation_strength * normal_vector[1] 141 flow_y[components == pos_val] = self.deformation_strength * normal_vector[0] 142 flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1] 143 flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0] 144 145 # add small random noise 146 flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 147 flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 148 149 # apply the flow fields 150 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 151 cval = 0.0 if self.mean_val is None else self.mean_val 152 raw = map_coordinates( 153 raw, (flow_y, flow_x), mode='constant', order=3, cval=cval 154 ).reshape(shape) 155 156 # dilate the line mask and zero out the raw below it 157 line_mask = binary_dilation(line_mask, iterations=10) 158 raw[line_mask] = 0. 159 return raw 160 161 def undirected_deformation(self, raw): 162 shape = raw.shape 163 164 # make meshgrid 165 x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1])) 166 167 # generate random vector field and smooth it 168 flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength 169 flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength 170 flow_x = gaussian(flow_x, sigma=3.) # sigma is hard-coded for now 171 flow_y = gaussian(flow_y, sigma=3.) # sigma is hard-coded for now 172 173 # apply the flow fields 174 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 175 raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape) 176 return raw 177 178 def deform_slice(self, raw): 179 if self.deformation_mode in ('undirected', 'compress'): 180 mode = self.deformation_mode 181 else: 182 mode = 'undireccted' if np.random.rand() < .5 else 'compress' 183 if mode == 'compress': 184 raw = self.compress_slice(raw) 185 else: 186 raw = self.undirected_deformation(raw) 187 return raw 188 189 def paste_artifact(self, raw): 190 # draw a random artifact location 191 artifact_index = np.random.randint(len(self.artifact_source)) 192 artifact, alpha_mask = self.artifact_source[artifact_index] 193 artifact = artifact.numpy().squeeze() 194 alpha_mask = alpha_mask.numpy().squeeze() 195 assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}" 196 assert alpha_mask.shape == raw.shape 197 assert alpha_mask.min() >= 0., f"{alpha_mask.min()}" 198 assert alpha_mask.max() <= 1., f"{alpha_mask.max()}" 199 200 # blend the raw raw data and the artifact according to the alpha mask 201 raw = raw * (1. - alpha_mask) + artifact * alpha_mask 202 return raw 203 204 def __call__(self, raw): 205 raw = raw.astype("float32") # needs to be floating point to avoid errors 206 for z in range(raw.shape[0]): 207 r = np.random.rand() 208 if r < self.p_drop_slice: 209 # print("Drop slice", z) 210 raw[z] = self.drop_slice(raw[z]) 211 elif r < self.p_low_contrast: 212 # print("Low contrast", z) 213 raw[z] = self.low_contrast(raw[z]) 214 elif r < self.p_deform_slice: 215 # print("Deform slice", z) 216 raw[z] = self.deform_slice(raw[z]) 217 elif r < self.p_paste_artifact: 218 # print("Paste artifact", z) 219 raw[z] = self.paste_artifact(raw[z]) 220 return raw
Augment raw data with transformations similar to defects common in EM data.
Arguments:
- p_drop_slice: probability for a missing slice
- p_low_contrast: probability for a low contrast slice
- p_deform_slice: probaboloty for a deformed slice
- p_paste_artifact: probability for inserting an artifact from data source
- contrast_scale: scale of low contrast transformation
- deformation_mode: deformation mode that should be used
- deformation_strength: deformation strength in pixel
- artifact_source: data source for additional artifacts
- mean_val: mean value for artifact normalization
- std_val: std value for artifact normalization
EMDefectAugmentation( p_drop_slice, p_low_contrast, p_deform_slice, p_paste_artifact=0.0, contrast_scale=0.1, deformation_mode='undirected', deformation_strength=10, artifact_source=None, mean_val=None, std_val=None)
51 def __init__( 52 self, 53 p_drop_slice, 54 p_low_contrast, 55 p_deform_slice, 56 p_paste_artifact=0.0, 57 contrast_scale=0.1, 58 deformation_mode='undirected', 59 deformation_strength=10, 60 artifact_source=None, 61 mean_val=None, 62 std_val=None 63 ): 64 if p_paste_artifact > 0.0: 65 assert artifact_source is not None 66 self.artifact_source = artifact_source 67 68 # use cumulative probabilities 69 self.p_drop_slice = p_drop_slice 70 self.p_low_contrast = self.p_drop_slice + p_low_contrast 71 self.p_deform_slice = self.p_low_contrast + p_deform_slice 72 self.p_paste_artifact = self.p_deform_slice + p_paste_artifact 73 assert self.p_paste_artifact < 1.0 74 75 self.contrast_scale = contrast_scale 76 self.mean_val = mean_val 77 self.std_val = std_val 78 79 # set the parameters for deformation augments 80 if isinstance(deformation_mode, str): 81 assert deformation_mode in ('all', 'undirected', 'compress') 82 self.deformation_mode = deformation_mode 83 elif isinstance(deformation_mode, (list, tuple)): 84 assert len(deformation_mode) == 2 85 assert 'undirected' in deformation_mode 86 assert 'compress' in deformation_mode 87 self.deformation_mode = 'all' 88 self.deformation_strength = deformation_strength
def
compress_slice(self, raw):
103 def compress_slice(self, raw): 104 shape = raw.shape 105 # randomly choose fixed x or fixed y with p = 1/2 106 fixed_x = np.random.rand() < .5 107 if fixed_x: 108 x0, y0 = 0, np.random.randint(1, shape[1] - 2) 109 x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2) 110 else: 111 x0, y0 = np.random.randint(1, shape[0] - 2), 0 112 x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1 113 114 # generate the mask of the line that should be blacked out 115 line_mask = np.zeros_like(raw, dtype='bool') 116 rr, cc = line(x0, y0, x1, y1) 117 line_mask[rr, cc] = 1 118 119 # generate vectorfield pointing towards the line to compress the image 120 # first we get the unit vector representing the line 121 line_vector = np.array([x1 - x0, y1 - y0], dtype='float32') 122 line_vector /= np.linalg.norm(line_vector) 123 # next, we generate the normal to the line 124 normal_vector = np.zeros_like(line_vector) 125 normal_vector[0] = - line_vector[1] 126 normal_vector[1] = line_vector[0] 127 128 # make meshgrid 129 x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) 130 # generate the vector field 131 flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw) 132 133 # find the 2 components where coordinates are bigger / smaller than the line 134 # to apply normal vector in the correct direction 135 components = label(np.logical_not(line_mask)) 136 assert len(np.unique(components)) == 3, "%i" % len(np.unique(components)) 137 neg_val = components[0, 0] if fixed_x else components[-1, -1] 138 pos_val = components[-1, -1] if fixed_x else components[0, 0] 139 140 flow_x[components == pos_val] = self.deformation_strength * normal_vector[1] 141 flow_y[components == pos_val] = self.deformation_strength * normal_vector[0] 142 flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1] 143 flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0] 144 145 # add small random noise 146 flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 147 flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 148 149 # apply the flow fields 150 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 151 cval = 0.0 if self.mean_val is None else self.mean_val 152 raw = map_coordinates( 153 raw, (flow_y, flow_x), mode='constant', order=3, cval=cval 154 ).reshape(shape) 155 156 # dilate the line mask and zero out the raw below it 157 line_mask = binary_dilation(line_mask, iterations=10) 158 raw[line_mask] = 0. 159 return raw
def
undirected_deformation(self, raw):
161 def undirected_deformation(self, raw): 162 shape = raw.shape 163 164 # make meshgrid 165 x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1])) 166 167 # generate random vector field and smooth it 168 flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength 169 flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength 170 flow_x = gaussian(flow_x, sigma=3.) # sigma is hard-coded for now 171 flow_y = gaussian(flow_y, sigma=3.) # sigma is hard-coded for now 172 173 # apply the flow fields 174 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 175 raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape) 176 return raw
def
deform_slice(self, raw):
178 def deform_slice(self, raw): 179 if self.deformation_mode in ('undirected', 'compress'): 180 mode = self.deformation_mode 181 else: 182 mode = 'undireccted' if np.random.rand() < .5 else 'compress' 183 if mode == 'compress': 184 raw = self.compress_slice(raw) 185 else: 186 raw = self.undirected_deformation(raw) 187 return raw
def
paste_artifact(self, raw):
189 def paste_artifact(self, raw): 190 # draw a random artifact location 191 artifact_index = np.random.randint(len(self.artifact_source)) 192 artifact, alpha_mask = self.artifact_source[artifact_index] 193 artifact = artifact.numpy().squeeze() 194 alpha_mask = alpha_mask.numpy().squeeze() 195 assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}" 196 assert alpha_mask.shape == raw.shape 197 assert alpha_mask.min() >= 0., f"{alpha_mask.min()}" 198 assert alpha_mask.max() <= 1., f"{alpha_mask.max()}" 199 200 # blend the raw raw data and the artifact according to the alpha mask 201 raw = raw * (1. - alpha_mask) + artifact * alpha_mask 202 return raw