torch_em.transform.augmentation

  1from typing import Dict, List, Optional, Sequence, Tuple, Union
  2
  3import kornia
  4import numpy as np
  5import torch
  6from skimage.transform import resize
  7
  8from ..util import ensure_tensor
  9
 10
 11class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D):
 12    """Random elastic deformations implemented with kornia.
 13
 14    This transformation can be applied to 3D data, the same deformation is applied to each plane.
 15
 16    Args:
 17        control_point_spacing: The control point spacing for the deformation field.
 18        sigma: Sigma for smoothing the deformation field.
 19        alpha: Alpha value.
 20        interpolation: Interpolation order for applying the transformation to the data.
 21        p: Probability for applying the transformation.
 22        keepdim:
 23        same_on_batch:
 24    """
 25    def __init__(
 26        self,
 27        control_point_spacing: Union[int, Sequence[int]] = 1,
 28        sigma: Tuple[float, float] = (32.0, 32.0),
 29        alpha: Tuple[float, float] = (4.0, 4.0),
 30        interpolation=kornia.constants.Resample.BILINEAR,
 31        p: float = 0.5,
 32        keepdim: bool = False,
 33        same_on_batch: bool = True,
 34    ):
 35        super().__init__(p=p, same_on_batch=same_on_batch)
 36        if isinstance(control_point_spacing, int):
 37            self.control_point_spacing = [control_point_spacing] * 2
 38        else:
 39            self.control_point_spacing = control_point_spacing
 40        assert len(self.control_point_spacing) == 2
 41        self.interpolation = interpolation
 42        self.flags = dict(interpolation=torch.tensor(self.interpolation.value), sigma=sigma, alpha=alpha)
 43
 44    def generate_parameters(self, batch_shape):
 45        """@private
 46        """
 47        assert len(batch_shape) == 5
 48        shape = batch_shape[3:]
 49        control_shape = tuple(
 50            sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
 51        )
 52        deformation_fields = [
 53            np.random.uniform(-1, 1, control_shape),
 54            np.random.uniform(-1, 1, control_shape)
 55        ]
 56        deformation_fields = [
 57            resize(df, shape, order=3)[None] for df in deformation_fields
 58        ]
 59        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
 60        noise = torch.from_numpy(noise)
 61        return {"noise": noise}
 62
 63    def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor:
 64        """Apply the augmentation to a tensor.
 65
 66        Args:
 67            input: The input tensor.
 68            params: The transformation parameters.
 69
 70        Returns:
 71            The transformed tensor.
 72        """
 73        assert len(input.shape) == 5
 74        if params is None:
 75            params = self.generate_parameters(input.shape)
 76            self._params = params
 77
 78        noise = params["noise"]
 79        mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest"
 80        noise_ch = noise.expand(input.shape[1], -1, -1, -1)
 81        input_transformed = []
 82        for i, x in enumerate(torch.unbind(input, dim=0)):
 83            x_transformed = kornia.geometry.transform.elastic_transform2d(
 84                x, noise_ch, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection"
 85            )
 86            input_transformed.append(x_transformed)
 87        input_transformed = torch.stack(input_transformed)
 88        return input_transformed
 89
 90
 91class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D):
 92    """Random elastic deformations implemented with kornia.
 93
 94    Args:
 95        control_point_spacing: The control point spacing for the deformation field.
 96        sigma: Sigma for smoothing the deformation field.
 97        alpha: Alpha value.
 98        resample: Interpolation order for applying the transformation to the data.
 99        p: Probability for applying the transformation.
100        keepdim:
101        same_on_batch:
102    """
103    def __init__(
104        self,
105        control_point_spacing: Union[int, Sequence[int]] = 1,
106        sigma: Tuple[float, float] = (32.0, 32.0),
107        alpha: Tuple[float, float] = (4.0, 4.0),
108        resample=kornia.constants.Resample.BILINEAR,
109        p: float = 0.5,
110        keepdim: bool = False,
111        same_on_batch: bool = True,
112    ):
113        super().__init__(p=p, same_on_batch=same_on_batch)
114        if isinstance(control_point_spacing, int):
115            self.control_point_spacing = [control_point_spacing] * 2
116        else:
117            self.control_point_spacing = control_point_spacing
118        assert len(self.control_point_spacing) == 2
119        self.resample = resample
120        self.flags = dict(resample=torch.tensor(self.resample.value), sigma=sigma, alpha=alpha)
121
122    def generate_parameters(self, batch_shape):
123        """@private
124        """
125        assert len(batch_shape) == 4, f"{len(batch_shape)}"
126        shape = batch_shape[2:]
127        control_shape = tuple(sh // spacing for sh, spacing in zip(shape, self.control_point_spacing))
128        deformation_fields = [np.random.uniform(-1, 1, control_shape), np.random.uniform(-1, 1, control_shape)]
129        deformation_fields = [resize(df, shape, order=3)[None] for df in deformation_fields]
130        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
131        noise = torch.from_numpy(noise)
132        return {"noise": noise}
133
134    def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor:
135        """Apply the augmentation to a tensor.
136
137        Args:
138            input: The input tensor.
139            params: The transformation parameters.
140
141        Returns:
142            The transformed tensor.
143        """
144        if params is None:
145            params = self.generate_parameters(input.shape)
146            self._params = params
147        noise = params["noise"]
148        mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest"
149        return kornia.geometry.transform.elastic_transform2d(
150            input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection"
151        )
152
153
154# TODO: Implement 'require_halo', and estimate the halo from the transformations
155# so that we can load a bigger block and cut it away.
156class KorniaAugmentationPipeline(torch.nn.Module):
157    """Pipeline to apply multiple kornia augmentations to data.
158
159    Args:
160        kornia_augmentations: Augmentations implemented with kornia.
161        dtype: The data type of the return data.
162    """
163    interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
164    interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")]
165
166    def __init__(self, *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32):
167        super().__init__()
168        self.augmentations = torch.nn.ModuleList(kornia_augmentations)
169        self.dtype = dtype
170        self.halo = self.compute_halo()
171
172    # for now we only add a halo for the random rotation trafos and
173    # also don't compute the halo dynamically based on the input shape
174    def compute_halo(self):
175        """@private
176        """
177        halo = None
178        for aug in self.augmentations:
179            if isinstance(aug, kornia.augmentation.RandomRotation):
180                halo = [32, 32]
181            if isinstance(aug, kornia.augmentation.RandomRotation3D):
182                halo = [32, 32, 32]
183        return halo
184
185    def is_interpolatable(self, tensor):
186        """@private
187        """
188        if torch.is_tensor(tensor):
189            return tensor.dtype in self.interpolatable_torch_types
190        else:
191            return tensor.dtype in self.interpolatable_numpy_types
192
193    def transform_tensor(self, augmentation, tensor, interpolatable, params=None):
194        """@private
195        """
196        interpolating = "interpolation" in getattr(augmentation, "flags", [])
197        if interpolating:
198            resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST")
199            augmentation.flags["interpolation"] = torch.tensor(resampler.value)
200        transformed = augmentation(tensor, params)
201        return transformed, augmentation._params
202
203    def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]:
204        """Apply augmentations to a list of tensors.
205
206        Args:
207            tensors: The input tensors.
208
209        Returns:
210            List of transformed tensors.
211        """
212        interpolatable = [self.is_interpolatable(tensor) for tensor in tensors]
213        tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors]
214        for aug in self.augmentations:
215
216            t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0])
217            transformed_tensors = [t0]
218            for tensor, interpolate in zip(tensors[1:], interpolatable[1:]):
219                tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params)
220                transformed_tensors.append(tensor)
221
222            tensors = transformed_tensors
223        return tensors
224
225    def halo(self, shape):
226        """@private
227        """
228        return self.halo
229
230
231# Try out:
232# - RandomPerspective
233AUGMENTATIONS = {
234    "RandomAffine": {"degrees": 90, "scale": (0.9, 1.1)},
235    "RandomAffine3D": {"degrees": (90, 90, 90), "scale": (0.0, 1.1)},
236    "RandomDepthicalFlip3D": {},
237    "RandomHorizontalFlip": {},
238    "RandomHorizontalFlip3D": {},
239    "RandomRotation": {"degrees": 90},
240    "RandomRotation3D": {"degrees": (90, 90, 90)},
241    "RandomVerticalFlip": {},
242    "RandomVerticalFlip3D": {},
243    "RandomElasticDeformation3D": {"alpha": [5, 5], "sigma": [30, 30]}
244}
245"""All available augmentations and their default parameters.
246"""
247
248DEFAULT_2D_AUGMENTATIONS = [
249    "RandomHorizontalFlip",
250    "RandomVerticalFlip"
251]
252"""The default parameters for 2D data.
253"""
254DEFAULT_3D_AUGMENTATIONS = [
255    "RandomHorizontalFlip3D",
256    "RandomVerticalFlip3D",
257    "RandomDepthicalFlip3D",
258]
259"""The default parameters for 3D data.
260"""
261DEFAULT_ANISOTROPIC_AUGMENTATIONS = [
262    "RandomHorizontalFlip3D",
263    "RandomVerticalFlip3D",
264    "RandomDepthicalFlip3D",
265]
266"""The default parameters for anisotropic 3D data.
267"""
268
269
270def create_augmentation(trafo):
271    """@private
272    """
273    assert trafo in dir(kornia.augmentation) or trafo in globals().keys(), f"Transformation {trafo} not defined"
274    if trafo in dir(kornia.augmentation):
275        return getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo])
276    return globals()[trafo](**AUGMENTATIONS[trafo])
277
278
279def get_augmentations(ndim: Union[int, str] = 2, transforms=None, dtype: Union[str, torch.dtype] = torch.float32):
280    """Get augmentation pipeline.
281
282    Args:
283        ndim: The dimensionality for the augmentations. One of 2, 3 or "anisotropic".
284        transforms: The transformations to use for the augmentations.
285            If None, the default augmentations for the given data dimensionality will be used.
286        dtype: The data type of the output data of the augmentation.
287
288    Returns:
289        The augmentation pipeline.
290    """
291    if transforms is None:
292        assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}"
293        if ndim == 2:
294            transforms = DEFAULT_2D_AUGMENTATIONS
295        elif ndim == 3:
296            transforms = DEFAULT_3D_AUGMENTATIONS
297        else:
298            transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS
299    transforms = [create_augmentation(trafo) for trafo in transforms]
300    assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase) for trafo in transforms)
301    augmentations = KorniaAugmentationPipeline(*transforms, dtype=dtype)
302    return augmentations
class RandomElasticDeformationStacked(kornia.augmentation._3d.base.AugmentationBase3D):
12class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D):
13    """Random elastic deformations implemented with kornia.
14
15    This transformation can be applied to 3D data, the same deformation is applied to each plane.
16
17    Args:
18        control_point_spacing: The control point spacing for the deformation field.
19        sigma: Sigma for smoothing the deformation field.
20        alpha: Alpha value.
21        interpolation: Interpolation order for applying the transformation to the data.
22        p: Probability for applying the transformation.
23        keepdim:
24        same_on_batch:
25    """
26    def __init__(
27        self,
28        control_point_spacing: Union[int, Sequence[int]] = 1,
29        sigma: Tuple[float, float] = (32.0, 32.0),
30        alpha: Tuple[float, float] = (4.0, 4.0),
31        interpolation=kornia.constants.Resample.BILINEAR,
32        p: float = 0.5,
33        keepdim: bool = False,
34        same_on_batch: bool = True,
35    ):
36        super().__init__(p=p, same_on_batch=same_on_batch)
37        if isinstance(control_point_spacing, int):
38            self.control_point_spacing = [control_point_spacing] * 2
39        else:
40            self.control_point_spacing = control_point_spacing
41        assert len(self.control_point_spacing) == 2
42        self.interpolation = interpolation
43        self.flags = dict(interpolation=torch.tensor(self.interpolation.value), sigma=sigma, alpha=alpha)
44
45    def generate_parameters(self, batch_shape):
46        """@private
47        """
48        assert len(batch_shape) == 5
49        shape = batch_shape[3:]
50        control_shape = tuple(
51            sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
52        )
53        deformation_fields = [
54            np.random.uniform(-1, 1, control_shape),
55            np.random.uniform(-1, 1, control_shape)
56        ]
57        deformation_fields = [
58            resize(df, shape, order=3)[None] for df in deformation_fields
59        ]
60        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
61        noise = torch.from_numpy(noise)
62        return {"noise": noise}
63
64    def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor:
65        """Apply the augmentation to a tensor.
66
67        Args:
68            input: The input tensor.
69            params: The transformation parameters.
70
71        Returns:
72            The transformed tensor.
73        """
74        assert len(input.shape) == 5
75        if params is None:
76            params = self.generate_parameters(input.shape)
77            self._params = params
78
79        noise = params["noise"]
80        mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest"
81        noise_ch = noise.expand(input.shape[1], -1, -1, -1)
82        input_transformed = []
83        for i, x in enumerate(torch.unbind(input, dim=0)):
84            x_transformed = kornia.geometry.transform.elastic_transform2d(
85                x, noise_ch, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection"
86            )
87            input_transformed.append(x_transformed)
88        input_transformed = torch.stack(input_transformed)
89        return input_transformed

Random elastic deformations implemented with kornia.

This transformation can be applied to 3D data, the same deformation is applied to each plane.

Arguments:
  • control_point_spacing: The control point spacing for the deformation field.
  • sigma: Sigma for smoothing the deformation field.
  • alpha: Alpha value.
  • interpolation: Interpolation order for applying the transformation to the data.
  • p: Probability for applying the transformation.
  • keepdim:
  • same_on_batch:
RandomElasticDeformationStacked( control_point_spacing: Union[int, Sequence[int]] = 1, sigma: Tuple[float, float] = (32.0, 32.0), alpha: Tuple[float, float] = (4.0, 4.0), interpolation=<Resample.BILINEAR: 1>, p: float = 0.5, keepdim: bool = False, same_on_batch: bool = True)
26    def __init__(
27        self,
28        control_point_spacing: Union[int, Sequence[int]] = 1,
29        sigma: Tuple[float, float] = (32.0, 32.0),
30        alpha: Tuple[float, float] = (4.0, 4.0),
31        interpolation=kornia.constants.Resample.BILINEAR,
32        p: float = 0.5,
33        keepdim: bool = False,
34        same_on_batch: bool = True,
35    ):
36        super().__init__(p=p, same_on_batch=same_on_batch)
37        if isinstance(control_point_spacing, int):
38            self.control_point_spacing = [control_point_spacing] * 2
39        else:
40            self.control_point_spacing = control_point_spacing
41        assert len(self.control_point_spacing) == 2
42        self.interpolation = interpolation
43        self.flags = dict(interpolation=torch.tensor(self.interpolation.value), sigma=sigma, alpha=alpha)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

interpolation
flags
class RandomElasticDeformation(kornia.augmentation._2d.base.AugmentationBase2D):
 92class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D):
 93    """Random elastic deformations implemented with kornia.
 94
 95    Args:
 96        control_point_spacing: The control point spacing for the deformation field.
 97        sigma: Sigma for smoothing the deformation field.
 98        alpha: Alpha value.
 99        resample: Interpolation order for applying the transformation to the data.
100        p: Probability for applying the transformation.
101        keepdim:
102        same_on_batch:
103    """
104    def __init__(
105        self,
106        control_point_spacing: Union[int, Sequence[int]] = 1,
107        sigma: Tuple[float, float] = (32.0, 32.0),
108        alpha: Tuple[float, float] = (4.0, 4.0),
109        resample=kornia.constants.Resample.BILINEAR,
110        p: float = 0.5,
111        keepdim: bool = False,
112        same_on_batch: bool = True,
113    ):
114        super().__init__(p=p, same_on_batch=same_on_batch)
115        if isinstance(control_point_spacing, int):
116            self.control_point_spacing = [control_point_spacing] * 2
117        else:
118            self.control_point_spacing = control_point_spacing
119        assert len(self.control_point_spacing) == 2
120        self.resample = resample
121        self.flags = dict(resample=torch.tensor(self.resample.value), sigma=sigma, alpha=alpha)
122
123    def generate_parameters(self, batch_shape):
124        """@private
125        """
126        assert len(batch_shape) == 4, f"{len(batch_shape)}"
127        shape = batch_shape[2:]
128        control_shape = tuple(sh // spacing for sh, spacing in zip(shape, self.control_point_spacing))
129        deformation_fields = [np.random.uniform(-1, 1, control_shape), np.random.uniform(-1, 1, control_shape)]
130        deformation_fields = [resize(df, shape, order=3)[None] for df in deformation_fields]
131        noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
132        noise = torch.from_numpy(noise)
133        return {"noise": noise}
134
135    def __call__(self, input: torch.Tensor, params: Optional[Dict] = None) -> torch.Tensor:
136        """Apply the augmentation to a tensor.
137
138        Args:
139            input: The input tensor.
140            params: The transformation parameters.
141
142        Returns:
143            The transformed tensor.
144        """
145        if params is None:
146            params = self.generate_parameters(input.shape)
147            self._params = params
148        noise = params["noise"]
149        mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest"
150        return kornia.geometry.transform.elastic_transform2d(
151            input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, padding_mode="reflection"
152        )

Random elastic deformations implemented with kornia.

Arguments:
  • control_point_spacing: The control point spacing for the deformation field.
  • sigma: Sigma for smoothing the deformation field.
  • alpha: Alpha value.
  • resample: Interpolation order for applying the transformation to the data.
  • p: Probability for applying the transformation.
  • keepdim:
  • same_on_batch:
RandomElasticDeformation( control_point_spacing: Union[int, Sequence[int]] = 1, sigma: Tuple[float, float] = (32.0, 32.0), alpha: Tuple[float, float] = (4.0, 4.0), resample=<Resample.BILINEAR: 1>, p: float = 0.5, keepdim: bool = False, same_on_batch: bool = True)
104    def __init__(
105        self,
106        control_point_spacing: Union[int, Sequence[int]] = 1,
107        sigma: Tuple[float, float] = (32.0, 32.0),
108        alpha: Tuple[float, float] = (4.0, 4.0),
109        resample=kornia.constants.Resample.BILINEAR,
110        p: float = 0.5,
111        keepdim: bool = False,
112        same_on_batch: bool = True,
113    ):
114        super().__init__(p=p, same_on_batch=same_on_batch)
115        if isinstance(control_point_spacing, int):
116            self.control_point_spacing = [control_point_spacing] * 2
117        else:
118            self.control_point_spacing = control_point_spacing
119        assert len(self.control_point_spacing) == 2
120        self.resample = resample
121        self.flags = dict(resample=torch.tensor(self.resample.value), sigma=sigma, alpha=alpha)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

resample
flags
class KorniaAugmentationPipeline(torch.nn.modules.module.Module):
157class KorniaAugmentationPipeline(torch.nn.Module):
158    """Pipeline to apply multiple kornia augmentations to data.
159
160    Args:
161        kornia_augmentations: Augmentations implemented with kornia.
162        dtype: The data type of the return data.
163    """
164    interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
165    interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")]
166
167    def __init__(self, *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32):
168        super().__init__()
169        self.augmentations = torch.nn.ModuleList(kornia_augmentations)
170        self.dtype = dtype
171        self.halo = self.compute_halo()
172
173    # for now we only add a halo for the random rotation trafos and
174    # also don't compute the halo dynamically based on the input shape
175    def compute_halo(self):
176        """@private
177        """
178        halo = None
179        for aug in self.augmentations:
180            if isinstance(aug, kornia.augmentation.RandomRotation):
181                halo = [32, 32]
182            if isinstance(aug, kornia.augmentation.RandomRotation3D):
183                halo = [32, 32, 32]
184        return halo
185
186    def is_interpolatable(self, tensor):
187        """@private
188        """
189        if torch.is_tensor(tensor):
190            return tensor.dtype in self.interpolatable_torch_types
191        else:
192            return tensor.dtype in self.interpolatable_numpy_types
193
194    def transform_tensor(self, augmentation, tensor, interpolatable, params=None):
195        """@private
196        """
197        interpolating = "interpolation" in getattr(augmentation, "flags", [])
198        if interpolating:
199            resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST")
200            augmentation.flags["interpolation"] = torch.tensor(resampler.value)
201        transformed = augmentation(tensor, params)
202        return transformed, augmentation._params
203
204    def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]:
205        """Apply augmentations to a list of tensors.
206
207        Args:
208            tensors: The input tensors.
209
210        Returns:
211            List of transformed tensors.
212        """
213        interpolatable = [self.is_interpolatable(tensor) for tensor in tensors]
214        tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors]
215        for aug in self.augmentations:
216
217            t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0])
218            transformed_tensors = [t0]
219            for tensor, interpolate in zip(tensors[1:], interpolatable[1:]):
220                tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params)
221                transformed_tensors.append(tensor)
222
223            tensors = transformed_tensors
224        return tensors
225
226    def halo(self, shape):
227        """@private
228        """
229        return self.halo

Pipeline to apply multiple kornia augmentations to data.

Arguments:
  • kornia_augmentations: Augmentations implemented with kornia.
  • dtype: The data type of the return data.
KorniaAugmentationPipeline( *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32)
167    def __init__(self, *kornia_augmentations, dtype: Union[str, torch.dtype] = torch.float32):
168        super().__init__()
169        self.augmentations = torch.nn.ModuleList(kornia_augmentations)
170        self.dtype = dtype
171        self.halo = self.compute_halo()

Initialize internal Module state, shared by both nn.Module and ScriptModule.

interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
interpolatable_numpy_types = [dtype('float32'), dtype('float64')]
augmentations
dtype
def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]:
204    def forward(self, *tensors: torch.Tensor) -> List[torch.Tensor]:
205        """Apply augmentations to a list of tensors.
206
207        Args:
208            tensors: The input tensors.
209
210        Returns:
211            List of transformed tensors.
212        """
213        interpolatable = [self.is_interpolatable(tensor) for tensor in tensors]
214        tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors]
215        for aug in self.augmentations:
216
217            t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0])
218            transformed_tensors = [t0]
219            for tensor, interpolate in zip(tensors[1:], interpolatable[1:]):
220                tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params)
221                transformed_tensors.append(tensor)
222
223            tensors = transformed_tensors
224        return tensors

Apply augmentations to a list of tensors.

Arguments:
  • tensors: The input tensors.
Returns:

List of transformed tensors.

AUGMENTATIONS = {'RandomAffine': {'degrees': 90, 'scale': (0.9, 1.1)}, 'RandomAffine3D': {'degrees': (90, 90, 90), 'scale': (0.0, 1.1)}, 'RandomDepthicalFlip3D': {}, 'RandomHorizontalFlip': {}, 'RandomHorizontalFlip3D': {}, 'RandomRotation': {'degrees': 90}, 'RandomRotation3D': {'degrees': (90, 90, 90)}, 'RandomVerticalFlip': {}, 'RandomVerticalFlip3D': {}, 'RandomElasticDeformation3D': {'alpha': [5, 5], 'sigma': [30, 30]}}

All available augmentations and their default parameters.

DEFAULT_2D_AUGMENTATIONS = ['RandomHorizontalFlip', 'RandomVerticalFlip']

The default parameters for 2D data.

DEFAULT_3D_AUGMENTATIONS = ['RandomHorizontalFlip3D', 'RandomVerticalFlip3D', 'RandomDepthicalFlip3D']

The default parameters for 3D data.

DEFAULT_ANISOTROPIC_AUGMENTATIONS = ['RandomHorizontalFlip3D', 'RandomVerticalFlip3D', 'RandomDepthicalFlip3D']

The default parameters for anisotropic 3D data.

def get_augmentations( ndim: Union[int, str] = 2, transforms=None, dtype: Union[str, torch.dtype] = torch.float32):
280def get_augmentations(ndim: Union[int, str] = 2, transforms=None, dtype: Union[str, torch.dtype] = torch.float32):
281    """Get augmentation pipeline.
282
283    Args:
284        ndim: The dimensionality for the augmentations. One of 2, 3 or "anisotropic".
285        transforms: The transformations to use for the augmentations.
286            If None, the default augmentations for the given data dimensionality will be used.
287        dtype: The data type of the output data of the augmentation.
288
289    Returns:
290        The augmentation pipeline.
291    """
292    if transforms is None:
293        assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}"
294        if ndim == 2:
295            transforms = DEFAULT_2D_AUGMENTATIONS
296        elif ndim == 3:
297            transforms = DEFAULT_3D_AUGMENTATIONS
298        else:
299            transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS
300    transforms = [create_augmentation(trafo) for trafo in transforms]
301    assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase) for trafo in transforms)
302    augmentations = KorniaAugmentationPipeline(*transforms, dtype=dtype)
303    return augmentations

Get augmentation pipeline.

Arguments:
  • ndim: The dimensionality for the augmentations. One of 2, 3 or "anisotropic".
  • transforms: The transformations to use for the augmentations. If None, the default augmentations for the given data dimensionality will be used.
  • dtype: The data type of the output data of the augmentation.
Returns:

The augmentation pipeline.