torch_em.transform.defect
1from typing import Optional 2 3import numpy as np 4import torch 5 6from scipy.ndimage import binary_dilation, map_coordinates 7from skimage.draw import line 8from skimage.filters import gaussian 9from skimage.measure import label 10 11from .augmentation import get_augmentations 12from .raw import standardize 13from ..data import SegmentationDataset, MinForegroundSampler 14 15 16# 17# defect augmentations 18# 19# TODO 20# - alignment jitter 21 22 23def get_artifact_source(artifact_path, patch_shape, min_mask_fraction, 24 normalizer=standardize, 25 raw_key="artifacts", mask_key="alpha_mask"): 26 """@private 27 """ 28 augmentation = get_augmentations(ndim=2) 29 sampler = MinForegroundSampler(min_mask_fraction) 30 return SegmentationDataset( 31 artifact_path, raw_key, 32 artifact_path, mask_key, 33 patch_shape=patch_shape, 34 raw_transform=standardize, 35 transform=augmentation, 36 sampler=sampler 37 ) 38 39 40class EMDefectAugmentation: 41 """Augment raw data with transformations similar to defects common in EM data. 42 43 Args: 44 p_drop_slice: Probability for a missing slice. 45 p_low_contrast: Probability for a low contrast slice. 46 p_deform_slice: Probaboloty for a deformed slice. 47 p_paste_artifact: Probability for inserting an artifact from data source. 48 contrast_scale: Scale of low contrast transformation. 49 deformation_mode: Deformation mode that should be used. 50 deformation_strength: Deformation strength in pixel. 51 artifact_source: Data source for additional artifacts. 52 mean_val: Mean value for artifact normalization. 53 std_val: Std value for artifact normalization. 54 """ 55 def __init__( 56 self, 57 p_drop_slice: float, 58 p_low_contrast: float, 59 p_deform_slice: float, 60 p_paste_artifact: float = 0.0, 61 contrast_scale: float = 0.1, 62 deformation_mode: str = "undirected", 63 deformation_strength: float = 10.0, 64 artifact_source: Optional[torch.utils.data.Dataset] = None, 65 mean_val: Optional[float] = None, 66 std_val: Optional[float] = None, 67 ): 68 if p_paste_artifact > 0.0: 69 assert artifact_source is not None 70 self.artifact_source = artifact_source 71 72 # use cumulative probabilities 73 self.p_drop_slice = p_drop_slice 74 self.p_low_contrast = self.p_drop_slice + p_low_contrast 75 self.p_deform_slice = self.p_low_contrast + p_deform_slice 76 self.p_paste_artifact = self.p_deform_slice + p_paste_artifact 77 assert self.p_paste_artifact < 1.0 78 79 self.contrast_scale = contrast_scale 80 self.mean_val = mean_val 81 self.std_val = std_val 82 83 # set the parameters for deformation augments 84 if isinstance(deformation_mode, str): 85 assert deformation_mode in ('all', 'undirected', 'compress') 86 self.deformation_mode = deformation_mode 87 elif isinstance(deformation_mode, (list, tuple)): 88 assert len(deformation_mode) == 2 89 assert 'undirected' in deformation_mode 90 assert 'compress' in deformation_mode 91 self.deformation_mode = 'all' 92 self.deformation_strength = deformation_strength 93 94 def drop_slice(self, raw): 95 """@private 96 """ 97 raw[:] = 0 98 return raw 99 100 def low_contrast(self, raw): 101 """@private 102 """ 103 mean = raw.mean() 104 raw -= mean 105 raw *= self.contrast_scale 106 raw += mean 107 return raw 108 109 # this simulates a typical defect: 110 # missing line of data with rest of data compressed towards the line 111 def compress_slice(self, raw): 112 """@private 113 """ 114 shape = raw.shape 115 # randomly choose fixed x or fixed y with p = 1/2 116 fixed_x = np.random.rand() < .5 117 if fixed_x: 118 x0, y0 = 0, np.random.randint(1, shape[1] - 2) 119 x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2) 120 else: 121 x0, y0 = np.random.randint(1, shape[0] - 2), 0 122 x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1 123 124 # generate the mask of the line that should be blacked out 125 line_mask = np.zeros_like(raw, dtype='bool') 126 rr, cc = line(x0, y0, x1, y1) 127 line_mask[rr, cc] = 1 128 129 # generate vectorfield pointing towards the line to compress the image 130 # first we get the unit vector representing the line 131 line_vector = np.array([x1 - x0, y1 - y0], dtype='float32') 132 line_vector /= np.linalg.norm(line_vector) 133 # next, we generate the normal to the line 134 normal_vector = np.zeros_like(line_vector) 135 normal_vector[0] = - line_vector[1] 136 normal_vector[1] = line_vector[0] 137 138 # make meshgrid 139 x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) 140 # generate the vector field 141 flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw) 142 143 # find the 2 components where coordinates are bigger / smaller than the line 144 # to apply normal vector in the correct direction 145 components = label(np.logical_not(line_mask)) 146 assert len(np.unique(components)) == 3, "%i" % len(np.unique(components)) 147 neg_val = components[0, 0] if fixed_x else components[-1, -1] 148 pos_val = components[-1, -1] if fixed_x else components[0, 0] 149 150 flow_x[components == pos_val] = self.deformation_strength * normal_vector[1] 151 flow_y[components == pos_val] = self.deformation_strength * normal_vector[0] 152 flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1] 153 flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0] 154 155 # add small random noise 156 flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 157 flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 158 159 # apply the flow fields 160 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 161 cval = 0.0 if self.mean_val is None else self.mean_val 162 raw = map_coordinates( 163 raw, (flow_y, flow_x), mode='constant', order=3, cval=cval 164 ).reshape(shape) 165 166 # dilate the line mask and zero out the raw below it 167 line_mask = binary_dilation(line_mask, iterations=10) 168 raw[line_mask] = 0. 169 return raw 170 171 def undirected_deformation(self, raw): 172 """@private 173 """ 174 shape = raw.shape 175 176 # make meshgrid 177 x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1])) 178 179 # generate random vector field and smooth it 180 flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength 181 flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength 182 flow_x = gaussian(flow_x, sigma=3.) # sigma is hard-coded for now 183 flow_y = gaussian(flow_y, sigma=3.) # sigma is hard-coded for now 184 185 # apply the flow fields 186 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 187 raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape) 188 return raw 189 190 def deform_slice(self, raw): 191 """@private 192 """ 193 if self.deformation_mode in ('undirected', 'compress'): 194 mode = self.deformation_mode 195 else: 196 mode = 'undireccted' if np.random.rand() < .5 else 'compress' 197 if mode == 'compress': 198 raw = self.compress_slice(raw) 199 else: 200 raw = self.undirected_deformation(raw) 201 return raw 202 203 def paste_artifact(self, raw): 204 """@private 205 """ 206 # draw a random artifact location 207 artifact_index = np.random.randint(len(self.artifact_source)) 208 artifact, alpha_mask = self.artifact_source[artifact_index] 209 artifact = artifact.numpy().squeeze() 210 alpha_mask = alpha_mask.numpy().squeeze() 211 assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}" 212 assert alpha_mask.shape == raw.shape 213 assert alpha_mask.min() >= 0., f"{alpha_mask.min()}" 214 assert alpha_mask.max() <= 1., f"{alpha_mask.max()}" 215 216 # blend the raw raw data and the artifact according to the alpha mask 217 raw = raw * (1. - alpha_mask) + artifact * alpha_mask 218 return raw 219 220 def __call__(self, raw: np.ndarray) -> np.ndarray: 221 """Apply defect augmentations to input data. 222 223 Args: 224 raw: The input data. 225 226 Returns: 227 The augmented data. 228 """ 229 raw = raw.astype("float32") # needs to be floating point to avoid errors 230 for z in range(raw.shape[0]): 231 r = np.random.rand() 232 if r < self.p_drop_slice: 233 # print("Drop slice", z) 234 raw[z] = self.drop_slice(raw[z]) 235 elif r < self.p_low_contrast: 236 # print("Low contrast", z) 237 raw[z] = self.low_contrast(raw[z]) 238 elif r < self.p_deform_slice: 239 # print("Deform slice", z) 240 raw[z] = self.deform_slice(raw[z]) 241 elif r < self.p_paste_artifact: 242 # print("Paste artifact", z) 243 raw[z] = self.paste_artifact(raw[z]) 244 return raw
class
EMDefectAugmentation:
41class EMDefectAugmentation: 42 """Augment raw data with transformations similar to defects common in EM data. 43 44 Args: 45 p_drop_slice: Probability for a missing slice. 46 p_low_contrast: Probability for a low contrast slice. 47 p_deform_slice: Probaboloty for a deformed slice. 48 p_paste_artifact: Probability for inserting an artifact from data source. 49 contrast_scale: Scale of low contrast transformation. 50 deformation_mode: Deformation mode that should be used. 51 deformation_strength: Deformation strength in pixel. 52 artifact_source: Data source for additional artifacts. 53 mean_val: Mean value for artifact normalization. 54 std_val: Std value for artifact normalization. 55 """ 56 def __init__( 57 self, 58 p_drop_slice: float, 59 p_low_contrast: float, 60 p_deform_slice: float, 61 p_paste_artifact: float = 0.0, 62 contrast_scale: float = 0.1, 63 deformation_mode: str = "undirected", 64 deformation_strength: float = 10.0, 65 artifact_source: Optional[torch.utils.data.Dataset] = None, 66 mean_val: Optional[float] = None, 67 std_val: Optional[float] = None, 68 ): 69 if p_paste_artifact > 0.0: 70 assert artifact_source is not None 71 self.artifact_source = artifact_source 72 73 # use cumulative probabilities 74 self.p_drop_slice = p_drop_slice 75 self.p_low_contrast = self.p_drop_slice + p_low_contrast 76 self.p_deform_slice = self.p_low_contrast + p_deform_slice 77 self.p_paste_artifact = self.p_deform_slice + p_paste_artifact 78 assert self.p_paste_artifact < 1.0 79 80 self.contrast_scale = contrast_scale 81 self.mean_val = mean_val 82 self.std_val = std_val 83 84 # set the parameters for deformation augments 85 if isinstance(deformation_mode, str): 86 assert deformation_mode in ('all', 'undirected', 'compress') 87 self.deformation_mode = deformation_mode 88 elif isinstance(deformation_mode, (list, tuple)): 89 assert len(deformation_mode) == 2 90 assert 'undirected' in deformation_mode 91 assert 'compress' in deformation_mode 92 self.deformation_mode = 'all' 93 self.deformation_strength = deformation_strength 94 95 def drop_slice(self, raw): 96 """@private 97 """ 98 raw[:] = 0 99 return raw 100 101 def low_contrast(self, raw): 102 """@private 103 """ 104 mean = raw.mean() 105 raw -= mean 106 raw *= self.contrast_scale 107 raw += mean 108 return raw 109 110 # this simulates a typical defect: 111 # missing line of data with rest of data compressed towards the line 112 def compress_slice(self, raw): 113 """@private 114 """ 115 shape = raw.shape 116 # randomly choose fixed x or fixed y with p = 1/2 117 fixed_x = np.random.rand() < .5 118 if fixed_x: 119 x0, y0 = 0, np.random.randint(1, shape[1] - 2) 120 x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2) 121 else: 122 x0, y0 = np.random.randint(1, shape[0] - 2), 0 123 x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1 124 125 # generate the mask of the line that should be blacked out 126 line_mask = np.zeros_like(raw, dtype='bool') 127 rr, cc = line(x0, y0, x1, y1) 128 line_mask[rr, cc] = 1 129 130 # generate vectorfield pointing towards the line to compress the image 131 # first we get the unit vector representing the line 132 line_vector = np.array([x1 - x0, y1 - y0], dtype='float32') 133 line_vector /= np.linalg.norm(line_vector) 134 # next, we generate the normal to the line 135 normal_vector = np.zeros_like(line_vector) 136 normal_vector[0] = - line_vector[1] 137 normal_vector[1] = line_vector[0] 138 139 # make meshgrid 140 x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) 141 # generate the vector field 142 flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw) 143 144 # find the 2 components where coordinates are bigger / smaller than the line 145 # to apply normal vector in the correct direction 146 components = label(np.logical_not(line_mask)) 147 assert len(np.unique(components)) == 3, "%i" % len(np.unique(components)) 148 neg_val = components[0, 0] if fixed_x else components[-1, -1] 149 pos_val = components[-1, -1] if fixed_x else components[0, 0] 150 151 flow_x[components == pos_val] = self.deformation_strength * normal_vector[1] 152 flow_y[components == pos_val] = self.deformation_strength * normal_vector[0] 153 flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1] 154 flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0] 155 156 # add small random noise 157 flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 158 flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 159 160 # apply the flow fields 161 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 162 cval = 0.0 if self.mean_val is None else self.mean_val 163 raw = map_coordinates( 164 raw, (flow_y, flow_x), mode='constant', order=3, cval=cval 165 ).reshape(shape) 166 167 # dilate the line mask and zero out the raw below it 168 line_mask = binary_dilation(line_mask, iterations=10) 169 raw[line_mask] = 0. 170 return raw 171 172 def undirected_deformation(self, raw): 173 """@private 174 """ 175 shape = raw.shape 176 177 # make meshgrid 178 x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1])) 179 180 # generate random vector field and smooth it 181 flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength 182 flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength 183 flow_x = gaussian(flow_x, sigma=3.) # sigma is hard-coded for now 184 flow_y = gaussian(flow_y, sigma=3.) # sigma is hard-coded for now 185 186 # apply the flow fields 187 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 188 raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape) 189 return raw 190 191 def deform_slice(self, raw): 192 """@private 193 """ 194 if self.deformation_mode in ('undirected', 'compress'): 195 mode = self.deformation_mode 196 else: 197 mode = 'undireccted' if np.random.rand() < .5 else 'compress' 198 if mode == 'compress': 199 raw = self.compress_slice(raw) 200 else: 201 raw = self.undirected_deformation(raw) 202 return raw 203 204 def paste_artifact(self, raw): 205 """@private 206 """ 207 # draw a random artifact location 208 artifact_index = np.random.randint(len(self.artifact_source)) 209 artifact, alpha_mask = self.artifact_source[artifact_index] 210 artifact = artifact.numpy().squeeze() 211 alpha_mask = alpha_mask.numpy().squeeze() 212 assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}" 213 assert alpha_mask.shape == raw.shape 214 assert alpha_mask.min() >= 0., f"{alpha_mask.min()}" 215 assert alpha_mask.max() <= 1., f"{alpha_mask.max()}" 216 217 # blend the raw raw data and the artifact according to the alpha mask 218 raw = raw * (1. - alpha_mask) + artifact * alpha_mask 219 return raw 220 221 def __call__(self, raw: np.ndarray) -> np.ndarray: 222 """Apply defect augmentations to input data. 223 224 Args: 225 raw: The input data. 226 227 Returns: 228 The augmented data. 229 """ 230 raw = raw.astype("float32") # needs to be floating point to avoid errors 231 for z in range(raw.shape[0]): 232 r = np.random.rand() 233 if r < self.p_drop_slice: 234 # print("Drop slice", z) 235 raw[z] = self.drop_slice(raw[z]) 236 elif r < self.p_low_contrast: 237 # print("Low contrast", z) 238 raw[z] = self.low_contrast(raw[z]) 239 elif r < self.p_deform_slice: 240 # print("Deform slice", z) 241 raw[z] = self.deform_slice(raw[z]) 242 elif r < self.p_paste_artifact: 243 # print("Paste artifact", z) 244 raw[z] = self.paste_artifact(raw[z]) 245 return raw
Augment raw data with transformations similar to defects common in EM data.
Arguments:
- p_drop_slice: Probability for a missing slice.
- p_low_contrast: Probability for a low contrast slice.
- p_deform_slice: Probaboloty for a deformed slice.
- p_paste_artifact: Probability for inserting an artifact from data source.
- contrast_scale: Scale of low contrast transformation.
- deformation_mode: Deformation mode that should be used.
- deformation_strength: Deformation strength in pixel.
- artifact_source: Data source for additional artifacts.
- mean_val: Mean value for artifact normalization.
- std_val: Std value for artifact normalization.
EMDefectAugmentation( p_drop_slice: float, p_low_contrast: float, p_deform_slice: float, p_paste_artifact: float = 0.0, contrast_scale: float = 0.1, deformation_mode: str = 'undirected', deformation_strength: float = 10.0, artifact_source: Optional[torch.utils.data.dataset.Dataset] = None, mean_val: Optional[float] = None, std_val: Optional[float] = None)
56 def __init__( 57 self, 58 p_drop_slice: float, 59 p_low_contrast: float, 60 p_deform_slice: float, 61 p_paste_artifact: float = 0.0, 62 contrast_scale: float = 0.1, 63 deformation_mode: str = "undirected", 64 deformation_strength: float = 10.0, 65 artifact_source: Optional[torch.utils.data.Dataset] = None, 66 mean_val: Optional[float] = None, 67 std_val: Optional[float] = None, 68 ): 69 if p_paste_artifact > 0.0: 70 assert artifact_source is not None 71 self.artifact_source = artifact_source 72 73 # use cumulative probabilities 74 self.p_drop_slice = p_drop_slice 75 self.p_low_contrast = self.p_drop_slice + p_low_contrast 76 self.p_deform_slice = self.p_low_contrast + p_deform_slice 77 self.p_paste_artifact = self.p_deform_slice + p_paste_artifact 78 assert self.p_paste_artifact < 1.0 79 80 self.contrast_scale = contrast_scale 81 self.mean_val = mean_val 82 self.std_val = std_val 83 84 # set the parameters for deformation augments 85 if isinstance(deformation_mode, str): 86 assert deformation_mode in ('all', 'undirected', 'compress') 87 self.deformation_mode = deformation_mode 88 elif isinstance(deformation_mode, (list, tuple)): 89 assert len(deformation_mode) == 2 90 assert 'undirected' in deformation_mode 91 assert 'compress' in deformation_mode 92 self.deformation_mode = 'all' 93 self.deformation_strength = deformation_strength