torch_em.transform.augmentation

  1import torch
  2import numpy as np
  3import kornia
  4from skimage.transform import resize
  5
  6from ..util import ensure_tensor
  7
  8
  9class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D):
 10    def __init__(self,
 11                 control_point_spacing=1,
 12                 sigma=(32.0, 32.0),
 13                 alpha=(4.0, 4.0),
 14                 interpolation=kornia.constants.Resample.BILINEAR,
 15                 p=0.5,
 16                 keepdim=False,
 17                 same_on_batch=True):
 18        super().__init__(p=p,  # keepdim=keepdim,
 19                         same_on_batch=same_on_batch)
 20        if isinstance(control_point_spacing, int):
 21            self.control_point_spacing = [control_point_spacing] * 2
 22        else:
 23            self.control_point_spacing = control_point_spacing
 24        assert len(self.control_point_spacing) == 2
 25        self.interpolation = interpolation
 26        self.flags = dict(
 27            interpolation=torch.tensor(self.interpolation.value),
 28            sigma=sigma,
 29            alpha=alpha
 30        )
 31
 32    # The same transformation applied to all samples in a batch
 33    def generate_parameters(self, batch_shape):
 34        assert len(batch_shape) == 5
 35        shape = batch_shape[3:]
 36        control_shape = tuple(
 37            sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
 38        )
 39        deformation_fields = [
 40            np.random.uniform(-1, 1, control_shape),
 41            np.random.uniform(-1, 1, control_shape)
 42        ]
 43        deformation_fields = [
 44            resize(df, shape, order=3)[None] for df in deformation_fields
 45        ]
 46        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
 47        noise = torch.from_numpy(noise)
 48        return {"noise": noise}
 49
 50    def __call__(self, input, params=None):
 51        assert len(input.shape) == 5
 52        if params is None:
 53            params = self.generate_parameters(input.shape)
 54            self._params = params
 55
 56        noise = params["noise"]
 57        mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest"
 58        noise_ch = noise.expand(input.shape[1], -1, -1, -1)
 59        input_transformed = []
 60        for i, x in enumerate(torch.unbind(input, dim=0)):
 61            x_transformed = kornia.geometry.transform.elastic_transform2d(
 62                            x, noise_ch, sigma=self.flags["sigma"],
 63                            alpha=self.flags["alpha"], mode=mode,
 64                            padding_mode="reflection"
 65                            )
 66            input_transformed.append(x_transformed)
 67        input_transformed = torch.stack(input_transformed)
 68        return input_transformed
 69
 70
 71class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D):
 72    def __init__(self,
 73                 control_point_spacing=1,
 74                 sigma=(4.0, 4.0),
 75                 alpha=(32.0, 32.0),
 76                 resample=kornia.constants.Resample.BILINEAR,
 77                 p=0.5,
 78                 keepdim=False,
 79                 same_on_batch=False):
 80        super().__init__(p=p,  # keepdim=keepdim,
 81                         same_on_batch=same_on_batch)
 82        if isinstance(control_point_spacing, int):
 83            self.control_point_spacing = [control_point_spacing] * 2
 84        else:
 85            self.control_point_spacing = control_point_spacing
 86        assert len(self.control_point_spacing) == 2
 87        self.resample = resample
 88        self.flags = dict(
 89            resample=torch.tensor(self.resample.value),
 90            sigma=sigma,
 91            alpha=alpha
 92        )
 93
 94    # TODO do we need special treatment for batches, channels > 1?
 95    def generate_parameters(self, batch_shape):
 96        assert len(batch_shape) == 4, f"{len(batch_shape)}"
 97        shape = batch_shape[2:]
 98        control_shape = tuple(
 99            sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
100        )
101        deformation_fields = [
102            np.random.uniform(-1, 1, control_shape),
103            np.random.uniform(-1, 1, control_shape)
104        ]
105        deformation_fields = [
106            resize(df, shape, order=3)[None] for df in deformation_fields
107        ]
108        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
109        noise = torch.from_numpy(noise)
110        return {"noise": noise}
111
112    def __call__(self, input, params=None):
113        if params is None:
114            params = self.generate_parameters(input.shape)
115            self._params = params
116        noise = params["noise"]
117        mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest"
118        return kornia.geometry.transform.elastic_transform2d(
119            input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode,
120            padding_mode="reflection"
121        )
122
123
124# TODO implement 'require_halo', and estimate the halo from the transformations
125# so that we can load a bigger block and cut it away
126class KorniaAugmentationPipeline(torch.nn.Module):
127    interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
128    interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")]
129
130    def __init__(self, *kornia_augmentations, dtype=torch.float32):
131        super().__init__()
132        self.augmentations = torch.nn.ModuleList(kornia_augmentations)
133        self.dtype = dtype
134        self.halo = self.compute_halo()
135
136    # for now we only add a halo for the random rotation trafos and
137    # also don't compute the halo dynamically based on the input shape
138    def compute_halo(self):
139        halo = None
140        for aug in self.augmentations:
141            if isinstance(aug, kornia.augmentation.RandomRotation):
142                halo = [32, 32]
143            if isinstance(aug, kornia.augmentation.RandomRotation3D):
144                halo = [32, 32, 32]
145        return halo
146
147    def is_interpolatable(self, tensor):
148        if torch.is_tensor(tensor):
149            return tensor.dtype in self.interpolatable_torch_types
150        else:
151            return tensor.dtype in self.interpolatable_numpy_types
152
153    def transform_tensor(self, augmentation, tensor, interpolatable, params=None):
154        interpolating = "interpolation" in getattr(augmentation, "flags", [])
155        if interpolating:
156            resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST")
157            augmentation.flags["interpolation"] = torch.tensor(resampler.value)
158        transformed = augmentation(tensor, params)
159        return transformed, augmentation._params
160
161    def forward(self, *tensors):
162        interpolatable = [self.is_interpolatable(tensor) for tensor in tensors]
163        tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors]
164        for aug in self.augmentations:
165
166            t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0])
167            transformed_tensors = [t0]
168            for tensor, interpolate in zip(tensors[1:], interpolatable[1:]):
169                tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params)
170                transformed_tensors.append(tensor)
171
172            tensors = transformed_tensors
173        return tensors
174
175    def halo(self, shape):
176        return self.halo
177
178
179# TODO elastic deformation
180# Try out:
181# - RandomPerspective
182AUGMENTATIONS = {
183    "RandomAffine": {"degrees": 90, "scale": (0.9, 1.1)},
184    "RandomAffine3D": {"degrees": (90, 90, 90), "scale": (0.0, 1.1)},
185    "RandomDepthicalFlip3D": {},
186    "RandomHorizontalFlip": {},
187    "RandomHorizontalFlip3D": {},
188    "RandomRotation": {"degrees": 90},
189    "RandomRotation3D": {"degrees": (90, 90, 90)},
190    "RandomVerticalFlip": {},
191    "RandomVerticalFlip3D": {},
192    "RandomElasticDeformation3D": {"alpha": [5, 5], "sigma": [30, 30]}
193}
194
195
196DEFAULT_2D_AUGMENTATIONS = [
197    "RandomHorizontalFlip",
198    "RandomVerticalFlip"
199]
200DEFAULT_3D_AUGMENTATIONS = [
201    "RandomHorizontalFlip3D",
202    "RandomVerticalFlip3D",
203    "RandomDepthicalFlip3D",
204]
205DEFAULT_ANISOTROPIC_AUGMENTATIONS = [
206    "RandomHorizontalFlip3D",
207    "RandomVerticalFlip3D",
208    "RandomDepthicalFlip3D",
209]
210
211
212def create_augmentation(trafo):
213    assert trafo in dir(kornia.augmentation) or trafo in globals().keys(), f"Transformation {trafo} not defined"
214    if trafo in dir(kornia.augmentation):
215        return getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo])
216
217    return globals()[trafo](**AUGMENTATIONS[trafo])
218
219
220def get_augmentations(ndim=2,
221                      transforms=None,
222                      dtype=torch.float32):
223    if transforms is None:
224        assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}"
225        if ndim == 2:
226            transforms = DEFAULT_2D_AUGMENTATIONS
227        elif ndim == 3:
228            transforms = DEFAULT_3D_AUGMENTATIONS
229        else:
230            transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS
231    transforms = [create_augmentation(trafo) for trafo in transforms]
232
233    assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase)
234               for trafo in transforms)
235    augmentations = KorniaAugmentationPipeline(
236        *transforms,
237        dtype=dtype
238    )
239    return augmentations
class RandomElasticDeformationStacked(kornia.augmentation._3d.base.AugmentationBase3D):
10class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D):
11    def __init__(self,
12                 control_point_spacing=1,
13                 sigma=(32.0, 32.0),
14                 alpha=(4.0, 4.0),
15                 interpolation=kornia.constants.Resample.BILINEAR,
16                 p=0.5,
17                 keepdim=False,
18                 same_on_batch=True):
19        super().__init__(p=p,  # keepdim=keepdim,
20                         same_on_batch=same_on_batch)
21        if isinstance(control_point_spacing, int):
22            self.control_point_spacing = [control_point_spacing] * 2
23        else:
24            self.control_point_spacing = control_point_spacing
25        assert len(self.control_point_spacing) == 2
26        self.interpolation = interpolation
27        self.flags = dict(
28            interpolation=torch.tensor(self.interpolation.value),
29            sigma=sigma,
30            alpha=alpha
31        )
32
33    # The same transformation applied to all samples in a batch
34    def generate_parameters(self, batch_shape):
35        assert len(batch_shape) == 5
36        shape = batch_shape[3:]
37        control_shape = tuple(
38            sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
39        )
40        deformation_fields = [
41            np.random.uniform(-1, 1, control_shape),
42            np.random.uniform(-1, 1, control_shape)
43        ]
44        deformation_fields = [
45            resize(df, shape, order=3)[None] for df in deformation_fields
46        ]
47        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
48        noise = torch.from_numpy(noise)
49        return {"noise": noise}
50
51    def __call__(self, input, params=None):
52        assert len(input.shape) == 5
53        if params is None:
54            params = self.generate_parameters(input.shape)
55            self._params = params
56
57        noise = params["noise"]
58        mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest"
59        noise_ch = noise.expand(input.shape[1], -1, -1, -1)
60        input_transformed = []
61        for i, x in enumerate(torch.unbind(input, dim=0)):
62            x_transformed = kornia.geometry.transform.elastic_transform2d(
63                            x, noise_ch, sigma=self.flags["sigma"],
64                            alpha=self.flags["alpha"], mode=mode,
65                            padding_mode="reflection"
66                            )
67            input_transformed.append(x_transformed)
68        input_transformed = torch.stack(input_transformed)
69        return input_transformed

AugmentationBase3D base class for customized augmentation implementations.

Arguments:
  • p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch.
  • p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise.
  • same_on_batch: apply the same transformation across the batch.
RandomElasticDeformationStacked( control_point_spacing=1, sigma=(32.0, 32.0), alpha=(4.0, 4.0), interpolation=<Resample.BILINEAR: 1>, p=0.5, keepdim=False, same_on_batch=True)
11    def __init__(self,
12                 control_point_spacing=1,
13                 sigma=(32.0, 32.0),
14                 alpha=(4.0, 4.0),
15                 interpolation=kornia.constants.Resample.BILINEAR,
16                 p=0.5,
17                 keepdim=False,
18                 same_on_batch=True):
19        super().__init__(p=p,  # keepdim=keepdim,
20                         same_on_batch=same_on_batch)
21        if isinstance(control_point_spacing, int):
22            self.control_point_spacing = [control_point_spacing] * 2
23        else:
24            self.control_point_spacing = control_point_spacing
25        assert len(self.control_point_spacing) == 2
26        self.interpolation = interpolation
27        self.flags = dict(
28            interpolation=torch.tensor(self.interpolation.value),
29            sigma=sigma,
30            alpha=alpha
31        )

Initializes internal Module state, shared by both nn.Module and ScriptModule.

interpolation
flags
def generate_parameters(self, batch_shape):
34    def generate_parameters(self, batch_shape):
35        assert len(batch_shape) == 5
36        shape = batch_shape[3:]
37        control_shape = tuple(
38            sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
39        )
40        deformation_fields = [
41            np.random.uniform(-1, 1, control_shape),
42            np.random.uniform(-1, 1, control_shape)
43        ]
44        deformation_fields = [
45            resize(df, shape, order=3)[None] for df in deformation_fields
46        ]
47        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
48        noise = torch.from_numpy(noise)
49        return {"noise": noise}
Inherited Members
kornia.augmentation._3d.base.AugmentationBase3D
validate_tensor
transform_tensor
identity_matrix
kornia.augmentation.base._AugmentationBase
apply_transform
apply_non_transform
transform_inputs
transform_masks
transform_boxes
transform_keypoints
transform_classes
apply_non_transform_mask
apply_transform_mask
apply_non_transform_box
apply_transform_box
apply_non_transform_keypoint
apply_transform_keypoint
apply_non_transform_class
apply_transform_class
apply_func
kornia.augmentation.base._BasicAugmentationBase
p
p_batch
same_on_batch
keepdim
transform_output_tensor
set_rng_device_and_dtype
forward_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class RandomElasticDeformation(kornia.augmentation._2d.base.AugmentationBase2D):
 72class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D):
 73    def __init__(self,
 74                 control_point_spacing=1,
 75                 sigma=(4.0, 4.0),
 76                 alpha=(32.0, 32.0),
 77                 resample=kornia.constants.Resample.BILINEAR,
 78                 p=0.5,
 79                 keepdim=False,
 80                 same_on_batch=False):
 81        super().__init__(p=p,  # keepdim=keepdim,
 82                         same_on_batch=same_on_batch)
 83        if isinstance(control_point_spacing, int):
 84            self.control_point_spacing = [control_point_spacing] * 2
 85        else:
 86            self.control_point_spacing = control_point_spacing
 87        assert len(self.control_point_spacing) == 2
 88        self.resample = resample
 89        self.flags = dict(
 90            resample=torch.tensor(self.resample.value),
 91            sigma=sigma,
 92            alpha=alpha
 93        )
 94
 95    # TODO do we need special treatment for batches, channels > 1?
 96    def generate_parameters(self, batch_shape):
 97        assert len(batch_shape) == 4, f"{len(batch_shape)}"
 98        shape = batch_shape[2:]
 99        control_shape = tuple(
100            sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
101        )
102        deformation_fields = [
103            np.random.uniform(-1, 1, control_shape),
104            np.random.uniform(-1, 1, control_shape)
105        ]
106        deformation_fields = [
107            resize(df, shape, order=3)[None] for df in deformation_fields
108        ]
109        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
110        noise = torch.from_numpy(noise)
111        return {"noise": noise}
112
113    def __call__(self, input, params=None):
114        if params is None:
115            params = self.generate_parameters(input.shape)
116            self._params = params
117        noise = params["noise"]
118        mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest"
119        return kornia.geometry.transform.elastic_transform2d(
120            input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode,
121            padding_mode="reflection"
122        )

AugmentationBase2D base class for customized augmentation implementations.

AugmentationBase2D aims at offering a generic base class for a greater level of customization. If the subclass contains routined matrix-based transformations, RigidAffineAugmentationBase2D might be a better fit.

Arguments:
  • p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch.
  • p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise.
  • same_on_batch: apply the same transformation across the batch.
  • keepdim: whether to keep the output shape the same as input True or broadcast it to the batch form False.
RandomElasticDeformation( control_point_spacing=1, sigma=(4.0, 4.0), alpha=(32.0, 32.0), resample=<Resample.BILINEAR: 1>, p=0.5, keepdim=False, same_on_batch=False)
73    def __init__(self,
74                 control_point_spacing=1,
75                 sigma=(4.0, 4.0),
76                 alpha=(32.0, 32.0),
77                 resample=kornia.constants.Resample.BILINEAR,
78                 p=0.5,
79                 keepdim=False,
80                 same_on_batch=False):
81        super().__init__(p=p,  # keepdim=keepdim,
82                         same_on_batch=same_on_batch)
83        if isinstance(control_point_spacing, int):
84            self.control_point_spacing = [control_point_spacing] * 2
85        else:
86            self.control_point_spacing = control_point_spacing
87        assert len(self.control_point_spacing) == 2
88        self.resample = resample
89        self.flags = dict(
90            resample=torch.tensor(self.resample.value),
91            sigma=sigma,
92            alpha=alpha
93        )

Initializes internal Module state, shared by both nn.Module and ScriptModule.

resample
flags
def generate_parameters(self, batch_shape):
 96    def generate_parameters(self, batch_shape):
 97        assert len(batch_shape) == 4, f"{len(batch_shape)}"
 98        shape = batch_shape[2:]
 99        control_shape = tuple(
100            sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
101        )
102        deformation_fields = [
103            np.random.uniform(-1, 1, control_shape),
104            np.random.uniform(-1, 1, control_shape)
105        ]
106        deformation_fields = [
107            resize(df, shape, order=3)[None] for df in deformation_fields
108        ]
109        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
110        noise = torch.from_numpy(noise)
111        return {"noise": noise}
Inherited Members
kornia.augmentation._2d.base.AugmentationBase2D
validate_tensor
transform_tensor
kornia.augmentation.base._AugmentationBase
apply_transform
apply_non_transform
transform_inputs
transform_masks
transform_boxes
transform_keypoints
transform_classes
apply_non_transform_mask
apply_transform_mask
apply_non_transform_box
apply_transform_box
apply_non_transform_keypoint
apply_transform_keypoint
apply_non_transform_class
apply_transform_class
apply_func
kornia.augmentation.base._BasicAugmentationBase
p
p_batch
same_on_batch
keepdim
transform_output_tensor
set_rng_device_and_dtype
forward_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class KorniaAugmentationPipeline(torch.nn.modules.module.Module):
127class KorniaAugmentationPipeline(torch.nn.Module):
128    interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
129    interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")]
130
131    def __init__(self, *kornia_augmentations, dtype=torch.float32):
132        super().__init__()
133        self.augmentations = torch.nn.ModuleList(kornia_augmentations)
134        self.dtype = dtype
135        self.halo = self.compute_halo()
136
137    # for now we only add a halo for the random rotation trafos and
138    # also don't compute the halo dynamically based on the input shape
139    def compute_halo(self):
140        halo = None
141        for aug in self.augmentations:
142            if isinstance(aug, kornia.augmentation.RandomRotation):
143                halo = [32, 32]
144            if isinstance(aug, kornia.augmentation.RandomRotation3D):
145                halo = [32, 32, 32]
146        return halo
147
148    def is_interpolatable(self, tensor):
149        if torch.is_tensor(tensor):
150            return tensor.dtype in self.interpolatable_torch_types
151        else:
152            return tensor.dtype in self.interpolatable_numpy_types
153
154    def transform_tensor(self, augmentation, tensor, interpolatable, params=None):
155        interpolating = "interpolation" in getattr(augmentation, "flags", [])
156        if interpolating:
157            resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST")
158            augmentation.flags["interpolation"] = torch.tensor(resampler.value)
159        transformed = augmentation(tensor, params)
160        return transformed, augmentation._params
161
162    def forward(self, *tensors):
163        interpolatable = [self.is_interpolatable(tensor) for tensor in tensors]
164        tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors]
165        for aug in self.augmentations:
166
167            t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0])
168            transformed_tensors = [t0]
169            for tensor, interpolate in zip(tensors[1:], interpolatable[1:]):
170                tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params)
171                transformed_tensors.append(tensor)
172
173            tensors = transformed_tensors
174        return tensors
175
176    def halo(self, shape):
177        return self.halo

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

KorniaAugmentationPipeline(*kornia_augmentations, dtype=torch.float32)
131    def __init__(self, *kornia_augmentations, dtype=torch.float32):
132        super().__init__()
133        self.augmentations = torch.nn.ModuleList(kornia_augmentations)
134        self.dtype = dtype
135        self.halo = self.compute_halo()

Initializes internal Module state, shared by both nn.Module and ScriptModule.

interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
interpolatable_numpy_types = [dtype('float32'), dtype('float64')]
augmentations
dtype
def halo(self, shape):
176    def halo(self, shape):
177        return self.halo
def compute_halo(self):
139    def compute_halo(self):
140        halo = None
141        for aug in self.augmentations:
142            if isinstance(aug, kornia.augmentation.RandomRotation):
143                halo = [32, 32]
144            if isinstance(aug, kornia.augmentation.RandomRotation3D):
145                halo = [32, 32, 32]
146        return halo
def is_interpolatable(self, tensor):
148    def is_interpolatable(self, tensor):
149        if torch.is_tensor(tensor):
150            return tensor.dtype in self.interpolatable_torch_types
151        else:
152            return tensor.dtype in self.interpolatable_numpy_types
def transform_tensor(self, augmentation, tensor, interpolatable, params=None):
154    def transform_tensor(self, augmentation, tensor, interpolatable, params=None):
155        interpolating = "interpolation" in getattr(augmentation, "flags", [])
156        if interpolating:
157            resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST")
158            augmentation.flags["interpolation"] = torch.tensor(resampler.value)
159        transformed = augmentation(tensor, params)
160        return transformed, augmentation._params
def forward(self, *tensors):
162    def forward(self, *tensors):
163        interpolatable = [self.is_interpolatable(tensor) for tensor in tensors]
164        tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors]
165        for aug in self.augmentations:
166
167            t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0])
168            transformed_tensors = [t0]
169            for tensor, interpolate in zip(tensors[1:], interpolatable[1:]):
170                tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params)
171                transformed_tensors.append(tensor)
172
173            tensors = transformed_tensors
174        return tensors

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
AUGMENTATIONS = {'RandomAffine': {'degrees': 90, 'scale': (0.9, 1.1)}, 'RandomAffine3D': {'degrees': (90, 90, 90), 'scale': (0.0, 1.1)}, 'RandomDepthicalFlip3D': {}, 'RandomHorizontalFlip': {}, 'RandomHorizontalFlip3D': {}, 'RandomRotation': {'degrees': 90}, 'RandomRotation3D': {'degrees': (90, 90, 90)}, 'RandomVerticalFlip': {}, 'RandomVerticalFlip3D': {}, 'RandomElasticDeformation3D': {'alpha': [5, 5], 'sigma': [30, 30]}}
DEFAULT_2D_AUGMENTATIONS = ['RandomHorizontalFlip', 'RandomVerticalFlip']
DEFAULT_3D_AUGMENTATIONS = ['RandomHorizontalFlip3D', 'RandomVerticalFlip3D', 'RandomDepthicalFlip3D']
DEFAULT_ANISOTROPIC_AUGMENTATIONS = ['RandomHorizontalFlip3D', 'RandomVerticalFlip3D', 'RandomDepthicalFlip3D']
def create_augmentation(trafo):
213def create_augmentation(trafo):
214    assert trafo in dir(kornia.augmentation) or trafo in globals().keys(), f"Transformation {trafo} not defined"
215    if trafo in dir(kornia.augmentation):
216        return getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo])
217
218    return globals()[trafo](**AUGMENTATIONS[trafo])
def get_augmentations(ndim=2, transforms=None, dtype=torch.float32):
221def get_augmentations(ndim=2,
222                      transforms=None,
223                      dtype=torch.float32):
224    if transforms is None:
225        assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}"
226        if ndim == 2:
227            transforms = DEFAULT_2D_AUGMENTATIONS
228        elif ndim == 3:
229            transforms = DEFAULT_3D_AUGMENTATIONS
230        else:
231            transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS
232    transforms = [create_augmentation(trafo) for trafo in transforms]
233
234    assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase)
235               for trafo in transforms)
236    augmentations = KorniaAugmentationPipeline(
237        *transforms,
238        dtype=dtype
239    )
240    return augmentations