
  1from typing import Dict, List, Optional, Sequence, Tuple, Union
  3import kornia
  4import numpy as np
  5import torch
  6from skimage.transform import resize
  8from ..util import ensure_tensor
 11class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D):
 12    """Random elastic deformations implemented with kornia.
 14    This transformation can be applied to 3D data, the same deformation is applied to each plane.
 16    Args:
 17        control_point_spacing: The control point spacing for the deformation field.
 18        sigma: Sigma for smoothing the deformation field.
 19        alpha: Alpha value.
 20        interpolation: Interpolation order for applying the transformation to the data.
 21        p: Probability for applying the transformation.
 22        keepdim:
 23        same_on_batch:
 24    """
 25    def __init__(
 26        self,
 27        control_point_spacing: Union[int, Sequence[int]] = 1,
 28        sigma: Tuple[float, float] = (32.0, 32.0),
 29        alpha: Tuple[float, float] = (4.0, 4.0),
 30        interpolation=kornia.constants.Resample.BILINEAR,
 31        p: float = 0.5,
 32        keepdim: bool = False,
 33        same_on_batch: bool = True,
 34    ):
 35        super().__init__(p=p, same_on_batch=same_on_batch)
 36        if isinstance(control_point_spacing, int):
 37            self.control_point_spacing = [control_point_spacing] * 2
 38        else:
 39            self.control_point_spacing = control_point_spacing
 40        assert len(self.control_point_spacing) == 2
 41        self.interpolation = interpolation
 42        self.flags = dict(interpolation=torch.tensor(self.interpolation.value), sigma=sigma, alpha=alpha)
 44    def generate_parameters(self, batch_shape):
 45        """@private
 46        """
 47        assert len(batch_shape) == 5
 48        shape = batch_shape[3:]
 49        control_shape = tuple(
 50            sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
 51        )
 52        deformation_fields = [
 53            np.random.uniform(-1, 1, control_shape),
 54            np.random.uniform(-1, 1, control_shape)
 55        ]
 56        deformation_fields = [
 57            resize(df, shape, order=3)[None] for df in deformation_fields
 58        ]
 59        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
 60        noise = torch.from_numpy(noise)
 61        return {"noise": noise}
 63    def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor:
 64        """Apply the augmentation to a tensor.
 66        Args:
 67            input: The input tensor.
 68            params: The transformation parameters.
 70        Returns:
 71            The transformed tensor.
 72        """
 73        assert len(input.shape) == 5
 74        if params is None:
 75            params = self.generate_parameters(input.shape)
 76            self._params = params
 78        noise = params["noise"]
 79        mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest"
 80        noise_ch = noise.expand(input.shape[1], -1, -1, -1)
 81        input_transformed = []
 82        for i, x in enumerate(torch.unbind(input, dim=0)):
 83            x_transformed = kornia.geometry.transform.elastic_transform2d(
 84                x, noise_ch, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection"
 85            )
 86            input_transformed.append(x_transformed)
 87        input_transformed = torch.stack(input_transformed)
 88        return input_transformed
 91class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D):
 92    """Random elastic deformations implemented with kornia.
 94    Args:
 95        control_point_spacing: The control point spacing for the deformation field.
 96        sigma: Sigma for smoothing the deformation field.
 97        alpha: Alpha value.
 98        resample: Interpolation order for applying the transformation to the data.
 99        p: Probability for applying the transformation.
100        keepdim:
101        same_on_batch:
102    """
103    def __init__(
104        self,
105        control_point_spacing: Union[int, Sequence[int]] = 1,
106        sigma: Tuple[float, float] = (32.0, 32.0),
107        alpha: Tuple[float, float] = (4.0, 4.0),
108        resample=kornia.constants.Resample.BILINEAR,
109        p: float = 0.5,
110        keepdim: bool = False,
111        same_on_batch: bool = True,
112    ):
113        super().__init__(p=p, same_on_batch=same_on_batch)
114        if isinstance(control_point_spacing, int):
115            self.control_point_spacing = [control_point_spacing] * 2
116        else:
117            self.control_point_spacing = control_point_spacing
118        assert len(self.control_point_spacing) == 2
119        self.resample = resample
120        self.flags = dict(resample=torch.tensor(self.resample.value), sigma=sigma, alpha=alpha)
122    def generate_parameters(self, batch_shape):
123        """@private
124        """
125        assert len(batch_shape) == 4, f"{len(batch_shape)}"
126        shape = batch_shape[2:]
127        control_shape = tuple(sh // spacing for sh, spacing in zip(shape, self.control_point_spacing))
128        deformation_fields = [np.random.uniform(-1, 1, control_shape), np.random.uniform(-1, 1, control_shape)]
129        deformation_fields = [resize(df, shape, order=3)[None] for df in deformation_fields]
130        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
131        noise = torch.from_numpy(noise)
132        return {"noise": noise}
134    def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor:
135        """Apply the augmentation to a tensor.
137        Args:
138            input: The input tensor.
139            params: The transformation parameters.
141        Returns:
142            The transformed tensor.
143        """
144        if params is None:
145            params = self.generate_parameters(input.shape)
146            self._params = params
147        noise = params["noise"]
148        mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest"
149        return kornia.geometry.transform.elastic_transform2d(
150            input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection"
151        )
154# TODO: Implement 'require_halo', and estimate the halo from the transformations
155# so that we can load a bigger block and cut it away.
156class KorniaAugmentationPipeline(torch.nn.Module):
157    """Pipeline to apply multiple kornia augmentations to data.
159    Args:
160        kornia_augmentations: Augmentations implemented with kornia.
161        dtype: The data type of the return data.
162    """
163    interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
164    interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")]
166    def __init__(self, *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32):
167        super().__init__()
168        self.augmentations = torch.nn.ModuleList(kornia_augmentations)
169        self.dtype = dtype
170        self.halo = self.compute_halo()
172    # for now we only add a halo for the random rotation trafos and
173    # also don't compute the halo dynamically based on the input shape
174    def compute_halo(self):
175        """@private
176        """
177        halo = None
178        for aug in self.augmentations:
179            if isinstance(aug, kornia.augmentation.RandomRotation):
180                halo = [32, 32]
181            if isinstance(aug, kornia.augmentation.RandomRotation3D):
182                halo = [32, 32, 32]
183        return halo
185    def is_interpolatable(self, tensor):
186        """@private
187        """
188        if torch.is_tensor(tensor):
189            return tensor.dtype in self.interpolatable_torch_types
190        else:
191            return tensor.dtype in self.interpolatable_numpy_types
193    def transform_tensor(self, augmentation, tensor, interpolatable, params=None):
194        """@private
195        """
196        interpolating = "interpolation" in getattr(augmentation, "flags", [])
197        if interpolating:
198            resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST")
199            augmentation.flags["interpolation"] = torch.tensor(resampler.value)
200        transformed = augmentation(tensor, params)
201        return transformed, augmentation._params
203    def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]:
204        """Apply augmentations to a list of tensors.
206        Args:
207            tensors: The input tensors.
209        Returns:
210            List of transformed tensors.
211        """
212        interpolatable = [self.is_interpolatable(tensor) for tensor in tensors]
213        tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors]
214        for aug in self.augmentations:
216            t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0])
217            transformed_tensors = [t0]
218            for tensor, interpolate in zip(tensors[1:], interpolatable[1:]):
219                tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params)
220                transformed_tensors.append(tensor)
222            tensors = transformed_tensors
223        return tensors
225    def halo(self, shape):
226        """@private
227        """
228        return self.halo
231# Try out:
232# - RandomPerspective
234    "RandomAffine": {"degrees": 90, "scale": (0.9, 1.1)},
235    "RandomAffine3D": {"degrees": (90, 90, 90), "scale": (0.0, 1.1)},
236    "RandomDepthicalFlip3D": {},
237    "RandomHorizontalFlip": {},
238    "RandomHorizontalFlip3D": {},
239    "RandomRotation": {"degrees": 90},
240    "RandomRotation3D": {"degrees": (90, 90, 90)},
241    "RandomVerticalFlip": {},
242    "RandomVerticalFlip3D": {},
243    "RandomElasticDeformation3D": {"alpha": [5, 5], "sigma": [30, 30]}
245"""All available augmentations and their default parameters.
249    "RandomHorizontalFlip",
250    "RandomVerticalFlip"
252"""The default parameters for 2D data.
255    "RandomHorizontalFlip3D",
256    "RandomVerticalFlip3D",
257    "RandomDepthicalFlip3D",
259"""The default parameters for 3D data.
262    "RandomHorizontalFlip3D",
263    "RandomVerticalFlip3D",
264    "RandomDepthicalFlip3D",
266"""The default parameters for anisotropic 3D data.
270def create_augmentation(trafo):
271    """@private
272    """
273    assert trafo in dir(kornia.augmentation) or trafo in globals().keys(), f"Transformation {trafo} not defined"
274    if trafo in dir(kornia.augmentation):
275        return getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo])
276    return globals()[trafo](**AUGMENTATIONS[trafo])
279def get_augmentations(ndim: Union[int, str] = 2, transforms=None, dtype: Union[str, torch.dtype] = torch.float32):
280    """Get augmentation pipeline.
282    Args:
283        ndim: The dimensionality for the augmentations. One of 2, 3 or "anisotropic".
284        transforms: The transformations to use for the augmentations.
285            If None, the default augmentations for the given data dimensionality will be used.
286        dtype: The data type of the output data of the augmentation.
288    Returns:
289        The augmentation pipeline.
290    """
291    if transforms is None:
292        assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}"
293        if ndim == 2:
294            transforms = DEFAULT_2D_AUGMENTATIONS
295        elif ndim == 3:
296            transforms = DEFAULT_3D_AUGMENTATIONS
297        else:
299    transforms = [create_augmentation(trafo) for trafo in transforms]
300    assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase) for trafo in transforms)
301    augmentations = KorniaAugmentationPipeline(*transforms, dtype=dtype)
302    return augmentations
class RandomElasticDeformationStacked(kornia.augmentation._3d.base.AugmentationBase3D):
12class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D):
13    """Random elastic deformations implemented with kornia.
15    This transformation can be applied to 3D data, the same deformation is applied to each plane.
17    Args:
18        control_point_spacing: The control point spacing for the deformation field.
19        sigma: Sigma for smoothing the deformation field.
20        alpha: Alpha value.
21        interpolation: Interpolation order for applying the transformation to the data.
22        p: Probability for applying the transformation.
23        keepdim:
24        same_on_batch:
25    """
26    def __init__(
27        self,
28        control_point_spacing: Union[int, Sequence[int]] = 1,
29        sigma: Tuple[float, float] = (32.0, 32.0),
30        alpha: Tuple[float, float] = (4.0, 4.0),
31        interpolation=kornia.constants.Resample.BILINEAR,
32        p: float = 0.5,
33        keepdim: bool = False,
34        same_on_batch: bool = True,
35    ):
36        super().__init__(p=p, same_on_batch=same_on_batch)
37        if isinstance(control_point_spacing, int):
38            self.control_point_spacing = [control_point_spacing] * 2
39        else:
40            self.control_point_spacing = control_point_spacing
41        assert len(self.control_point_spacing) == 2
42        self.interpolation = interpolation
43        self.flags = dict(interpolation=torch.tensor(self.interpolation.value), sigma=sigma, alpha=alpha)
45    def generate_parameters(self, batch_shape):
46        """@private
47        """
48        assert len(batch_shape) == 5
49        shape = batch_shape[3:]
50        control_shape = tuple(
51            sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
52        )
53        deformation_fields = [
54            np.random.uniform(-1, 1, control_shape),
55            np.random.uniform(-1, 1, control_shape)
56        ]
57        deformation_fields = [
58            resize(df, shape, order=3)[None] for df in deformation_fields
59        ]
60        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
61        noise = torch.from_numpy(noise)
62        return {"noise": noise}
64    def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor:
65        """Apply the augmentation to a tensor.
67        Args:
68            input: The input tensor.
69            params: The transformation parameters.
71        Returns:
72            The transformed tensor.
73        """
74        assert len(input.shape) == 5
75        if params is None:
76            params = self.generate_parameters(input.shape)
77            self._params = params
79        noise = params["noise"]
80        mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest"
81        noise_ch = noise.expand(input.shape[1], -1, -1, -1)
82        input_transformed = []
83        for i, x in enumerate(torch.unbind(input, dim=0)):
84            x_transformed = kornia.geometry.transform.elastic_transform2d(
85                x, noise_ch, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection"
86            )
87            input_transformed.append(x_transformed)
88        input_transformed = torch.stack(input_transformed)
89        return input_transformed

Random elastic deformations implemented with kornia.

This transformation can be applied to 3D data, the same deformation is applied to each plane.

  • control_point_spacing: The control point spacing for the deformation field.
  • sigma: Sigma for smoothing the deformation field.
  • alpha: Alpha value.
  • interpolation: Interpolation order for applying the transformation to the data.
  • p: Probability for applying the transformation.
  • keepdim:
  • same_on_batch:
RandomElasticDeformationStacked( control_point_spacing: Union[int, Sequence[int]] = 1, sigma: Tuple[float, float] = (32.0, 32.0), alpha: Tuple[float, float] = (4.0, 4.0), interpolation=<Resample.BILINEAR: 1>, p: float = 0.5, keepdim: bool = False, same_on_batch: bool = True)
26    def __init__(
27        self,
28        control_point_spacing: Union[int, Sequence[int]] = 1,
29        sigma: Tuple[float, float] = (32.0, 32.0),
30        alpha: Tuple[float, float] = (4.0, 4.0),
31        interpolation=kornia.constants.Resample.BILINEAR,
32        p: float = 0.5,
33        keepdim: bool = False,
34        same_on_batch: bool = True,
35    ):
36        super().__init__(p=p, same_on_batch=same_on_batch)
37        if isinstance(control_point_spacing, int):
38            self.control_point_spacing = [control_point_spacing] * 2
39        else:
40            self.control_point_spacing = control_point_spacing
41        assert len(self.control_point_spacing) == 2
42        self.interpolation = interpolation
43        self.flags = dict(interpolation=torch.tensor(self.interpolation.value), sigma=sigma, alpha=alpha)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class RandomElasticDeformation(kornia.augmentation._2d.base.AugmentationBase2D):
 92class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D):
 93    """Random elastic deformations implemented with kornia.
 95    Args:
 96        control_point_spacing: The control point spacing for the deformation field.
 97        sigma: Sigma for smoothing the deformation field.
 98        alpha: Alpha value.
 99        resample: Interpolation order for applying the transformation to the data.
100        p: Probability for applying the transformation.
101        keepdim:
102        same_on_batch:
103    """
104    def __init__(
105        self,
106        control_point_spacing: Union[int, Sequence[int]] = 1,
107        sigma: Tuple[float, float] = (32.0, 32.0),
108        alpha: Tuple[float, float] = (4.0, 4.0),
109        resample=kornia.constants.Resample.BILINEAR,
110        p: float = 0.5,
111        keepdim: bool = False,
112        same_on_batch: bool = True,
113    ):
114        super().__init__(p=p, same_on_batch=same_on_batch)
115        if isinstance(control_point_spacing, int):
116            self.control_point_spacing = [control_point_spacing] * 2
117        else:
118            self.control_point_spacing = control_point_spacing
119        assert len(self.control_point_spacing) == 2
120        self.resample = resample
121        self.flags = dict(resample=torch.tensor(self.resample.value), sigma=sigma, alpha=alpha)
123    def generate_parameters(self, batch_shape):
124        """@private
125        """
126        assert len(batch_shape) == 4, f"{len(batch_shape)}"
127        shape = batch_shape[2:]
128        control_shape = tuple(sh // spacing for sh, spacing in zip(shape, self.control_point_spacing))
129        deformation_fields = [np.random.uniform(-1, 1, control_shape), np.random.uniform(-1, 1, control_shape)]
130        deformation_fields = [resize(df, shape, order=3)[None] for df in deformation_fields]
131        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
132        noise = torch.from_numpy(noise)
133        return {"noise": noise}
135    def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor:
136        """Apply the augmentation to a tensor.
138        Args:
139            input: The input tensor.
140            params: The transformation parameters.
142        Returns:
143            The transformed tensor.
144        """
145        if params is None:
146            params = self.generate_parameters(input.shape)
147            self._params = params
148        noise = params["noise"]
149        mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest"
150        return kornia.geometry.transform.elastic_transform2d(
151            input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection"
152        )

Random elastic deformations implemented with kornia.

  • control_point_spacing: The control point spacing for the deformation field.
  • sigma: Sigma for smoothing the deformation field.
  • alpha: Alpha value.
  • resample: Interpolation order for applying the transformation to the data.
  • p: Probability for applying the transformation.
  • keepdim:
  • same_on_batch:
RandomElasticDeformation( control_point_spacing: Union[int, Sequence[int]] = 1, sigma: Tuple[float, float] = (32.0, 32.0), alpha: Tuple[float, float] = (4.0, 4.0), resample=<Resample.BILINEAR: 1>, p: float = 0.5, keepdim: bool = False, same_on_batch: bool = True)
104    def __init__(
105        self,
106        control_point_spacing: Union[int, Sequence[int]] = 1,
107        sigma: Tuple[float, float] = (32.0, 32.0),
108        alpha: Tuple[float, float] = (4.0, 4.0),
109        resample=kornia.constants.Resample.BILINEAR,
110        p: float = 0.5,
111        keepdim: bool = False,
112        same_on_batch: bool = True,
113    ):
114        super().__init__(p=p, same_on_batch=same_on_batch)
115        if isinstance(control_point_spacing, int):
116            self.control_point_spacing = [control_point_spacing] * 2
117        else:
118            self.control_point_spacing = control_point_spacing
119        assert len(self.control_point_spacing) == 2
120        self.resample = resample
121        self.flags = dict(resample=torch.tensor(self.resample.value), sigma=sigma, alpha=alpha)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class KorniaAugmentationPipeline(torch.nn.modules.module.Module):
157class KorniaAugmentationPipeline(torch.nn.Module):
158    """Pipeline to apply multiple kornia augmentations to data.
160    Args:
161        kornia_augmentations: Augmentations implemented with kornia.
162        dtype: The data type of the return data.
163    """
164    interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
165    interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")]
167    def __init__(self, *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32):
168        super().__init__()
169        self.augmentations = torch.nn.ModuleList(kornia_augmentations)
170        self.dtype = dtype
171        self.halo = self.compute_halo()
173    # for now we only add a halo for the random rotation trafos and
174    # also don't compute the halo dynamically based on the input shape
175    def compute_halo(self):
176        """@private
177        """
178        halo = None
179        for aug in self.augmentations:
180            if isinstance(aug, kornia.augmentation.RandomRotation):
181                halo = [32, 32]
182            if isinstance(aug, kornia.augmentation.RandomRotation3D):
183                halo = [32, 32, 32]
184        return halo
186    def is_interpolatable(self, tensor):
187        """@private
188        """
189        if torch.is_tensor(tensor):
190            return tensor.dtype in self.interpolatable_torch_types
191        else:
192            return tensor.dtype in self.interpolatable_numpy_types
194    def transform_tensor(self, augmentation, tensor, interpolatable, params=None):
195        """@private
196        """
197        interpolating = "interpolation" in getattr(augmentation, "flags", [])
198        if interpolating:
199            resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST")
200            augmentation.flags["interpolation"] = torch.tensor(resampler.value)
201        transformed = augmentation(tensor, params)
202        return transformed, augmentation._params
204    def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]:
205        """Apply augmentations to a list of tensors.
207        Args:
208            tensors: The input tensors.
210        Returns:
211            List of transformed tensors.
212        """
213        interpolatable = [self.is_interpolatable(tensor) for tensor in tensors]
214        tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors]
215        for aug in self.augmentations:
217            t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0])
218            transformed_tensors = [t0]
219            for tensor, interpolate in zip(tensors[1:], interpolatable[1:]):
220                tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params)
221                transformed_tensors.append(tensor)
223            tensors = transformed_tensors
224        return tensors
226    def halo(self, shape):
227        """@private
228        """
229        return self.halo

Pipeline to apply multiple kornia augmentations to data.

  • kornia_augmentations: Augmentations implemented with kornia.
  • dtype: The data type of the return data.
KorniaAugmentationPipeline( *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32)
167    def __init__(self, *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32):
168        super().__init__()
169        self.augmentations = torch.nn.ModuleList(kornia_augmentations)
170        self.dtype = dtype
171        self.halo = self.compute_halo()

Initialize internal Module state, shared by both nn.Module and ScriptModule.

interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
interpolatable_numpy_types = [dtype('float32'), dtype('float64')]
def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]:
204    def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]:
205        """Apply augmentations to a list of tensors.
207        Args:
208            tensors: The input tensors.
210        Returns:
211            List of transformed tensors.
212        """
213        interpolatable = [self.is_interpolatable(tensor) for tensor in tensors]
214        tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors]
215        for aug in self.augmentations:
217            t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0])
218            transformed_tensors = [t0]
219            for tensor, interpolate in zip(tensors[1:], interpolatable[1:]):
220                tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params)
221                transformed_tensors.append(tensor)
223            tensors = transformed_tensors
224        return tensors

Apply augmentations to a list of tensors.

  • tensors: The input tensors.

List of transformed tensors.

AUGMENTATIONS = {'RandomAffine': {'degrees': 90, 'scale': (0.9, 1.1)}, 'RandomAffine3D': {'degrees': (90, 90, 90), 'scale': (0.0, 1.1)}, 'RandomDepthicalFlip3D': {}, 'RandomHorizontalFlip': {}, 'RandomHorizontalFlip3D': {}, 'RandomRotation': {'degrees': 90}, 'RandomRotation3D': {'degrees': (90, 90, 90)}, 'RandomVerticalFlip': {}, 'RandomVerticalFlip3D': {}, 'RandomElasticDeformation3D': {'alpha': [5, 5], 'sigma': [30, 30]}}

All available augmentations and their default parameters.

DEFAULT_2D_AUGMENTATIONS = ['RandomHorizontalFlip', 'RandomVerticalFlip']

The default parameters for 2D data.

DEFAULT_3D_AUGMENTATIONS = ['RandomHorizontalFlip3D', 'RandomVerticalFlip3D', 'RandomDepthicalFlip3D']

The default parameters for 3D data.

DEFAULT_ANISOTROPIC_AUGMENTATIONS = ['RandomHorizontalFlip3D', 'RandomVerticalFlip3D', 'RandomDepthicalFlip3D']

The default parameters for anisotropic 3D data.

def get_augmentations( ndim: Union[int, str] = 2, transforms=None, dtype: Union[str, torch.dtype] = torch.float32):
280def get_augmentations(ndim: Union[int, str] = 2, transforms=None, dtype: Union[str, torch.dtype] = torch.float32):
281    """Get augmentation pipeline.
283    Args:
284        ndim: The dimensionality for the augmentations. One of 2, 3 or "anisotropic".
285        transforms: The transformations to use for the augmentations.
286            If None, the default augmentations for the given data dimensionality will be used.
287        dtype: The data type of the output data of the augmentation.
289    Returns:
290        The augmentation pipeline.
291    """
292    if transforms is None:
293        assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}"
294        if ndim == 2:
295            transforms = DEFAULT_2D_AUGMENTATIONS
296        elif ndim == 3:
297            transforms = DEFAULT_3D_AUGMENTATIONS
298        else:
300    transforms = [create_augmentation(trafo) for trafo in transforms]
301    assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase) for trafo in transforms)
302    augmentations = KorniaAugmentationPipeline(*transforms, dtype=dtype)
303    return augmentations

Get augmentation pipeline.

  • ndim: The dimensionality for the augmentations. One of 2, 3 or "anisotropic".
  • transforms: The transformations to use for the augmentations. If None, the default augmentations for the given data dimensionality will be used.
  • dtype: The data type of the output data of the augmentation.

The augmentation pipeline.