torch_em.transform.invertible_augmentations

  1from typing import Callable
  2import torch
  3import kornia.augmentation as K
  4
  5
  6DEFAULT_WEAK_AUGMENTATIONS = {
  7    "intensity": {},
  8    "geometrical": {
  9        "RandomHorizontalFlip": {},
 10        "RandomVerticalFlip": {},
 11        "RandomRotation90": {"times": (-1, 2)},
 12    }
 13}
 14
 15DEFAULT_STRONG_AUGMENTATIONS = {
 16    "intensity": {
 17        "RandomGaussianBlur": {"kernel_size": (3, 3), "sigma": (0.1, 1.0)},
 18        "RandomGaussianNoise": {"mean": (0.0), "std": (0.1)},
 19    },
 20    "geometrical": {
 21        "RandomHorizontalFlip": {},
 22        "RandomVerticalFlip": {},
 23        "RandomRotation90": {"times": (-1, 2)},
 24    }
 25}
 26
 27
 28def get_intensity_augmentations(aug_name, ndim, p: float = 0.75) -> callable:
 29    if aug_name == "weak":
 30        aug_dict = DEFAULT_WEAK_AUGMENTATIONS["intensity"]
 31    elif aug_name == "strong":
 32        aug_dict = DEFAULT_STRONG_AUGMENTATIONS["intensity"]
 33    else:
 34        raise ValueError(f"Number of dimensions must be 2 or 3, got ndim={ndim}")
 35
 36    transforms_list = []
 37    for trafo, kwargs in aug_dict.items():
 38        assert trafo in dir(K), f"{trafo} not found in kornia.augmentation"
 39        transforms_list.append(getattr(K, trafo)(p=p, **kwargs))
 40
 41    if ndim == 2:
 42        return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False)
 43    elif ndim == 3:
 44        return AugmentationSequential3D(*transforms_list)
 45
 46
 47def get_geometrical_augmentations(aug_name, ndim, p: float = 0.75) -> callable:
 48    if aug_name == "weak":
 49        aug_dict = DEFAULT_WEAK_AUGMENTATIONS["geometrical"]
 50    elif aug_name == "strong":
 51        aug_dict = DEFAULT_STRONG_AUGMENTATIONS["geometrical"]
 52    else:
 53        raise ValueError(f"Number of dimensions must be 2 or 3, got ndim={ndim}")
 54
 55    transforms_list = []
 56    for trafo, kwargs in aug_dict.items():
 57        assert trafo in dir(K), f"{trafo} not found in kornia.augmentation"
 58        transforms_list.append(getattr(K, trafo)(p=p, **kwargs))
 59
 60    if ndim == 2:
 61        return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False)
 62    elif ndim == 3:
 63        return AugmentationSequential3D(*transforms_list)
 64
 65
 66def get_augmentations(aug_name: str, ndim: int, p: float = 0.75):
 67    if aug_name == "weak":
 68        intensity_transforms = get_intensity_augmentations(aug_name, ndim=ndim, p=p)
 69        geometrical_transforms = get_geometrical_augmentations(aug_name, ndim=ndim, p=p)
 70    elif aug_name == "strong":
 71        intensity_transforms = get_intensity_augmentations(aug_name, ndim=ndim, p=p)
 72        geometrical_transforms = get_geometrical_augmentations(aug_name, ndim=ndim, p=p)
 73    else:
 74        raise ValueError(f"aug_name must be 'weak' or 'strong', got {aug_name}")
 75
 76    return intensity_transforms, geometrical_transforms
 77
 78
 79class AugmentationSequential3D(torch.nn.Module):
 80    def __init__(self, *augmentations: torch.nn.Module):
 81        super().__init__()
 82        self.augmentations = torch.nn.ModuleList(augmentations)
 83        self._params = None
 84
 85    @staticmethod
 86    def _flatten(x):
 87        """
 88        (B, C, D, H, W) -> (B, C*D, H, W)
 89        """
 90        if x.ndim != 5:
 91            raise RuntimeError(f"Expected 5D tensor, got {x.shape}")
 92        b, c, d, h, w = x.shape
 93        x = x.reshape(b, c * d, h, w)
 94        return x, (b, c, d, h, w)
 95
 96    @staticmethod
 97    def _unflatten(x, shape):
 98        """
 99        (B, C*D, H, W) -> (B, C, D, H, W)
100        """
101        b, c, d, h, w = shape
102        x = x.reshape(b, c, d, h, w)
103        return x
104
105    def forward(self, x: torch.Tensor) -> torch.Tensor:
106        params_all = []
107
108        flat_x, shape = self._flatten(x)
109        for aug in self.augmentations:
110            flat_x = aug(flat_x)
111            params_all.append(aug._params)
112        out = self._unflatten(flat_x, shape)
113        self._params = params_all
114        return out
115
116    def inverse(self, x: torch.Tensor, params) -> torch.Tensor:
117
118        flat_x, shape = self._flatten(x)
119        for aug, p in reversed(list(zip(self.augmentations, params))):
120            flat_x = aug.inverse(flat_x, params=p)
121        out = self._unflatten(flat_x, shape)
122
123        return out
124
125
126class InvertibleAugmenter(torch.nn.Module):
127
128    def __init__(
129        self,
130        intensity_transforms: Callable[[torch.Tensor], torch.Tensor],
131        geometrical_transforms: Callable[[torch.Tensor], torch.Tensor],
132        **kwargs,
133    ):
134        super().__init__(**kwargs)
135        self.intensity_transforms = intensity_transforms
136        self.geometrical_transforms = geometrical_transforms
137
138    def reset(self):
139        self.params = None
140
141    def transform(self, x: torch.Tensor) -> torch.Tensor:
142        x = self.intensity_transforms(x)
143        x = self.geometrical_transforms(x)
144
145        self.params = self.geometrical_transforms._params
146
147        return x
148
149    def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
150        x_inv = self.geometrical_transforms.inverse(x, params=self.params)
151        return x_inv
152
153
154class MeanTeacherAugmenters:
155    def __init__(
156        self,
157        ndim: int,
158        teacher=None,
159        student=None,
160    ):
161        self.teacher = teacher or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
162        self.student = student or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
163
164    def reset_all(self):
165        self.teacher.reset()
166        self.student.reset()
167
168
169class FixMatchAugmenters:
170    def __init__(
171        self,
172        ndim: int,
173        teacher=None,
174        student=None,
175    ):
176        self.teacher = teacher or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
177        self.student = student or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim))
178
179    def reset_all(self):
180        self.teacher.reset()
181        self.student.reset()
182
183
184class UniMatchv2Augmenters:
185    def __init__(
186        self,
187        ndim: int,
188        weak=None,
189        strong1=None,
190        strong2=None,
191    ):
192        self.weak = weak or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
193        self.strong1 = strong1 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim))
194        self.strong2 = strong2 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim))
195
196    def reset_all(self):
197        self.weak.reset()
198        self.strong1.reset()
199        self.strong2.reset()
DEFAULT_WEAK_AUGMENTATIONS = {'intensity': {}, 'geometrical': {'RandomHorizontalFlip': {}, 'RandomVerticalFlip': {}, 'RandomRotation90': {'times': (-1, 2)}}}
DEFAULT_STRONG_AUGMENTATIONS = {'intensity': {'RandomGaussianBlur': {'kernel_size': (3, 3), 'sigma': (0.1, 1.0)}, 'RandomGaussianNoise': {'mean': 0.0, 'std': 0.1}}, 'geometrical': {'RandomHorizontalFlip': {}, 'RandomVerticalFlip': {}, 'RandomRotation90': {'times': (-1, 2)}}}
def get_intensity_augmentations(aug_name, ndim, p: float = 0.75) -> <built-in function callable>:
29def get_intensity_augmentations(aug_name, ndim, p: float = 0.75) -> callable:
30    if aug_name == "weak":
31        aug_dict = DEFAULT_WEAK_AUGMENTATIONS["intensity"]
32    elif aug_name == "strong":
33        aug_dict = DEFAULT_STRONG_AUGMENTATIONS["intensity"]
34    else:
35        raise ValueError(f"Number of dimensions must be 2 or 3, got ndim={ndim}")
36
37    transforms_list = []
38    for trafo, kwargs in aug_dict.items():
39        assert trafo in dir(K), f"{trafo} not found in kornia.augmentation"
40        transforms_list.append(getattr(K, trafo)(p=p, **kwargs))
41
42    if ndim == 2:
43        return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False)
44    elif ndim == 3:
45        return AugmentationSequential3D(*transforms_list)
def get_geometrical_augmentations(aug_name, ndim, p: float = 0.75) -> <built-in function callable>:
48def get_geometrical_augmentations(aug_name, ndim, p: float = 0.75) -> callable:
49    if aug_name == "weak":
50        aug_dict = DEFAULT_WEAK_AUGMENTATIONS["geometrical"]
51    elif aug_name == "strong":
52        aug_dict = DEFAULT_STRONG_AUGMENTATIONS["geometrical"]
53    else:
54        raise ValueError(f"Number of dimensions must be 2 or 3, got ndim={ndim}")
55
56    transforms_list = []
57    for trafo, kwargs in aug_dict.items():
58        assert trafo in dir(K), f"{trafo} not found in kornia.augmentation"
59        transforms_list.append(getattr(K, trafo)(p=p, **kwargs))
60
61    if ndim == 2:
62        return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False)
63    elif ndim == 3:
64        return AugmentationSequential3D(*transforms_list)
def get_augmentations(aug_name: str, ndim: int, p: float = 0.75):
67def get_augmentations(aug_name: str, ndim: int, p: float = 0.75):
68    if aug_name == "weak":
69        intensity_transforms = get_intensity_augmentations(aug_name, ndim=ndim, p=p)
70        geometrical_transforms = get_geometrical_augmentations(aug_name, ndim=ndim, p=p)
71    elif aug_name == "strong":
72        intensity_transforms = get_intensity_augmentations(aug_name, ndim=ndim, p=p)
73        geometrical_transforms = get_geometrical_augmentations(aug_name, ndim=ndim, p=p)
74    else:
75        raise ValueError(f"aug_name must be 'weak' or 'strong', got {aug_name}")
76
77    return intensity_transforms, geometrical_transforms
class AugmentationSequential3D(torch.nn.modules.module.Module):
 80class AugmentationSequential3D(torch.nn.Module):
 81    def __init__(self, *augmentations: torch.nn.Module):
 82        super().__init__()
 83        self.augmentations = torch.nn.ModuleList(augmentations)
 84        self._params = None
 85
 86    @staticmethod
 87    def _flatten(x):
 88        """
 89        (B, C, D, H, W) -> (B, C*D, H, W)
 90        """
 91        if x.ndim != 5:
 92            raise RuntimeError(f"Expected 5D tensor, got {x.shape}")
 93        b, c, d, h, w = x.shape
 94        x = x.reshape(b, c * d, h, w)
 95        return x, (b, c, d, h, w)
 96
 97    @staticmethod
 98    def _unflatten(x, shape):
 99        """
100        (B, C*D, H, W) -> (B, C, D, H, W)
101        """
102        b, c, d, h, w = shape
103        x = x.reshape(b, c, d, h, w)
104        return x
105
106    def forward(self, x: torch.Tensor) -> torch.Tensor:
107        params_all = []
108
109        flat_x, shape = self._flatten(x)
110        for aug in self.augmentations:
111            flat_x = aug(flat_x)
112            params_all.append(aug._params)
113        out = self._unflatten(flat_x, shape)
114        self._params = params_all
115        return out
116
117    def inverse(self, x: torch.Tensor, params) -> torch.Tensor:
118
119        flat_x, shape = self._flatten(x)
120        for aug, p in reversed(list(zip(self.augmentations, params))):
121            flat_x = aug.inverse(flat_x, params=p)
122        out = self._unflatten(flat_x, shape)
123
124        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AugmentationSequential3D(*augmentations: torch.nn.modules.module.Module)
81    def __init__(self, *augmentations: torch.nn.Module):
82        super().__init__()
83        self.augmentations = torch.nn.ModuleList(augmentations)
84        self._params = None

Initialize internal Module state, shared by both nn.Module and ScriptModule.

augmentations
def forward(self, x: torch.Tensor) -> torch.Tensor:
106    def forward(self, x: torch.Tensor) -> torch.Tensor:
107        params_all = []
108
109        flat_x, shape = self._flatten(x)
110        for aug in self.augmentations:
111            flat_x = aug(flat_x)
112            params_all.append(aug._params)
113        out = self._unflatten(flat_x, shape)
114        self._params = params_all
115        return out

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def inverse(self, x: torch.Tensor, params) -> torch.Tensor:
117    def inverse(self, x: torch.Tensor, params) -> torch.Tensor:
118
119        flat_x, shape = self._flatten(x)
120        for aug, p in reversed(list(zip(self.augmentations, params))):
121            flat_x = aug.inverse(flat_x, params=p)
122        out = self._unflatten(flat_x, shape)
123
124        return out
class InvertibleAugmenter(torch.nn.modules.module.Module):
127class InvertibleAugmenter(torch.nn.Module):
128
129    def __init__(
130        self,
131        intensity_transforms: Callable[[torch.Tensor], torch.Tensor],
132        geometrical_transforms: Callable[[torch.Tensor], torch.Tensor],
133        **kwargs,
134    ):
135        super().__init__(**kwargs)
136        self.intensity_transforms = intensity_transforms
137        self.geometrical_transforms = geometrical_transforms
138
139    def reset(self):
140        self.params = None
141
142    def transform(self, x: torch.Tensor) -> torch.Tensor:
143        x = self.intensity_transforms(x)
144        x = self.geometrical_transforms(x)
145
146        self.params = self.geometrical_transforms._params
147
148        return x
149
150    def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
151        x_inv = self.geometrical_transforms.inverse(x, params=self.params)
152        return x_inv

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

InvertibleAugmenter( intensity_transforms: Callable[[torch.Tensor], torch.Tensor], geometrical_transforms: Callable[[torch.Tensor], torch.Tensor], **kwargs)
129    def __init__(
130        self,
131        intensity_transforms: Callable[[torch.Tensor], torch.Tensor],
132        geometrical_transforms: Callable[[torch.Tensor], torch.Tensor],
133        **kwargs,
134    ):
135        super().__init__(**kwargs)
136        self.intensity_transforms = intensity_transforms
137        self.geometrical_transforms = geometrical_transforms

Initialize internal Module state, shared by both nn.Module and ScriptModule.

intensity_transforms
geometrical_transforms
def reset(self):
139    def reset(self):
140        self.params = None
def transform(self, x: torch.Tensor) -> torch.Tensor:
142    def transform(self, x: torch.Tensor) -> torch.Tensor:
143        x = self.intensity_transforms(x)
144        x = self.geometrical_transforms(x)
145
146        self.params = self.geometrical_transforms._params
147
148        return x
def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
150    def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
151        x_inv = self.geometrical_transforms.inverse(x, params=self.params)
152        return x_inv
class MeanTeacherAugmenters:
155class MeanTeacherAugmenters:
156    def __init__(
157        self,
158        ndim: int,
159        teacher=None,
160        student=None,
161    ):
162        self.teacher = teacher or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
163        self.student = student or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
164
165    def reset_all(self):
166        self.teacher.reset()
167        self.student.reset()
MeanTeacherAugmenters(ndim: int, teacher=None, student=None)
156    def __init__(
157        self,
158        ndim: int,
159        teacher=None,
160        student=None,
161    ):
162        self.teacher = teacher or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
163        self.student = student or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
teacher
student
def reset_all(self):
165    def reset_all(self):
166        self.teacher.reset()
167        self.student.reset()
class FixMatchAugmenters:
170class FixMatchAugmenters:
171    def __init__(
172        self,
173        ndim: int,
174        teacher=None,
175        student=None,
176    ):
177        self.teacher = teacher or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
178        self.student = student or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim))
179
180    def reset_all(self):
181        self.teacher.reset()
182        self.student.reset()
FixMatchAugmenters(ndim: int, teacher=None, student=None)
171    def __init__(
172        self,
173        ndim: int,
174        teacher=None,
175        student=None,
176    ):
177        self.teacher = teacher or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
178        self.student = student or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim))
teacher
student
def reset_all(self):
180    def reset_all(self):
181        self.teacher.reset()
182        self.student.reset()
class UniMatchv2Augmenters:
185class UniMatchv2Augmenters:
186    def __init__(
187        self,
188        ndim: int,
189        weak=None,
190        strong1=None,
191        strong2=None,
192    ):
193        self.weak = weak or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
194        self.strong1 = strong1 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim))
195        self.strong2 = strong2 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim))
196
197    def reset_all(self):
198        self.weak.reset()
199        self.strong1.reset()
200        self.strong2.reset()
UniMatchv2Augmenters(ndim: int, weak=None, strong1=None, strong2=None)
186    def __init__(
187        self,
188        ndim: int,
189        weak=None,
190        strong1=None,
191        strong2=None,
192    ):
193        self.weak = weak or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim))
194        self.strong1 = strong1 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim))
195        self.strong2 = strong2 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim))
weak
strong1
strong2
def reset_all(self):
197    def reset_all(self):
198        self.weak.reset()
199        self.strong1.reset()
200        self.strong2.reset()