torch_em.transform.invertible_augmentations

  1from typing import Callable
  2import torch
  3import kornia.augmentation as K
  4
  5
  6DEFAULT_WEAK_AUGMENTATIONS = {
  7    "intensity": {},
  8    "geometrical": {
  9        "RandomHorizontalFlip": {},
 10        "RandomVerticalFlip": {},
 11        "RandomRotation90": {"times": (-1, 2)},
 12    }
 13}
 14
 15DEFAULT_STRONG_AUGMENTATIONS = {
 16    "intensity": {
 17        "RandomGaussianBlur": {"kernel_size": (3, 3), "sigma": (0.1, 1.0)},
 18        "RandomGaussianNoise": {"mean": (0.0), "std": (0.1)},
 19    },
 20    "geometrical": {
 21        "RandomHorizontalFlip": {},
 22        "RandomVerticalFlip": {},
 23        "RandomRotation90": {"times": (-1, 2)},
 24    }
 25}
 26
 27
 28def get_intensity_augmentations(ndim, aug_name: str = None, aug_dict: dict = None, p: float = 0.75) -> callable:
 29    assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}"
 30    assert aug_name is not None or aug_dict is not None, "Either aug_name or aug_dict must be provided."
 31
 32    if aug_dict is None:
 33        if aug_name == "weak":
 34            aug_dict = DEFAULT_WEAK_AUGMENTATIONS["intensity"]
 35        elif aug_name == "strong":
 36            aug_dict = DEFAULT_STRONG_AUGMENTATIONS["intensity"]
 37        else:
 38            raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}")
 39
 40    transforms_list = []
 41    for trafo, kwargs in aug_dict.items():
 42        assert trafo in dir(K), f"{trafo} not found in kornia.augmentation"
 43        transforms_list.append(getattr(K, trafo)(p=p, **kwargs))
 44
 45    if ndim == 2:
 46        return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False)
 47    elif ndim == 3:
 48        return AugmentationSequential3D(*transforms_list)
 49
 50
 51def get_geometrical_augmentations(ndim, aug_name: str = None, aug_dict: dict = None, p: float = 0.75) -> callable:
 52    assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}"
 53    assert aug_name is not None or aug_dict is not None, "Either aug_name or aug_dict must be provided."
 54
 55    if aug_dict is None:
 56        if aug_name == "weak":
 57            aug_dict = DEFAULT_WEAK_AUGMENTATIONS["geometrical"]
 58        elif aug_name == "strong":
 59            aug_dict = DEFAULT_STRONG_AUGMENTATIONS["geometrical"]
 60        else:
 61            raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}")
 62
 63    transforms_list = []
 64    for trafo, kwargs in aug_dict.items():
 65        assert trafo in dir(K), f"{trafo} not found in kornia.augmentation"
 66        transforms_list.append(getattr(K, trafo)(p=p, **kwargs))
 67
 68    if ndim == 2:
 69        return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False)
 70    elif ndim == 3:
 71        return AugmentationSequential3D(*transforms_list)
 72
 73
 74def get_default_augmentations(aug_name: str, ndim: int, p: float = 0.75):
 75    assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}"
 76
 77    if aug_name == "weak":
 78        intensity_transforms = get_intensity_augmentations(ndim, aug_name=aug_name, p=p)
 79        geometrical_transforms = get_geometrical_augmentations(ndim, aug_name=aug_name, p=p)
 80    elif aug_name == "strong":
 81        intensity_transforms = get_intensity_augmentations(ndim, aug_name=aug_name, p=p)
 82        geometrical_transforms = get_geometrical_augmentations(ndim, aug_name=aug_name, p=p)
 83    else:
 84        raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}")
 85
 86    return intensity_transforms, geometrical_transforms
 87
 88
 89class AugmentationSequential3D(torch.nn.Module):
 90    def __init__(self, *augmentations: torch.nn.Module):
 91        super().__init__()
 92        self.augmentations = torch.nn.ModuleList(augmentations)
 93        self._params = None
 94
 95    @staticmethod
 96    def _flatten(x):
 97        """
 98        (B, C, D, H, W) -> (B, C*D, H, W)
 99        """
100        if x.ndim != 5:
101            raise RuntimeError(f"Expected 5D tensor, got {x.shape}")
102        b, c, d, h, w = x.shape
103        x = x.reshape(b, c * d, h, w)
104        return x, (b, c, d, h, w)
105
106    @staticmethod
107    def _unflatten(x, shape):
108        """
109        (B, C*D, H, W) -> (B, C, D, H, W)
110        """
111        b, c, d, h, w = shape
112        x = x.reshape(b, c, d, h, w)
113        return x
114
115    def forward(self, x: torch.Tensor) -> torch.Tensor:
116        params_all = []
117
118        flat_x, shape = self._flatten(x)
119        for aug in self.augmentations:
120            flat_x = aug(flat_x)
121            params_all.append(aug._params)
122        out = self._unflatten(flat_x, shape)
123        self._params = params_all
124        return out
125
126    def inverse(self, x: torch.Tensor, params) -> torch.Tensor:
127
128        flat_x, shape = self._flatten(x)
129        for aug, p in reversed(list(zip(self.augmentations, params))):
130            flat_x = aug.inverse(flat_x, params=p)
131        out = self._unflatten(flat_x, shape)
132
133        return out
134
135
136class InvertibleAugmenter(torch.nn.Module):
137
138    def __init__(
139        self,
140        intensity_transforms: Callable[[torch.Tensor], torch.Tensor],
141        geometrical_transforms: Callable[[torch.Tensor], torch.Tensor],
142        clip_max = None,
143        **kwargs,
144    ):
145        super().__init__(**kwargs)
146        self.intensity_transforms = intensity_transforms
147        self.geometrical_transforms = geometrical_transforms
148        self.clip_max = clip_max
149
150    def reset(self):
151        self.params = None
152
153    def transform(self, x: torch.Tensor) -> torch.Tensor:
154        x = self.intensity_transforms(x)
155        if self.clip_max is not None:
156            x = torch.clamp(x, 0.0, self.clip_max)
157        x = self.geometrical_transforms(x)
158
159        self.params = self.geometrical_transforms._params
160
161        return x
162
163    def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
164        x_inv = self.geometrical_transforms.inverse(x, params=self.params)
165        return x_inv
166
167
168class MeanTeacherAugmenters:
169    def __init__(
170        self,
171        ndim: int,
172        teacher=None,
173        student=None,
174        clip_max=None,
175    ):
176        self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
177        self.student = student or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
178
179    def reset_all(self):
180        self.teacher.reset()
181        self.student.reset()
182
183
184class FixMatchAugmenters:
185    def __init__(
186        self,
187        ndim: int,
188        teacher=None,
189        student=None,
190        clip_max=None,
191    ):
192        self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
193        self.student = student or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)
194
195    def reset_all(self):
196        self.teacher.reset()
197        self.student.reset()
198
199
200class UniMatchv2Augmenters:
201    def __init__(
202        self,
203        ndim: int,
204        weak=None,
205        strong1=None,
206        strong2=None,
207        clip_max=None,
208    ):
209        self.weak = weak or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
210        self.strong1 = strong1 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)
211        self.strong2 = strong2 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)
212
213    def reset_all(self):
214        self.weak.reset()
215        self.strong1.reset()
216        self.strong2.reset()
DEFAULT_WEAK_AUGMENTATIONS = {'intensity': {}, 'geometrical': {'RandomHorizontalFlip': {}, 'RandomVerticalFlip': {}, 'RandomRotation90': {'times': (-1, 2)}}}
DEFAULT_STRONG_AUGMENTATIONS = {'intensity': {'RandomGaussianBlur': {'kernel_size': (3, 3), 'sigma': (0.1, 1.0)}, 'RandomGaussianNoise': {'mean': 0.0, 'std': 0.1}}, 'geometrical': {'RandomHorizontalFlip': {}, 'RandomVerticalFlip': {}, 'RandomRotation90': {'times': (-1, 2)}}}
def get_intensity_augmentations( ndim, aug_name: str = None, aug_dict: dict = None, p: float = 0.75) -> <built-in function callable>:
29def get_intensity_augmentations(ndim, aug_name: str = None, aug_dict: dict = None, p: float = 0.75) -> callable:
30    assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}"
31    assert aug_name is not None or aug_dict is not None, "Either aug_name or aug_dict must be provided."
32
33    if aug_dict is None:
34        if aug_name == "weak":
35            aug_dict = DEFAULT_WEAK_AUGMENTATIONS["intensity"]
36        elif aug_name == "strong":
37            aug_dict = DEFAULT_STRONG_AUGMENTATIONS["intensity"]
38        else:
39            raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}")
40
41    transforms_list = []
42    for trafo, kwargs in aug_dict.items():
43        assert trafo in dir(K), f"{trafo} not found in kornia.augmentation"
44        transforms_list.append(getattr(K, trafo)(p=p, **kwargs))
45
46    if ndim == 2:
47        return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False)
48    elif ndim == 3:
49        return AugmentationSequential3D(*transforms_list)
def get_geometrical_augmentations( ndim, aug_name: str = None, aug_dict: dict = None, p: float = 0.75) -> <built-in function callable>:
52def get_geometrical_augmentations(ndim, aug_name: str = None, aug_dict: dict = None, p: float = 0.75) -> callable:
53    assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}"
54    assert aug_name is not None or aug_dict is not None, "Either aug_name or aug_dict must be provided."
55
56    if aug_dict is None:
57        if aug_name == "weak":
58            aug_dict = DEFAULT_WEAK_AUGMENTATIONS["geometrical"]
59        elif aug_name == "strong":
60            aug_dict = DEFAULT_STRONG_AUGMENTATIONS["geometrical"]
61        else:
62            raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}")
63
64    transforms_list = []
65    for trafo, kwargs in aug_dict.items():
66        assert trafo in dir(K), f"{trafo} not found in kornia.augmentation"
67        transforms_list.append(getattr(K, trafo)(p=p, **kwargs))
68
69    if ndim == 2:
70        return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False)
71    elif ndim == 3:
72        return AugmentationSequential3D(*transforms_list)
def get_default_augmentations(aug_name: str, ndim: int, p: float = 0.75):
75def get_default_augmentations(aug_name: str, ndim: int, p: float = 0.75):
76    assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}"
77
78    if aug_name == "weak":
79        intensity_transforms = get_intensity_augmentations(ndim, aug_name=aug_name, p=p)
80        geometrical_transforms = get_geometrical_augmentations(ndim, aug_name=aug_name, p=p)
81    elif aug_name == "strong":
82        intensity_transforms = get_intensity_augmentations(ndim, aug_name=aug_name, p=p)
83        geometrical_transforms = get_geometrical_augmentations(ndim, aug_name=aug_name, p=p)
84    else:
85        raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}")
86
87    return intensity_transforms, geometrical_transforms
class AugmentationSequential3D(torch.nn.modules.module.Module):
 90class AugmentationSequential3D(torch.nn.Module):
 91    def __init__(self, *augmentations: torch.nn.Module):
 92        super().__init__()
 93        self.augmentations = torch.nn.ModuleList(augmentations)
 94        self._params = None
 95
 96    @staticmethod
 97    def _flatten(x):
 98        """
 99        (B, C, D, H, W) -> (B, C*D, H, W)
100        """
101        if x.ndim != 5:
102            raise RuntimeError(f"Expected 5D tensor, got {x.shape}")
103        b, c, d, h, w = x.shape
104        x = x.reshape(b, c * d, h, w)
105        return x, (b, c, d, h, w)
106
107    @staticmethod
108    def _unflatten(x, shape):
109        """
110        (B, C*D, H, W) -> (B, C, D, H, W)
111        """
112        b, c, d, h, w = shape
113        x = x.reshape(b, c, d, h, w)
114        return x
115
116    def forward(self, x: torch.Tensor) -> torch.Tensor:
117        params_all = []
118
119        flat_x, shape = self._flatten(x)
120        for aug in self.augmentations:
121            flat_x = aug(flat_x)
122            params_all.append(aug._params)
123        out = self._unflatten(flat_x, shape)
124        self._params = params_all
125        return out
126
127    def inverse(self, x: torch.Tensor, params) -> torch.Tensor:
128
129        flat_x, shape = self._flatten(x)
130        for aug, p in reversed(list(zip(self.augmentations, params))):
131            flat_x = aug.inverse(flat_x, params=p)
132        out = self._unflatten(flat_x, shape)
133
134        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AugmentationSequential3D(*augmentations: torch.nn.modules.module.Module)
91    def __init__(self, *augmentations: torch.nn.Module):
92        super().__init__()
93        self.augmentations = torch.nn.ModuleList(augmentations)
94        self._params = None

Initialize internal Module state, shared by both nn.Module and ScriptModule.

augmentations
def forward(self, x: torch.Tensor) -> torch.Tensor:
116    def forward(self, x: torch.Tensor) -> torch.Tensor:
117        params_all = []
118
119        flat_x, shape = self._flatten(x)
120        for aug in self.augmentations:
121            flat_x = aug(flat_x)
122            params_all.append(aug._params)
123        out = self._unflatten(flat_x, shape)
124        self._params = params_all
125        return out

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def inverse(self, x: torch.Tensor, params) -> torch.Tensor:
127    def inverse(self, x: torch.Tensor, params) -> torch.Tensor:
128
129        flat_x, shape = self._flatten(x)
130        for aug, p in reversed(list(zip(self.augmentations, params))):
131            flat_x = aug.inverse(flat_x, params=p)
132        out = self._unflatten(flat_x, shape)
133
134        return out
class InvertibleAugmenter(torch.nn.modules.module.Module):
137class InvertibleAugmenter(torch.nn.Module):
138
139    def __init__(
140        self,
141        intensity_transforms: Callable[[torch.Tensor], torch.Tensor],
142        geometrical_transforms: Callable[[torch.Tensor], torch.Tensor],
143        clip_max = None,
144        **kwargs,
145    ):
146        super().__init__(**kwargs)
147        self.intensity_transforms = intensity_transforms
148        self.geometrical_transforms = geometrical_transforms
149        self.clip_max = clip_max
150
151    def reset(self):
152        self.params = None
153
154    def transform(self, x: torch.Tensor) -> torch.Tensor:
155        x = self.intensity_transforms(x)
156        if self.clip_max is not None:
157            x = torch.clamp(x, 0.0, self.clip_max)
158        x = self.geometrical_transforms(x)
159
160        self.params = self.geometrical_transforms._params
161
162        return x
163
164    def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
165        x_inv = self.geometrical_transforms.inverse(x, params=self.params)
166        return x_inv

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

InvertibleAugmenter( intensity_transforms: Callable[[torch.Tensor], torch.Tensor], geometrical_transforms: Callable[[torch.Tensor], torch.Tensor], clip_max=None, **kwargs)
139    def __init__(
140        self,
141        intensity_transforms: Callable[[torch.Tensor], torch.Tensor],
142        geometrical_transforms: Callable[[torch.Tensor], torch.Tensor],
143        clip_max = None,
144        **kwargs,
145    ):
146        super().__init__(**kwargs)
147        self.intensity_transforms = intensity_transforms
148        self.geometrical_transforms = geometrical_transforms
149        self.clip_max = clip_max

Initialize internal Module state, shared by both nn.Module and ScriptModule.

intensity_transforms
geometrical_transforms
clip_max
def reset(self):
151    def reset(self):
152        self.params = None
def transform(self, x: torch.Tensor) -> torch.Tensor:
154    def transform(self, x: torch.Tensor) -> torch.Tensor:
155        x = self.intensity_transforms(x)
156        if self.clip_max is not None:
157            x = torch.clamp(x, 0.0, self.clip_max)
158        x = self.geometrical_transforms(x)
159
160        self.params = self.geometrical_transforms._params
161
162        return x
def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
164    def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
165        x_inv = self.geometrical_transforms.inverse(x, params=self.params)
166        return x_inv
class MeanTeacherAugmenters:
169class MeanTeacherAugmenters:
170    def __init__(
171        self,
172        ndim: int,
173        teacher=None,
174        student=None,
175        clip_max=None,
176    ):
177        self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
178        self.student = student or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
179
180    def reset_all(self):
181        self.teacher.reset()
182        self.student.reset()
MeanTeacherAugmenters(ndim: int, teacher=None, student=None, clip_max=None)
170    def __init__(
171        self,
172        ndim: int,
173        teacher=None,
174        student=None,
175        clip_max=None,
176    ):
177        self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
178        self.student = student or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
teacher
student
def reset_all(self):
180    def reset_all(self):
181        self.teacher.reset()
182        self.student.reset()
class FixMatchAugmenters:
185class FixMatchAugmenters:
186    def __init__(
187        self,
188        ndim: int,
189        teacher=None,
190        student=None,
191        clip_max=None,
192    ):
193        self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
194        self.student = student or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)
195
196    def reset_all(self):
197        self.teacher.reset()
198        self.student.reset()
FixMatchAugmenters(ndim: int, teacher=None, student=None, clip_max=None)
186    def __init__(
187        self,
188        ndim: int,
189        teacher=None,
190        student=None,
191        clip_max=None,
192    ):
193        self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
194        self.student = student or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)
teacher
student
def reset_all(self):
196    def reset_all(self):
197        self.teacher.reset()
198        self.student.reset()
class UniMatchv2Augmenters:
201class UniMatchv2Augmenters:
202    def __init__(
203        self,
204        ndim: int,
205        weak=None,
206        strong1=None,
207        strong2=None,
208        clip_max=None,
209    ):
210        self.weak = weak or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
211        self.strong1 = strong1 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)
212        self.strong2 = strong2 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)
213
214    def reset_all(self):
215        self.weak.reset()
216        self.strong1.reset()
217        self.strong2.reset()
UniMatchv2Augmenters(ndim: int, weak=None, strong1=None, strong2=None, clip_max=None)
202    def __init__(
203        self,
204        ndim: int,
205        weak=None,
206        strong1=None,
207        strong2=None,
208        clip_max=None,
209    ):
210        self.weak = weak or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
211        self.strong1 = strong1 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)
212        self.strong2 = strong2 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)
weak
strong1
strong2
def reset_all(self):
214    def reset_all(self):
215        self.weak.reset()
216        self.strong1.reset()
217        self.strong2.reset()