torch_em.transform.augmentation
1import torch 2import numpy as np 3import kornia 4from skimage.transform import resize 5 6from ..util import ensure_tensor 7 8 9class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D): 10 def __init__(self, 11 control_point_spacing=1, 12 sigma=(32.0, 32.0), 13 alpha=(4.0, 4.0), 14 interpolation=kornia.constants.Resample.BILINEAR, 15 p=0.5, 16 keepdim=False, 17 same_on_batch=True): 18 super().__init__(p=p, # keepdim=keepdim, 19 same_on_batch=same_on_batch) 20 if isinstance(control_point_spacing, int): 21 self.control_point_spacing = [control_point_spacing] * 2 22 else: 23 self.control_point_spacing = control_point_spacing 24 assert len(self.control_point_spacing) == 2 25 self.interpolation = interpolation 26 self.flags = dict( 27 interpolation=torch.tensor(self.interpolation.value), 28 sigma=sigma, 29 alpha=alpha 30 ) 31 32 # The same transformation applied to all samples in a batch 33 def generate_parameters(self, batch_shape): 34 assert len(batch_shape) == 5 35 shape = batch_shape[3:] 36 control_shape = tuple( 37 sh // spacing for sh, spacing in zip(shape, self.control_point_spacing) 38 ) 39 deformation_fields = [ 40 np.random.uniform(-1, 1, control_shape), 41 np.random.uniform(-1, 1, control_shape) 42 ] 43 deformation_fields = [ 44 resize(df, shape, order=3)[None] for df in deformation_fields 45 ] 46 noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") 47 noise = torch.from_numpy(noise) 48 return {"noise": noise} 49 50 def __call__(self, input, params=None): 51 assert len(input.shape) == 5 52 if params is None: 53 params = self.generate_parameters(input.shape) 54 self._params = params 55 56 noise = params["noise"] 57 mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest" 58 noise_ch = noise.expand(input.shape[1], -1, -1, -1) 59 input_transformed = [] 60 for i, x in enumerate(torch.unbind(input, dim=0)): 61 x_transformed = kornia.geometry.transform.elastic_transform2d( 62 x, noise_ch, sigma=self.flags["sigma"], 63 alpha=self.flags["alpha"], mode=mode, 64 padding_mode="reflection" 65 ) 66 input_transformed.append(x_transformed) 67 input_transformed = torch.stack(input_transformed) 68 return input_transformed 69 70 71class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D): 72 def __init__(self, 73 control_point_spacing=1, 74 sigma=(4.0, 4.0), 75 alpha=(32.0, 32.0), 76 resample=kornia.constants.Resample.BILINEAR, 77 p=0.5, 78 keepdim=False, 79 same_on_batch=False): 80 super().__init__(p=p, # keepdim=keepdim, 81 same_on_batch=same_on_batch) 82 if isinstance(control_point_spacing, int): 83 self.control_point_spacing = [control_point_spacing] * 2 84 else: 85 self.control_point_spacing = control_point_spacing 86 assert len(self.control_point_spacing) == 2 87 self.resample = resample 88 self.flags = dict( 89 resample=torch.tensor(self.resample.value), 90 sigma=sigma, 91 alpha=alpha 92 ) 93 94 # TODO do we need special treatment for batches, channels > 1? 95 def generate_parameters(self, batch_shape): 96 assert len(batch_shape) == 4, f"{len(batch_shape)}" 97 shape = batch_shape[2:] 98 control_shape = tuple( 99 sh // spacing for sh, spacing in zip(shape, self.control_point_spacing) 100 ) 101 deformation_fields = [ 102 np.random.uniform(-1, 1, control_shape), 103 np.random.uniform(-1, 1, control_shape) 104 ] 105 deformation_fields = [ 106 resize(df, shape, order=3)[None] for df in deformation_fields 107 ] 108 noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") 109 noise = torch.from_numpy(noise) 110 return {"noise": noise} 111 112 def __call__(self, input, params=None): 113 if params is None: 114 params = self.generate_parameters(input.shape) 115 self._params = params 116 noise = params["noise"] 117 mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest" 118 return kornia.geometry.transform.elastic_transform2d( 119 input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, 120 padding_mode="reflection" 121 ) 122 123 124# TODO implement 'require_halo', and estimate the halo from the transformations 125# so that we can load a bigger block and cut it away 126class KorniaAugmentationPipeline(torch.nn.Module): 127 interpolatable_torch_types = [torch.float16, torch.float32, torch.float64] 128 interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")] 129 130 def __init__(self, *kornia_augmentations, dtype=torch.float32): 131 super().__init__() 132 self.augmentations = torch.nn.ModuleList(kornia_augmentations) 133 self.dtype = dtype 134 self.halo = self.compute_halo() 135 136 # for now we only add a halo for the random rotation trafos and 137 # also don't compute the halo dynamically based on the input shape 138 def compute_halo(self): 139 halo = None 140 for aug in self.augmentations: 141 if isinstance(aug, kornia.augmentation.RandomRotation): 142 halo = [32, 32] 143 if isinstance(aug, kornia.augmentation.RandomRotation3D): 144 halo = [32, 32, 32] 145 return halo 146 147 def is_interpolatable(self, tensor): 148 if torch.is_tensor(tensor): 149 return tensor.dtype in self.interpolatable_torch_types 150 else: 151 return tensor.dtype in self.interpolatable_numpy_types 152 153 def transform_tensor(self, augmentation, tensor, interpolatable, params=None): 154 interpolating = "interpolation" in getattr(augmentation, "flags", []) 155 if interpolating: 156 resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST") 157 augmentation.flags["interpolation"] = torch.tensor(resampler.value) 158 transformed = augmentation(tensor, params) 159 return transformed, augmentation._params 160 161 def forward(self, *tensors): 162 interpolatable = [self.is_interpolatable(tensor) for tensor in tensors] 163 tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors] 164 for aug in self.augmentations: 165 166 t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0]) 167 transformed_tensors = [t0] 168 for tensor, interpolate in zip(tensors[1:], interpolatable[1:]): 169 tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params) 170 transformed_tensors.append(tensor) 171 172 tensors = transformed_tensors 173 return tensors 174 175 def halo(self, shape): 176 return self.halo 177 178 179# TODO elastic deformation 180# Try out: 181# - RandomPerspective 182AUGMENTATIONS = { 183 "RandomAffine": {"degrees": 90, "scale": (0.9, 1.1)}, 184 "RandomAffine3D": {"degrees": (90, 90, 90), "scale": (0.0, 1.1)}, 185 "RandomDepthicalFlip3D": {}, 186 "RandomHorizontalFlip": {}, 187 "RandomHorizontalFlip3D": {}, 188 "RandomRotation": {"degrees": 90}, 189 "RandomRotation3D": {"degrees": (90, 90, 90)}, 190 "RandomVerticalFlip": {}, 191 "RandomVerticalFlip3D": {}, 192 "RandomElasticDeformation3D": {"alpha": [5, 5], "sigma": [30, 30]} 193} 194 195 196DEFAULT_2D_AUGMENTATIONS = [ 197 "RandomHorizontalFlip", 198 "RandomVerticalFlip" 199] 200DEFAULT_3D_AUGMENTATIONS = [ 201 "RandomHorizontalFlip3D", 202 "RandomVerticalFlip3D", 203 "RandomDepthicalFlip3D", 204] 205DEFAULT_ANISOTROPIC_AUGMENTATIONS = [ 206 "RandomHorizontalFlip3D", 207 "RandomVerticalFlip3D", 208 "RandomDepthicalFlip3D", 209] 210 211 212def create_augmentation(trafo): 213 assert trafo in dir(kornia.augmentation) or trafo in globals().keys(), f"Transformation {trafo} not defined" 214 if trafo in dir(kornia.augmentation): 215 return getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo]) 216 217 return globals()[trafo](**AUGMENTATIONS[trafo]) 218 219 220def get_augmentations(ndim=2, 221 transforms=None, 222 dtype=torch.float32): 223 if transforms is None: 224 assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}" 225 if ndim == 2: 226 transforms = DEFAULT_2D_AUGMENTATIONS 227 elif ndim == 3: 228 transforms = DEFAULT_3D_AUGMENTATIONS 229 else: 230 transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS 231 transforms = [create_augmentation(trafo) for trafo in transforms] 232 233 assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase) 234 for trafo in transforms) 235 augmentations = KorniaAugmentationPipeline( 236 *transforms, 237 dtype=dtype 238 ) 239 return augmentations
10class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D): 11 def __init__(self, 12 control_point_spacing=1, 13 sigma=(32.0, 32.0), 14 alpha=(4.0, 4.0), 15 interpolation=kornia.constants.Resample.BILINEAR, 16 p=0.5, 17 keepdim=False, 18 same_on_batch=True): 19 super().__init__(p=p, # keepdim=keepdim, 20 same_on_batch=same_on_batch) 21 if isinstance(control_point_spacing, int): 22 self.control_point_spacing = [control_point_spacing] * 2 23 else: 24 self.control_point_spacing = control_point_spacing 25 assert len(self.control_point_spacing) == 2 26 self.interpolation = interpolation 27 self.flags = dict( 28 interpolation=torch.tensor(self.interpolation.value), 29 sigma=sigma, 30 alpha=alpha 31 ) 32 33 # The same transformation applied to all samples in a batch 34 def generate_parameters(self, batch_shape): 35 assert len(batch_shape) == 5 36 shape = batch_shape[3:] 37 control_shape = tuple( 38 sh // spacing for sh, spacing in zip(shape, self.control_point_spacing) 39 ) 40 deformation_fields = [ 41 np.random.uniform(-1, 1, control_shape), 42 np.random.uniform(-1, 1, control_shape) 43 ] 44 deformation_fields = [ 45 resize(df, shape, order=3)[None] for df in deformation_fields 46 ] 47 noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") 48 noise = torch.from_numpy(noise) 49 return {"noise": noise} 50 51 def __call__(self, input, params=None): 52 assert len(input.shape) == 5 53 if params is None: 54 params = self.generate_parameters(input.shape) 55 self._params = params 56 57 noise = params["noise"] 58 mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest" 59 noise_ch = noise.expand(input.shape[1], -1, -1, -1) 60 input_transformed = [] 61 for i, x in enumerate(torch.unbind(input, dim=0)): 62 x_transformed = kornia.geometry.transform.elastic_transform2d( 63 x, noise_ch, sigma=self.flags["sigma"], 64 alpha=self.flags["alpha"], mode=mode, 65 padding_mode="reflection" 66 ) 67 input_transformed.append(x_transformed) 68 input_transformed = torch.stack(input_transformed) 69 return input_transformed
AugmentationBase3D base class for customized augmentation implementations.
Arguments:
- p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch.
- p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise.
- same_on_batch: apply the same transformation across the batch.
11 def __init__(self, 12 control_point_spacing=1, 13 sigma=(32.0, 32.0), 14 alpha=(4.0, 4.0), 15 interpolation=kornia.constants.Resample.BILINEAR, 16 p=0.5, 17 keepdim=False, 18 same_on_batch=True): 19 super().__init__(p=p, # keepdim=keepdim, 20 same_on_batch=same_on_batch) 21 if isinstance(control_point_spacing, int): 22 self.control_point_spacing = [control_point_spacing] * 2 23 else: 24 self.control_point_spacing = control_point_spacing 25 assert len(self.control_point_spacing) == 2 26 self.interpolation = interpolation 27 self.flags = dict( 28 interpolation=torch.tensor(self.interpolation.value), 29 sigma=sigma, 30 alpha=alpha 31 )
Initializes internal Module state, shared by both nn.Module and ScriptModule.
34 def generate_parameters(self, batch_shape): 35 assert len(batch_shape) == 5 36 shape = batch_shape[3:] 37 control_shape = tuple( 38 sh // spacing for sh, spacing in zip(shape, self.control_point_spacing) 39 ) 40 deformation_fields = [ 41 np.random.uniform(-1, 1, control_shape), 42 np.random.uniform(-1, 1, control_shape) 43 ] 44 deformation_fields = [ 45 resize(df, shape, order=3)[None] for df in deformation_fields 46 ] 47 noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") 48 noise = torch.from_numpy(noise) 49 return {"noise": noise}
Inherited Members
- kornia.augmentation._3d.base.AugmentationBase3D
- validate_tensor
- transform_tensor
- identity_matrix
- kornia.augmentation.base._AugmentationBase
- apply_transform
- apply_non_transform
- transform_inputs
- transform_masks
- transform_boxes
- transform_keypoints
- transform_classes
- apply_non_transform_mask
- apply_transform_mask
- apply_non_transform_box
- apply_transform_box
- apply_non_transform_keypoint
- apply_transform_keypoint
- apply_non_transform_class
- apply_transform_class
- apply_func
- kornia.augmentation.base._BasicAugmentationBase
- p
- p_batch
- same_on_batch
- keepdim
- transform_output_tensor
- set_rng_device_and_dtype
- forward_parameters
- forward
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
72class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D): 73 def __init__(self, 74 control_point_spacing=1, 75 sigma=(4.0, 4.0), 76 alpha=(32.0, 32.0), 77 resample=kornia.constants.Resample.BILINEAR, 78 p=0.5, 79 keepdim=False, 80 same_on_batch=False): 81 super().__init__(p=p, # keepdim=keepdim, 82 same_on_batch=same_on_batch) 83 if isinstance(control_point_spacing, int): 84 self.control_point_spacing = [control_point_spacing] * 2 85 else: 86 self.control_point_spacing = control_point_spacing 87 assert len(self.control_point_spacing) == 2 88 self.resample = resample 89 self.flags = dict( 90 resample=torch.tensor(self.resample.value), 91 sigma=sigma, 92 alpha=alpha 93 ) 94 95 # TODO do we need special treatment for batches, channels > 1? 96 def generate_parameters(self, batch_shape): 97 assert len(batch_shape) == 4, f"{len(batch_shape)}" 98 shape = batch_shape[2:] 99 control_shape = tuple( 100 sh // spacing for sh, spacing in zip(shape, self.control_point_spacing) 101 ) 102 deformation_fields = [ 103 np.random.uniform(-1, 1, control_shape), 104 np.random.uniform(-1, 1, control_shape) 105 ] 106 deformation_fields = [ 107 resize(df, shape, order=3)[None] for df in deformation_fields 108 ] 109 noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") 110 noise = torch.from_numpy(noise) 111 return {"noise": noise} 112 113 def __call__(self, input, params=None): 114 if params is None: 115 params = self.generate_parameters(input.shape) 116 self._params = params 117 noise = params["noise"] 118 mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest" 119 return kornia.geometry.transform.elastic_transform2d( 120 input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, 121 padding_mode="reflection" 122 )
AugmentationBase2D base class for customized augmentation implementations.
AugmentationBase2D aims at offering a generic base class for a greater level of customization.
If the subclass contains routined matrix-based transformations, RigidAffineAugmentationBase2D
might be a better fit.
Arguments:
- p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch.
- p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise.
- same_on_batch: apply the same transformation across the batch.
- keepdim: whether to keep the output shape the same as input
True
or broadcast it to the batch formFalse
.
73 def __init__(self, 74 control_point_spacing=1, 75 sigma=(4.0, 4.0), 76 alpha=(32.0, 32.0), 77 resample=kornia.constants.Resample.BILINEAR, 78 p=0.5, 79 keepdim=False, 80 same_on_batch=False): 81 super().__init__(p=p, # keepdim=keepdim, 82 same_on_batch=same_on_batch) 83 if isinstance(control_point_spacing, int): 84 self.control_point_spacing = [control_point_spacing] * 2 85 else: 86 self.control_point_spacing = control_point_spacing 87 assert len(self.control_point_spacing) == 2 88 self.resample = resample 89 self.flags = dict( 90 resample=torch.tensor(self.resample.value), 91 sigma=sigma, 92 alpha=alpha 93 )
Initializes internal Module state, shared by both nn.Module and ScriptModule.
96 def generate_parameters(self, batch_shape): 97 assert len(batch_shape) == 4, f"{len(batch_shape)}" 98 shape = batch_shape[2:] 99 control_shape = tuple( 100 sh // spacing for sh, spacing in zip(shape, self.control_point_spacing) 101 ) 102 deformation_fields = [ 103 np.random.uniform(-1, 1, control_shape), 104 np.random.uniform(-1, 1, control_shape) 105 ] 106 deformation_fields = [ 107 resize(df, shape, order=3)[None] for df in deformation_fields 108 ] 109 noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") 110 noise = torch.from_numpy(noise) 111 return {"noise": noise}
Inherited Members
- kornia.augmentation._2d.base.AugmentationBase2D
- validate_tensor
- transform_tensor
- kornia.augmentation.base._AugmentationBase
- apply_transform
- apply_non_transform
- transform_inputs
- transform_masks
- transform_boxes
- transform_keypoints
- transform_classes
- apply_non_transform_mask
- apply_transform_mask
- apply_non_transform_box
- apply_transform_box
- apply_non_transform_keypoint
- apply_transform_keypoint
- apply_non_transform_class
- apply_transform_class
- apply_func
- kornia.augmentation.base._BasicAugmentationBase
- p
- p_batch
- same_on_batch
- keepdim
- transform_output_tensor
- set_rng_device_and_dtype
- forward_parameters
- forward
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
127class KorniaAugmentationPipeline(torch.nn.Module): 128 interpolatable_torch_types = [torch.float16, torch.float32, torch.float64] 129 interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")] 130 131 def __init__(self, *kornia_augmentations, dtype=torch.float32): 132 super().__init__() 133 self.augmentations = torch.nn.ModuleList(kornia_augmentations) 134 self.dtype = dtype 135 self.halo = self.compute_halo() 136 137 # for now we only add a halo for the random rotation trafos and 138 # also don't compute the halo dynamically based on the input shape 139 def compute_halo(self): 140 halo = None 141 for aug in self.augmentations: 142 if isinstance(aug, kornia.augmentation.RandomRotation): 143 halo = [32, 32] 144 if isinstance(aug, kornia.augmentation.RandomRotation3D): 145 halo = [32, 32, 32] 146 return halo 147 148 def is_interpolatable(self, tensor): 149 if torch.is_tensor(tensor): 150 return tensor.dtype in self.interpolatable_torch_types 151 else: 152 return tensor.dtype in self.interpolatable_numpy_types 153 154 def transform_tensor(self, augmentation, tensor, interpolatable, params=None): 155 interpolating = "interpolation" in getattr(augmentation, "flags", []) 156 if interpolating: 157 resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST") 158 augmentation.flags["interpolation"] = torch.tensor(resampler.value) 159 transformed = augmentation(tensor, params) 160 return transformed, augmentation._params 161 162 def forward(self, *tensors): 163 interpolatable = [self.is_interpolatable(tensor) for tensor in tensors] 164 tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors] 165 for aug in self.augmentations: 166 167 t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0]) 168 transformed_tensors = [t0] 169 for tensor, interpolate in zip(tensors[1:], interpolatable[1:]): 170 tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params) 171 transformed_tensors.append(tensor) 172 173 tensors = transformed_tensors 174 return tensors 175 176 def halo(self, shape): 177 return self.halo
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
131 def __init__(self, *kornia_augmentations, dtype=torch.float32): 132 super().__init__() 133 self.augmentations = torch.nn.ModuleList(kornia_augmentations) 134 self.dtype = dtype 135 self.halo = self.compute_halo()
Initializes internal Module state, shared by both nn.Module and ScriptModule.
154 def transform_tensor(self, augmentation, tensor, interpolatable, params=None): 155 interpolating = "interpolation" in getattr(augmentation, "flags", []) 156 if interpolating: 157 resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST") 158 augmentation.flags["interpolation"] = torch.tensor(resampler.value) 159 transformed = augmentation(tensor, params) 160 return transformed, augmentation._params
162 def forward(self, *tensors): 163 interpolatable = [self.is_interpolatable(tensor) for tensor in tensors] 164 tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors] 165 for aug in self.augmentations: 166 167 t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0]) 168 transformed_tensors = [t0] 169 for tensor, interpolate in zip(tensors[1:], interpolatable[1:]): 170 tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params) 171 transformed_tensors.append(tensor) 172 173 tensors = transformed_tensors 174 return tensors
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
213def create_augmentation(trafo): 214 assert trafo in dir(kornia.augmentation) or trafo in globals().keys(), f"Transformation {trafo} not defined" 215 if trafo in dir(kornia.augmentation): 216 return getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo]) 217 218 return globals()[trafo](**AUGMENTATIONS[trafo])
221def get_augmentations(ndim=2, 222 transforms=None, 223 dtype=torch.float32): 224 if transforms is None: 225 assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}" 226 if ndim == 2: 227 transforms = DEFAULT_2D_AUGMENTATIONS 228 elif ndim == 3: 229 transforms = DEFAULT_3D_AUGMENTATIONS 230 else: 231 transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS 232 transforms = [create_augmentation(trafo) for trafo in transforms] 233 234 assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase) 235 for trafo in transforms) 236 augmentations = KorniaAugmentationPipeline( 237 *transforms, 238 dtype=dtype 239 ) 240 return augmentations