torch_em.transform.augmentation
1from typing import Dict, List, Optional, Sequence, Tuple, Union 2 3import kornia 4import numpy as np 5import torch 6from skimage.transform import resize 7 8from ..util import ensure_tensor 9 10 11class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D): 12 """Random elastic deformations implemented with kornia. 13 14 This transformation can be applied to 3D data, the same deformation is applied to each plane. 15 16 Args: 17 control_point_spacing: The control point spacing for the deformation field. 18 sigma: Sigma for smoothing the deformation field. 19 alpha: Alpha value. 20 interpolation: Interpolation order for applying the transformation to the data. 21 p: Probability for applying the transformation. 22 keepdim: 23 same_on_batch: 24 """ 25 def __init__( 26 self, 27 control_point_spacing: Union[int, Sequence[int]] = 1, 28 sigma: Tuple[float, float] = (32.0, 32.0), 29 alpha: Tuple[float, float] = (4.0, 4.0), 30 interpolation=kornia.constants.Resample.BILINEAR, 31 p: float = 0.5, 32 keepdim: bool = False, 33 same_on_batch: bool = True, 34 ): 35 super().__init__(p=p, same_on_batch=same_on_batch) 36 if isinstance(control_point_spacing, int): 37 self.control_point_spacing = [control_point_spacing] * 2 38 else: 39 self.control_point_spacing = control_point_spacing 40 assert len(self.control_point_spacing) == 2 41 self.interpolation = interpolation 42 self.flags = dict(interpolation=torch.tensor(self.interpolation.value), sigma=sigma, alpha=alpha) 43 44 def generate_parameters(self, batch_shape): 45 """@private 46 """ 47 assert len(batch_shape) == 5 48 shape = batch_shape[3:] 49 control_shape = tuple( 50 sh // spacing for sh, spacing in zip(shape, self.control_point_spacing) 51 ) 52 deformation_fields = [ 53 np.random.uniform(-1, 1, control_shape), 54 np.random.uniform(-1, 1, control_shape) 55 ] 56 deformation_fields = [ 57 resize(df, shape, order=3)[None] for df in deformation_fields 58 ] 59 noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") 60 noise = torch.from_numpy(noise) 61 return {"noise": noise} 62 63 def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor: 64 """Apply the augmentation to a tensor. 65 66 Args: 67 input: The input tensor. 68 params: The transformation parameters. 69 70 Returns: 71 The transformed tensor. 72 """ 73 assert len(input.shape) == 5 74 if params is None: 75 params = self.generate_parameters(input.shape) 76 self._params = params 77 78 noise = params["noise"] 79 mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest" 80 noise_ch = noise.expand(input.shape[1], -1, -1, -1) 81 input_transformed = [] 82 for i, x in enumerate(torch.unbind(input, dim=0)): 83 x_transformed = kornia.geometry.transform.elastic_transform2d( 84 x, noise_ch, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection" 85 ) 86 input_transformed.append(x_transformed) 87 input_transformed = torch.stack(input_transformed) 88 return input_transformed 89 90 91class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D): 92 """Random elastic deformations implemented with kornia. 93 94 Args: 95 control_point_spacing: The control point spacing for the deformation field. 96 sigma: Sigma for smoothing the deformation field. 97 alpha: Alpha value. 98 resample: Interpolation order for applying the transformation to the data. 99 p: Probability for applying the transformation. 100 keepdim: 101 same_on_batch: 102 """ 103 def __init__( 104 self, 105 control_point_spacing: Union[int, Sequence[int]] = 1, 106 sigma: Tuple[float, float] = (32.0, 32.0), 107 alpha: Tuple[float, float] = (4.0, 4.0), 108 resample=kornia.constants.Resample.BILINEAR, 109 p: float = 0.5, 110 keepdim: bool = False, 111 same_on_batch: bool = True, 112 ): 113 super().__init__(p=p, same_on_batch=same_on_batch) 114 if isinstance(control_point_spacing, int): 115 self.control_point_spacing = [control_point_spacing] * 2 116 else: 117 self.control_point_spacing = control_point_spacing 118 assert len(self.control_point_spacing) == 2 119 self.resample = resample 120 self.flags = dict(resample=torch.tensor(self.resample.value), sigma=sigma, alpha=alpha) 121 122 def generate_parameters(self, batch_shape): 123 """@private 124 """ 125 assert len(batch_shape) == 4, f"{len(batch_shape)}" 126 shape = batch_shape[2:] 127 control_shape = tuple(sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)) 128 deformation_fields = [np.random.uniform(-1, 1, control_shape), np.random.uniform(-1, 1, control_shape)] 129 deformation_fields = [resize(df, shape, order=3)[None] for df in deformation_fields] 130 noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") 131 noise = torch.from_numpy(noise) 132 return {"noise": noise} 133 134 def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor: 135 """Apply the augmentation to a tensor. 136 137 Args: 138 input: The input tensor. 139 params: The transformation parameters. 140 141 Returns: 142 The transformed tensor. 143 """ 144 if params is None: 145 params = self.generate_parameters(input.shape) 146 self._params = params 147 noise = params["noise"] 148 mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest" 149 return kornia.geometry.transform.elastic_transform2d( 150 input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection" 151 ) 152 153 154# TODO: Implement 'require_halo', and estimate the halo from the transformations 155# so that we can load a bigger block and cut it away. 156class KorniaAugmentationPipeline(torch.nn.Module): 157 """Pipeline to apply multiple kornia augmentations to data. 158 159 Args: 160 kornia_augmentations: Augmentations implemented with kornia. 161 dtype: The data type of the return data. 162 """ 163 interpolatable_torch_types = [torch.float16, torch.float32, torch.float64] 164 interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")] 165 166 def __init__(self, *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32): 167 super().__init__() 168 self.augmentations = torch.nn.ModuleList(kornia_augmentations) 169 self.dtype = dtype 170 self.halo = self.compute_halo() 171 172 # for now we only add a halo for the random rotation trafos and 173 # also don't compute the halo dynamically based on the input shape 174 def compute_halo(self): 175 """@private 176 """ 177 halo = None 178 for aug in self.augmentations: 179 if isinstance(aug, kornia.augmentation.RandomRotation): 180 halo = [32, 32] 181 if isinstance(aug, kornia.augmentation.RandomRotation3D): 182 halo = [32, 32, 32] 183 return halo 184 185 def is_interpolatable(self, tensor): 186 """@private 187 """ 188 if torch.is_tensor(tensor): 189 return tensor.dtype in self.interpolatable_torch_types 190 else: 191 return tensor.dtype in self.interpolatable_numpy_types 192 193 def transform_tensor(self, augmentation, tensor, interpolatable, params=None): 194 """@private 195 """ 196 interpolating = "interpolation" in getattr(augmentation, "flags", []) 197 if interpolating: 198 resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST") 199 augmentation.flags["interpolation"] = torch.tensor(resampler.value) 200 transformed = augmentation(tensor, params) 201 return transformed, augmentation._params 202 203 def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]: 204 """Apply augmentations to a list of tensors. 205 206 Args: 207 tensors: The input tensors. 208 209 Returns: 210 List of transformed tensors. 211 """ 212 interpolatable = [self.is_interpolatable(tensor) for tensor in tensors] 213 tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors] 214 for aug in self.augmentations: 215 216 t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0]) 217 transformed_tensors = [t0] 218 for tensor, interpolate in zip(tensors[1:], interpolatable[1:]): 219 tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params) 220 transformed_tensors.append(tensor) 221 222 tensors = transformed_tensors 223 return tensors 224 225 def halo(self, shape): 226 """@private 227 """ 228 return self.halo 229 230 231# Try out: 232# - RandomPerspective 233AUGMENTATIONS = { 234 "RandomAffine": {"degrees": 90, "scale": (0.9, 1.1)}, 235 "RandomAffine3D": {"degrees": (90, 90, 90), "scale": (0.0, 1.1)}, 236 "RandomDepthicalFlip3D": {}, 237 "RandomHorizontalFlip": {}, 238 "RandomHorizontalFlip3D": {}, 239 "RandomRotation": {"degrees": 90}, 240 "RandomRotation3D": {"degrees": (90, 90, 90)}, 241 "RandomVerticalFlip": {}, 242 "RandomVerticalFlip3D": {}, 243 "RandomElasticDeformation3D": {"alpha": [5, 5], "sigma": [30, 30]} 244} 245"""All available augmentations and their default parameters. 246""" 247 248DEFAULT_2D_AUGMENTATIONS = [ 249 "RandomHorizontalFlip", 250 "RandomVerticalFlip" 251] 252"""The default parameters for 2D data. 253""" 254DEFAULT_3D_AUGMENTATIONS = [ 255 "RandomHorizontalFlip3D", 256 "RandomVerticalFlip3D", 257 "RandomDepthicalFlip3D", 258] 259"""The default parameters for 3D data. 260""" 261DEFAULT_ANISOTROPIC_AUGMENTATIONS = [ 262 "RandomHorizontalFlip3D", 263 "RandomVerticalFlip3D", 264 "RandomDepthicalFlip3D", 265] 266"""The default parameters for anisotropic 3D data. 267""" 268 269 270def create_augmentation(trafo): 271 """@private 272 """ 273 assert trafo in dir(kornia.augmentation) or trafo in globals().keys(), f"Transformation {trafo} not defined" 274 if trafo in dir(kornia.augmentation): 275 return getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo]) 276 return globals()[trafo](**AUGMENTATIONS[trafo]) 277 278 279def get_augmentations(ndim: Union[int, str] = 2, transforms=None, dtype: Union[str, torch.dtype] = torch.float32): 280 """Get augmentation pipeline. 281 282 Args: 283 ndim: The dimensionality for the augmentations. One of 2, 3 or "anisotropic". 284 transforms: The transformations to use for the augmentations. 285 If None, the default augmentations for the given data dimensionality will be used. 286 dtype: The data type of the output data of the augmentation. 287 288 Returns: 289 The augmentation pipeline. 290 """ 291 if transforms is None: 292 assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}" 293 if ndim == 2: 294 transforms = DEFAULT_2D_AUGMENTATIONS 295 elif ndim == 3: 296 transforms = DEFAULT_3D_AUGMENTATIONS 297 else: 298 transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS 299 transforms = [create_augmentation(trafo) for trafo in transforms] 300 assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase) for trafo in transforms) 301 augmentations = KorniaAugmentationPipeline(*transforms, dtype=dtype) 302 return augmentations
12class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D): 13 """Random elastic deformations implemented with kornia. 14 15 This transformation can be applied to 3D data, the same deformation is applied to each plane. 16 17 Args: 18 control_point_spacing: The control point spacing for the deformation field. 19 sigma: Sigma for smoothing the deformation field. 20 alpha: Alpha value. 21 interpolation: Interpolation order for applying the transformation to the data. 22 p: Probability for applying the transformation. 23 keepdim: 24 same_on_batch: 25 """ 26 def __init__( 27 self, 28 control_point_spacing: Union[int, Sequence[int]] = 1, 29 sigma: Tuple[float, float] = (32.0, 32.0), 30 alpha: Tuple[float, float] = (4.0, 4.0), 31 interpolation=kornia.constants.Resample.BILINEAR, 32 p: float = 0.5, 33 keepdim: bool = False, 34 same_on_batch: bool = True, 35 ): 36 super().__init__(p=p, same_on_batch=same_on_batch) 37 if isinstance(control_point_spacing, int): 38 self.control_point_spacing = [control_point_spacing] * 2 39 else: 40 self.control_point_spacing = control_point_spacing 41 assert len(self.control_point_spacing) == 2 42 self.interpolation = interpolation 43 self.flags = dict(interpolation=torch.tensor(self.interpolation.value), sigma=sigma, alpha=alpha) 44 45 def generate_parameters(self, batch_shape): 46 """@private 47 """ 48 assert len(batch_shape) == 5 49 shape = batch_shape[3:] 50 control_shape = tuple( 51 sh // spacing for sh, spacing in zip(shape, self.control_point_spacing) 52 ) 53 deformation_fields = [ 54 np.random.uniform(-1, 1, control_shape), 55 np.random.uniform(-1, 1, control_shape) 56 ] 57 deformation_fields = [ 58 resize(df, shape, order=3)[None] for df in deformation_fields 59 ] 60 noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") 61 noise = torch.from_numpy(noise) 62 return {"noise": noise} 63 64 def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor: 65 """Apply the augmentation to a tensor. 66 67 Args: 68 input: The input tensor. 69 params: The transformation parameters. 70 71 Returns: 72 The transformed tensor. 73 """ 74 assert len(input.shape) == 5 75 if params is None: 76 params = self.generate_parameters(input.shape) 77 self._params = params 78 79 noise = params["noise"] 80 mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest" 81 noise_ch = noise.expand(input.shape[1], -1, -1, -1) 82 input_transformed = [] 83 for i, x in enumerate(torch.unbind(input, dim=0)): 84 x_transformed = kornia.geometry.transform.elastic_transform2d( 85 x, noise_ch, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection" 86 ) 87 input_transformed.append(x_transformed) 88 input_transformed = torch.stack(input_transformed) 89 return input_transformed
Random elastic deformations implemented with kornia.
This transformation can be applied to 3D data, the same deformation is applied to each plane.
Arguments:
- control_point_spacing: The control point spacing for the deformation field.
- sigma: Sigma for smoothing the deformation field.
- alpha: Alpha value.
- interpolation: Interpolation order for applying the transformation to the data.
- p: Probability for applying the transformation.
- keepdim:
- same_on_batch:
26 def __init__( 27 self, 28 control_point_spacing: Union[int, Sequence[int]] = 1, 29 sigma: Tuple[float, float] = (32.0, 32.0), 30 alpha: Tuple[float, float] = (4.0, 4.0), 31 interpolation=kornia.constants.Resample.BILINEAR, 32 p: float = 0.5, 33 keepdim: bool = False, 34 same_on_batch: bool = True, 35 ): 36 super().__init__(p=p, same_on_batch=same_on_batch) 37 if isinstance(control_point_spacing, int): 38 self.control_point_spacing = [control_point_spacing] * 2 39 else: 40 self.control_point_spacing = control_point_spacing 41 assert len(self.control_point_spacing) == 2 42 self.interpolation = interpolation 43 self.flags = dict(interpolation=torch.tensor(self.interpolation.value), sigma=sigma, alpha=alpha)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
92class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D): 93 """Random elastic deformations implemented with kornia. 94 95 Args: 96 control_point_spacing: The control point spacing for the deformation field. 97 sigma: Sigma for smoothing the deformation field. 98 alpha: Alpha value. 99 resample: Interpolation order for applying the transformation to the data. 100 p: Probability for applying the transformation. 101 keepdim: 102 same_on_batch: 103 """ 104 def __init__( 105 self, 106 control_point_spacing: Union[int, Sequence[int]] = 1, 107 sigma: Tuple[float, float] = (32.0, 32.0), 108 alpha: Tuple[float, float] = (4.0, 4.0), 109 resample=kornia.constants.Resample.BILINEAR, 110 p: float = 0.5, 111 keepdim: bool = False, 112 same_on_batch: bool = True, 113 ): 114 super().__init__(p=p, same_on_batch=same_on_batch) 115 if isinstance(control_point_spacing, int): 116 self.control_point_spacing = [control_point_spacing] * 2 117 else: 118 self.control_point_spacing = control_point_spacing 119 assert len(self.control_point_spacing) == 2 120 self.resample = resample 121 self.flags = dict(resample=torch.tensor(self.resample.value), sigma=sigma, alpha=alpha) 122 123 def generate_parameters(self, batch_shape): 124 """@private 125 """ 126 assert len(batch_shape) == 4, f"{len(batch_shape)}" 127 shape = batch_shape[2:] 128 control_shape = tuple(sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)) 129 deformation_fields = [np.random.uniform(-1, 1, control_shape), np.random.uniform(-1, 1, control_shape)] 130 deformation_fields = [resize(df, shape, order=3)[None] for df in deformation_fields] 131 noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") 132 noise = torch.from_numpy(noise) 133 return {"noise": noise} 134 135 def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor: 136 """Apply the augmentation to a tensor. 137 138 Args: 139 input: The input tensor. 140 params: The transformation parameters. 141 142 Returns: 143 The transformed tensor. 144 """ 145 if params is None: 146 params = self.generate_parameters(input.shape) 147 self._params = params 148 noise = params["noise"] 149 mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest" 150 return kornia.geometry.transform.elastic_transform2d( 151 input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection" 152 )
Random elastic deformations implemented with kornia.
Arguments:
- control_point_spacing: The control point spacing for the deformation field.
- sigma: Sigma for smoothing the deformation field.
- alpha: Alpha value.
- resample: Interpolation order for applying the transformation to the data.
- p: Probability for applying the transformation.
- keepdim:
- same_on_batch:
104 def __init__( 105 self, 106 control_point_spacing: Union[int, Sequence[int]] = 1, 107 sigma: Tuple[float, float] = (32.0, 32.0), 108 alpha: Tuple[float, float] = (4.0, 4.0), 109 resample=kornia.constants.Resample.BILINEAR, 110 p: float = 0.5, 111 keepdim: bool = False, 112 same_on_batch: bool = True, 113 ): 114 super().__init__(p=p, same_on_batch=same_on_batch) 115 if isinstance(control_point_spacing, int): 116 self.control_point_spacing = [control_point_spacing] * 2 117 else: 118 self.control_point_spacing = control_point_spacing 119 assert len(self.control_point_spacing) == 2 120 self.resample = resample 121 self.flags = dict(resample=torch.tensor(self.resample.value), sigma=sigma, alpha=alpha)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
157class KorniaAugmentationPipeline(torch.nn.Module): 158 """Pipeline to apply multiple kornia augmentations to data. 159 160 Args: 161 kornia_augmentations: Augmentations implemented with kornia. 162 dtype: The data type of the return data. 163 """ 164 interpolatable_torch_types = [torch.float16, torch.float32, torch.float64] 165 interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")] 166 167 def __init__(self, *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32): 168 super().__init__() 169 self.augmentations = torch.nn.ModuleList(kornia_augmentations) 170 self.dtype = dtype 171 self.halo = self.compute_halo() 172 173 # for now we only add a halo for the random rotation trafos and 174 # also don't compute the halo dynamically based on the input shape 175 def compute_halo(self): 176 """@private 177 """ 178 halo = None 179 for aug in self.augmentations: 180 if isinstance(aug, kornia.augmentation.RandomRotation): 181 halo = [32, 32] 182 if isinstance(aug, kornia.augmentation.RandomRotation3D): 183 halo = [32, 32, 32] 184 return halo 185 186 def is_interpolatable(self, tensor): 187 """@private 188 """ 189 if torch.is_tensor(tensor): 190 return tensor.dtype in self.interpolatable_torch_types 191 else: 192 return tensor.dtype in self.interpolatable_numpy_types 193 194 def transform_tensor(self, augmentation, tensor, interpolatable, params=None): 195 """@private 196 """ 197 interpolating = "interpolation" in getattr(augmentation, "flags", []) 198 if interpolating: 199 resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST") 200 augmentation.flags["interpolation"] = torch.tensor(resampler.value) 201 transformed = augmentation(tensor, params) 202 return transformed, augmentation._params 203 204 def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]: 205 """Apply augmentations to a list of tensors. 206 207 Args: 208 tensors: The input tensors. 209 210 Returns: 211 List of transformed tensors. 212 """ 213 interpolatable = [self.is_interpolatable(tensor) for tensor in tensors] 214 tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors] 215 for aug in self.augmentations: 216 217 t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0]) 218 transformed_tensors = [t0] 219 for tensor, interpolate in zip(tensors[1:], interpolatable[1:]): 220 tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params) 221 transformed_tensors.append(tensor) 222 223 tensors = transformed_tensors 224 return tensors 225 226 def halo(self, shape): 227 """@private 228 """ 229 return self.halo
Pipeline to apply multiple kornia augmentations to data.
Arguments:
- kornia_augmentations: Augmentations implemented with kornia.
- dtype: The data type of the return data.
167 def __init__(self, *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32): 168 super().__init__() 169 self.augmentations = torch.nn.ModuleList(kornia_augmentations) 170 self.dtype = dtype 171 self.halo = self.compute_halo()
Initialize internal Module state, shared by both nn.Module and ScriptModule.
204 def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]: 205 """Apply augmentations to a list of tensors. 206 207 Args: 208 tensors: The input tensors. 209 210 Returns: 211 List of transformed tensors. 212 """ 213 interpolatable = [self.is_interpolatable(tensor) for tensor in tensors] 214 tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors] 215 for aug in self.augmentations: 216 217 t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0]) 218 transformed_tensors = [t0] 219 for tensor, interpolate in zip(tensors[1:], interpolatable[1:]): 220 tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params) 221 transformed_tensors.append(tensor) 222 223 tensors = transformed_tensors 224 return tensors
Apply augmentations to a list of tensors.
Arguments:
- tensors: The input tensors.
Returns:
List of transformed tensors.
All available augmentations and their default parameters.
The default parameters for 2D data.
The default parameters for 3D data.
The default parameters for anisotropic 3D data.
280def get_augmentations(ndim: Union[int, str] = 2, transforms=None, dtype: Union[str, torch.dtype] = torch.float32): 281 """Get augmentation pipeline. 282 283 Args: 284 ndim: The dimensionality for the augmentations. One of 2, 3 or "anisotropic". 285 transforms: The transformations to use for the augmentations. 286 If None, the default augmentations for the given data dimensionality will be used. 287 dtype: The data type of the output data of the augmentation. 288 289 Returns: 290 The augmentation pipeline. 291 """ 292 if transforms is None: 293 assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}" 294 if ndim == 2: 295 transforms = DEFAULT_2D_AUGMENTATIONS 296 elif ndim == 3: 297 transforms = DEFAULT_3D_AUGMENTATIONS 298 else: 299 transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS 300 transforms = [create_augmentation(trafo) for trafo in transforms] 301 assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase) for trafo in transforms) 302 augmentations = KorniaAugmentationPipeline(*transforms, dtype=dtype) 303 return augmentations
Get augmentation pipeline.
Arguments:
- ndim: The dimensionality for the augmentations. One of 2, 3 or "anisotropic".
- transforms: The transformations to use for the augmentations. If None, the default augmentations for the given data dimensionality will be used.
- dtype: The data type of the output data of the augmentation.
Returns:
The augmentation pipeline.