torch_em.transform.defect
1from typing import Optional 2 3import numpy as np 4import torch 5 6from scipy.ndimage import binary_dilation, map_coordinates 7from skimage.draw import line 8from skimage.filters import gaussian 9 10import bioimage_cpp as bic 11 12from .augmentation import get_augmentations 13from .raw import standardize 14from ..data import SegmentationDataset, MinForegroundSampler 15 16 17# 18# defect augmentations 19# 20# TODO 21# - alignment jitter 22 23 24def get_artifact_source(artifact_path, patch_shape, min_mask_fraction, 25 normalizer=standardize, 26 raw_key="artifacts", mask_key="alpha_mask"): 27 """@private 28 """ 29 augmentation = get_augmentations(ndim=2) 30 sampler = MinForegroundSampler(min_mask_fraction) 31 return SegmentationDataset( 32 artifact_path, raw_key, 33 artifact_path, mask_key, 34 patch_shape=patch_shape, 35 raw_transform=standardize, 36 transform=augmentation, 37 sampler=sampler 38 ) 39 40 41class EMDefectAugmentation: 42 """Augment raw data with transformations similar to defects common in EM data. 43 44 Args: 45 p_drop_slice: Probability for a missing slice. 46 p_low_contrast: Probability for a low contrast slice. 47 p_deform_slice: Probability for a deformed slice. 48 p_paste_artifact: Probability for inserting an artifact from data source. 49 contrast_scale: Scale of low contrast transformation. 50 deformation_mode: Deformation mode that should be used. 51 deformation_strength: Deformation strength in pixel. 52 artifact_source: Data source for additional artifacts. 53 mean_val: Mean value for artifact normalization. 54 std_val: Std value for artifact normalization. 55 """ 56 def __init__( 57 self, 58 p_drop_slice: float, 59 p_low_contrast: float, 60 p_deform_slice: float, 61 p_paste_artifact: float = 0.0, 62 contrast_scale: float = 0.1, 63 deformation_mode: str = "undirected", 64 deformation_strength: float = 10.0, 65 artifact_source: Optional[torch.utils.data.Dataset] = None, 66 mean_val: Optional[float] = None, 67 std_val: Optional[float] = None, 68 ): 69 if p_paste_artifact > 0.0: 70 assert artifact_source is not None 71 self.artifact_source = artifact_source 72 73 # use cumulative probabilities 74 self.p_drop_slice = p_drop_slice 75 self.p_low_contrast = self.p_drop_slice + p_low_contrast 76 self.p_deform_slice = self.p_low_contrast + p_deform_slice 77 self.p_paste_artifact = self.p_deform_slice + p_paste_artifact 78 assert self.p_paste_artifact < 1.0 79 80 self.contrast_scale = contrast_scale 81 self.mean_val = mean_val 82 self.std_val = std_val 83 84 # set the parameters for deformation augments 85 if isinstance(deformation_mode, str): 86 assert deformation_mode in ('all', 'undirected', 'compress') 87 self.deformation_mode = deformation_mode 88 elif isinstance(deformation_mode, (list, tuple)): 89 assert len(deformation_mode) == 2 90 assert 'undirected' in deformation_mode 91 assert 'compress' in deformation_mode 92 self.deformation_mode = 'all' 93 self.deformation_strength = deformation_strength 94 95 def drop_slice(self, raw): 96 """@private 97 """ 98 raw[:] = 0 99 return raw 100 101 def low_contrast(self, raw): 102 """@private 103 """ 104 mean = raw.mean() 105 raw -= mean 106 raw *= self.contrast_scale 107 raw += mean 108 return raw 109 110 # this simulates a typical defect: 111 # missing line of data with rest of data compressed towards the line 112 def compress_slice(self, raw): 113 """@private 114 """ 115 shape = raw.shape 116 # randomly choose fixed x or fixed y with p = 1/2 117 fixed_x = np.random.rand() < .5 118 if fixed_x: 119 x0, y0 = 0, np.random.randint(1, shape[1] - 2) 120 x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2) 121 else: 122 x0, y0 = np.random.randint(1, shape[0] - 2), 0 123 x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1 124 125 # generate the mask of the line that should be blacked out 126 line_mask = np.zeros_like(raw, dtype='bool') 127 rr, cc = line(x0, y0, x1, y1) 128 line_mask[rr, cc] = 1 129 130 # generate vectorfield pointing towards the line to compress the image 131 # first we get the unit vector representing the line 132 line_vector = np.array([x1 - x0, y1 - y0], dtype='float32') 133 line_vector /= np.linalg.norm(line_vector) 134 # next, we generate the normal to the line 135 normal_vector = np.zeros_like(line_vector) 136 normal_vector[0] = - line_vector[1] 137 normal_vector[1] = line_vector[0] 138 139 # make meshgrid 140 x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) 141 # generate the vector field 142 flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw) 143 144 # find the 2 components where coordinates are bigger / smaller than the line 145 # to apply normal vector in the correct direction 146 components = bic.segmentation.label(np.logical_not(line_mask)) 147 assert len(np.unique(components)) == 3, "%i" % len(np.unique(components)) 148 neg_val = components[0, 0] if fixed_x else components[-1, -1] 149 pos_val = components[-1, -1] if fixed_x else components[0, 0] 150 151 flow_x[components == pos_val] = self.deformation_strength * normal_vector[1] 152 flow_y[components == pos_val] = self.deformation_strength * normal_vector[0] 153 flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1] 154 flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0] 155 156 # add small random noise 157 flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 158 flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 159 160 # apply the flow fields 161 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 162 cval = 0.0 if self.mean_val is None else self.mean_val 163 raw = map_coordinates( 164 raw, (flow_y, flow_x), mode='constant', order=3, cval=cval 165 ).reshape(shape) 166 167 # dilate the line mask and zero out the raw below it 168 line_mask = binary_dilation(line_mask, iterations=10) 169 raw[line_mask] = 0. 170 return raw 171 172 def undirected_deformation(self, raw): 173 """@private 174 """ 175 shape = raw.shape 176 177 # make meshgrid 178 x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1])) 179 180 # generate random vector field and smooth it 181 flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength 182 flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength 183 flow_x = gaussian(flow_x, sigma=3.) # sigma is hard-coded for now 184 flow_y = gaussian(flow_y, sigma=3.) # sigma is hard-coded for now 185 186 # apply the flow fields 187 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 188 raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape) 189 return raw 190 191 def deform_slice(self, raw): 192 """@private 193 """ 194 if self.deformation_mode in ('undirected', 'compress'): 195 mode = self.deformation_mode 196 else: 197 mode = 'undireccted' if np.random.rand() < .5 else 'compress' 198 if mode == 'compress': 199 raw = self.compress_slice(raw) 200 else: 201 raw = self.undirected_deformation(raw) 202 return raw 203 204 def paste_artifact(self, raw): 205 """@private 206 """ 207 # draw a random artifact location 208 artifact_index = np.random.randint(len(self.artifact_source)) 209 artifact, alpha_mask = self.artifact_source[artifact_index] 210 artifact = artifact.numpy().squeeze() 211 alpha_mask = alpha_mask.numpy().squeeze() 212 assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}" 213 assert alpha_mask.shape == raw.shape 214 assert alpha_mask.min() >= 0., f"{alpha_mask.min()}" 215 assert alpha_mask.max() <= 1., f"{alpha_mask.max()}" 216 217 # blend the raw raw data and the artifact according to the alpha mask 218 raw = raw * (1. - alpha_mask) + artifact * alpha_mask 219 return raw 220 221 def __call__(self, raw: np.ndarray) -> np.ndarray: 222 """Apply defect augmentations to input data. 223 224 Args: 225 raw: The input data. 226 227 Returns: 228 The augmented data. 229 """ 230 raw = raw.astype("float32") # needs to be floating point to avoid errors 231 for z in range(raw.shape[0]): 232 r = np.random.rand() 233 if r < self.p_drop_slice: 234 # print("Drop slice", z) 235 raw[z] = self.drop_slice(raw[z]) 236 elif r < self.p_low_contrast: 237 # print("Low contrast", z) 238 raw[z] = self.low_contrast(raw[z]) 239 elif r < self.p_deform_slice: 240 # print("Deform slice", z) 241 raw[z] = self.deform_slice(raw[z]) 242 elif r < self.p_paste_artifact: 243 # print("Paste artifact", z) 244 raw[z] = self.paste_artifact(raw[z]) 245 return raw
class
EMDefectAugmentation:
42class EMDefectAugmentation: 43 """Augment raw data with transformations similar to defects common in EM data. 44 45 Args: 46 p_drop_slice: Probability for a missing slice. 47 p_low_contrast: Probability for a low contrast slice. 48 p_deform_slice: Probability for a deformed slice. 49 p_paste_artifact: Probability for inserting an artifact from data source. 50 contrast_scale: Scale of low contrast transformation. 51 deformation_mode: Deformation mode that should be used. 52 deformation_strength: Deformation strength in pixel. 53 artifact_source: Data source for additional artifacts. 54 mean_val: Mean value for artifact normalization. 55 std_val: Std value for artifact normalization. 56 """ 57 def __init__( 58 self, 59 p_drop_slice: float, 60 p_low_contrast: float, 61 p_deform_slice: float, 62 p_paste_artifact: float = 0.0, 63 contrast_scale: float = 0.1, 64 deformation_mode: str = "undirected", 65 deformation_strength: float = 10.0, 66 artifact_source: Optional[torch.utils.data.Dataset] = None, 67 mean_val: Optional[float] = None, 68 std_val: Optional[float] = None, 69 ): 70 if p_paste_artifact > 0.0: 71 assert artifact_source is not None 72 self.artifact_source = artifact_source 73 74 # use cumulative probabilities 75 self.p_drop_slice = p_drop_slice 76 self.p_low_contrast = self.p_drop_slice + p_low_contrast 77 self.p_deform_slice = self.p_low_contrast + p_deform_slice 78 self.p_paste_artifact = self.p_deform_slice + p_paste_artifact 79 assert self.p_paste_artifact < 1.0 80 81 self.contrast_scale = contrast_scale 82 self.mean_val = mean_val 83 self.std_val = std_val 84 85 # set the parameters for deformation augments 86 if isinstance(deformation_mode, str): 87 assert deformation_mode in ('all', 'undirected', 'compress') 88 self.deformation_mode = deformation_mode 89 elif isinstance(deformation_mode, (list, tuple)): 90 assert len(deformation_mode) == 2 91 assert 'undirected' in deformation_mode 92 assert 'compress' in deformation_mode 93 self.deformation_mode = 'all' 94 self.deformation_strength = deformation_strength 95 96 def drop_slice(self, raw): 97 """@private 98 """ 99 raw[:] = 0 100 return raw 101 102 def low_contrast(self, raw): 103 """@private 104 """ 105 mean = raw.mean() 106 raw -= mean 107 raw *= self.contrast_scale 108 raw += mean 109 return raw 110 111 # this simulates a typical defect: 112 # missing line of data with rest of data compressed towards the line 113 def compress_slice(self, raw): 114 """@private 115 """ 116 shape = raw.shape 117 # randomly choose fixed x or fixed y with p = 1/2 118 fixed_x = np.random.rand() < .5 119 if fixed_x: 120 x0, y0 = 0, np.random.randint(1, shape[1] - 2) 121 x1, y1 = shape[0] - 1, np.random.randint(1, shape[1] - 2) 122 else: 123 x0, y0 = np.random.randint(1, shape[0] - 2), 0 124 x1, y1 = np.random.randint(1, shape[0] - 2), shape[1] - 1 125 126 # generate the mask of the line that should be blacked out 127 line_mask = np.zeros_like(raw, dtype='bool') 128 rr, cc = line(x0, y0, x1, y1) 129 line_mask[rr, cc] = 1 130 131 # generate vectorfield pointing towards the line to compress the image 132 # first we get the unit vector representing the line 133 line_vector = np.array([x1 - x0, y1 - y0], dtype='float32') 134 line_vector /= np.linalg.norm(line_vector) 135 # next, we generate the normal to the line 136 normal_vector = np.zeros_like(line_vector) 137 normal_vector[0] = - line_vector[1] 138 normal_vector[1] = line_vector[0] 139 140 # make meshgrid 141 x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) 142 # generate the vector field 143 flow_x, flow_y = np.zeros_like(raw), np.zeros_like(raw) 144 145 # find the 2 components where coordinates are bigger / smaller than the line 146 # to apply normal vector in the correct direction 147 components = bic.segmentation.label(np.logical_not(line_mask)) 148 assert len(np.unique(components)) == 3, "%i" % len(np.unique(components)) 149 neg_val = components[0, 0] if fixed_x else components[-1, -1] 150 pos_val = components[-1, -1] if fixed_x else components[0, 0] 151 152 flow_x[components == pos_val] = self.deformation_strength * normal_vector[1] 153 flow_y[components == pos_val] = self.deformation_strength * normal_vector[0] 154 flow_x[components == neg_val] = - self.deformation_strength * normal_vector[1] 155 flow_y[components == neg_val] = - self.deformation_strength * normal_vector[0] 156 157 # add small random noise 158 flow_x += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 159 flow_y += np.random.uniform(-1, 1, shape) * (self.deformation_strength / 8.) 160 161 # apply the flow fields 162 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 163 cval = 0.0 if self.mean_val is None else self.mean_val 164 raw = map_coordinates( 165 raw, (flow_y, flow_x), mode='constant', order=3, cval=cval 166 ).reshape(shape) 167 168 # dilate the line mask and zero out the raw below it 169 line_mask = binary_dilation(line_mask, iterations=10) 170 raw[line_mask] = 0. 171 return raw 172 173 def undirected_deformation(self, raw): 174 """@private 175 """ 176 shape = raw.shape 177 178 # make meshgrid 179 x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1])) 180 181 # generate random vector field and smooth it 182 flow_x = np.random.uniform(-1, 1, shape) * self.deformation_strength 183 flow_y = np.random.uniform(-1, 1, shape) * self.deformation_strength 184 flow_x = gaussian(flow_x, sigma=3.) # sigma is hard-coded for now 185 flow_y = gaussian(flow_y, sigma=3.) # sigma is hard-coded for now 186 187 # apply the flow fields 188 flow_x, flow_y = (x + flow_x).reshape(-1, 1), (y + flow_y).reshape(-1, 1) 189 raw = map_coordinates(raw, (flow_y, flow_x), mode='constant').reshape(shape) 190 return raw 191 192 def deform_slice(self, raw): 193 """@private 194 """ 195 if self.deformation_mode in ('undirected', 'compress'): 196 mode = self.deformation_mode 197 else: 198 mode = 'undireccted' if np.random.rand() < .5 else 'compress' 199 if mode == 'compress': 200 raw = self.compress_slice(raw) 201 else: 202 raw = self.undirected_deformation(raw) 203 return raw 204 205 def paste_artifact(self, raw): 206 """@private 207 """ 208 # draw a random artifact location 209 artifact_index = np.random.randint(len(self.artifact_source)) 210 artifact, alpha_mask = self.artifact_source[artifact_index] 211 artifact = artifact.numpy().squeeze() 212 alpha_mask = alpha_mask.numpy().squeeze() 213 assert artifact.shape == raw.shape, f"{artifact.shape}, {raw.shape}" 214 assert alpha_mask.shape == raw.shape 215 assert alpha_mask.min() >= 0., f"{alpha_mask.min()}" 216 assert alpha_mask.max() <= 1., f"{alpha_mask.max()}" 217 218 # blend the raw raw data and the artifact according to the alpha mask 219 raw = raw * (1. - alpha_mask) + artifact * alpha_mask 220 return raw 221 222 def __call__(self, raw: np.ndarray) -> np.ndarray: 223 """Apply defect augmentations to input data. 224 225 Args: 226 raw: The input data. 227 228 Returns: 229 The augmented data. 230 """ 231 raw = raw.astype("float32") # needs to be floating point to avoid errors 232 for z in range(raw.shape[0]): 233 r = np.random.rand() 234 if r < self.p_drop_slice: 235 # print("Drop slice", z) 236 raw[z] = self.drop_slice(raw[z]) 237 elif r < self.p_low_contrast: 238 # print("Low contrast", z) 239 raw[z] = self.low_contrast(raw[z]) 240 elif r < self.p_deform_slice: 241 # print("Deform slice", z) 242 raw[z] = self.deform_slice(raw[z]) 243 elif r < self.p_paste_artifact: 244 # print("Paste artifact", z) 245 raw[z] = self.paste_artifact(raw[z]) 246 return raw
Augment raw data with transformations similar to defects common in EM data.
Arguments:
- p_drop_slice: Probability for a missing slice.
- p_low_contrast: Probability for a low contrast slice.
- p_deform_slice: Probability for a deformed slice.
- p_paste_artifact: Probability for inserting an artifact from data source.
- contrast_scale: Scale of low contrast transformation.
- deformation_mode: Deformation mode that should be used.
- deformation_strength: Deformation strength in pixel.
- artifact_source: Data source for additional artifacts.
- mean_val: Mean value for artifact normalization.
- std_val: Std value for artifact normalization.
EMDefectAugmentation( p_drop_slice: float, p_low_contrast: float, p_deform_slice: float, p_paste_artifact: float = 0.0, contrast_scale: float = 0.1, deformation_mode: str = 'undirected', deformation_strength: float = 10.0, artifact_source: Optional[torch.utils.data.dataset.Dataset] = None, mean_val: Optional[float] = None, std_val: Optional[float] = None)
57 def __init__( 58 self, 59 p_drop_slice: float, 60 p_low_contrast: float, 61 p_deform_slice: float, 62 p_paste_artifact: float = 0.0, 63 contrast_scale: float = 0.1, 64 deformation_mode: str = "undirected", 65 deformation_strength: float = 10.0, 66 artifact_source: Optional[torch.utils.data.Dataset] = None, 67 mean_val: Optional[float] = None, 68 std_val: Optional[float] = None, 69 ): 70 if p_paste_artifact > 0.0: 71 assert artifact_source is not None 72 self.artifact_source = artifact_source 73 74 # use cumulative probabilities 75 self.p_drop_slice = p_drop_slice 76 self.p_low_contrast = self.p_drop_slice + p_low_contrast 77 self.p_deform_slice = self.p_low_contrast + p_deform_slice 78 self.p_paste_artifact = self.p_deform_slice + p_paste_artifact 79 assert self.p_paste_artifact < 1.0 80 81 self.contrast_scale = contrast_scale 82 self.mean_val = mean_val 83 self.std_val = std_val 84 85 # set the parameters for deformation augments 86 if isinstance(deformation_mode, str): 87 assert deformation_mode in ('all', 'undirected', 'compress') 88 self.deformation_mode = deformation_mode 89 elif isinstance(deformation_mode, (list, tuple)): 90 assert len(deformation_mode) == 2 91 assert 'undirected' in deformation_mode 92 assert 'compress' in deformation_mode 93 self.deformation_mode = 'all' 94 self.deformation_strength = deformation_strength