torch_em.transform.invertible_augmentations
1from typing import Callable 2import torch 3import kornia.augmentation as K 4 5 6DEFAULT_WEAK_AUGMENTATIONS = { 7 "intensity": {}, 8 "geometrical": { 9 "RandomHorizontalFlip": {}, 10 "RandomVerticalFlip": {}, 11 "RandomRotation90": {"times": (-1, 2)}, 12 } 13} 14 15DEFAULT_STRONG_AUGMENTATIONS = { 16 "intensity": { 17 "RandomGaussianBlur": {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}, 18 "RandomGaussianNoise": {"mean": (0.0), "std": (0.1)}, 19 }, 20 "geometrical": { 21 "RandomHorizontalFlip": {}, 22 "RandomVerticalFlip": {}, 23 "RandomRotation90": {"times": (-1, 2)}, 24 } 25} 26 27 28def get_intensity_augmentations(aug_name, ndim, p: float = 0.75) -> callable: 29 if aug_name == "weak": 30 aug_dict = DEFAULT_WEAK_AUGMENTATIONS["intensity"] 31 elif aug_name == "strong": 32 aug_dict = DEFAULT_STRONG_AUGMENTATIONS["intensity"] 33 else: 34 raise ValueError(f"Number of dimensions must be 2 or 3, got ndim={ndim}") 35 36 transforms_list = [] 37 for trafo, kwargs in aug_dict.items(): 38 assert trafo in dir(K), f"{trafo} not found in kornia.augmentation" 39 transforms_list.append(getattr(K, trafo)(p=p, **kwargs)) 40 41 if ndim == 2: 42 return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False) 43 elif ndim == 3: 44 return AugmentationSequential3D(*transforms_list) 45 46 47def get_geometrical_augmentations(aug_name, ndim, p: float = 0.75) -> callable: 48 if aug_name == "weak": 49 aug_dict = DEFAULT_WEAK_AUGMENTATIONS["geometrical"] 50 elif aug_name == "strong": 51 aug_dict = DEFAULT_STRONG_AUGMENTATIONS["geometrical"] 52 else: 53 raise ValueError(f"Number of dimensions must be 2 or 3, got ndim={ndim}") 54 55 transforms_list = [] 56 for trafo, kwargs in aug_dict.items(): 57 assert trafo in dir(K), f"{trafo} not found in kornia.augmentation" 58 transforms_list.append(getattr(K, trafo)(p=p, **kwargs)) 59 60 if ndim == 2: 61 return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False) 62 elif ndim == 3: 63 return AugmentationSequential3D(*transforms_list) 64 65 66def get_augmentations(aug_name: str, ndim: int, p: float = 0.75): 67 if aug_name == "weak": 68 intensity_transforms = get_intensity_augmentations(aug_name, ndim=ndim, p=p) 69 geometrical_transforms = get_geometrical_augmentations(aug_name, ndim=ndim, p=p) 70 elif aug_name == "strong": 71 intensity_transforms = get_intensity_augmentations(aug_name, ndim=ndim, p=p) 72 geometrical_transforms = get_geometrical_augmentations(aug_name, ndim=ndim, p=p) 73 else: 74 raise ValueError(f"aug_name must be 'weak' or 'strong', got {aug_name}") 75 76 return intensity_transforms, geometrical_transforms 77 78 79class AugmentationSequential3D(torch.nn.Module): 80 def __init__(self, *augmentations: torch.nn.Module): 81 super().__init__() 82 self.augmentations = torch.nn.ModuleList(augmentations) 83 self._params = None 84 85 @staticmethod 86 def _flatten(x): 87 """ 88 (B, C, D, H, W) -> (B, C*D, H, W) 89 """ 90 if x.ndim != 5: 91 raise RuntimeError(f"Expected 5D tensor, got {x.shape}") 92 b, c, d, h, w = x.shape 93 x = x.reshape(b, c * d, h, w) 94 return x, (b, c, d, h, w) 95 96 @staticmethod 97 def _unflatten(x, shape): 98 """ 99 (B, C*D, H, W) -> (B, C, D, H, W) 100 """ 101 b, c, d, h, w = shape 102 x = x.reshape(b, c, d, h, w) 103 return x 104 105 def forward(self, x: torch.Tensor) -> torch.Tensor: 106 params_all = [] 107 108 flat_x, shape = self._flatten(x) 109 for aug in self.augmentations: 110 flat_x = aug(flat_x) 111 params_all.append(aug._params) 112 out = self._unflatten(flat_x, shape) 113 self._params = params_all 114 return out 115 116 def inverse(self, x: torch.Tensor, params) -> torch.Tensor: 117 118 flat_x, shape = self._flatten(x) 119 for aug, p in reversed(list(zip(self.augmentations, params))): 120 flat_x = aug.inverse(flat_x, params=p) 121 out = self._unflatten(flat_x, shape) 122 123 return out 124 125 126class InvertibleAugmenter(torch.nn.Module): 127 128 def __init__( 129 self, 130 intensity_transforms: Callable[[torch.Tensor], torch.Tensor], 131 geometrical_transforms: Callable[[torch.Tensor], torch.Tensor], 132 **kwargs, 133 ): 134 super().__init__(**kwargs) 135 self.intensity_transforms = intensity_transforms 136 self.geometrical_transforms = geometrical_transforms 137 138 def reset(self): 139 self.params = None 140 141 def transform(self, x: torch.Tensor) -> torch.Tensor: 142 x = self.intensity_transforms(x) 143 x = self.geometrical_transforms(x) 144 145 self.params = self.geometrical_transforms._params 146 147 return x 148 149 def reverse_transform(self, x: torch.Tensor) -> torch.Tensor: 150 x_inv = self.geometrical_transforms.inverse(x, params=self.params) 151 return x_inv 152 153 154class MeanTeacherAugmenters: 155 def __init__( 156 self, 157 ndim: int, 158 teacher=None, 159 student=None, 160 ): 161 self.teacher = teacher or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim)) 162 self.student = student or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim)) 163 164 def reset_all(self): 165 self.teacher.reset() 166 self.student.reset() 167 168 169class FixMatchAugmenters: 170 def __init__( 171 self, 172 ndim: int, 173 teacher=None, 174 student=None, 175 ): 176 self.teacher = teacher or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim)) 177 self.student = student or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim)) 178 179 def reset_all(self): 180 self.teacher.reset() 181 self.student.reset() 182 183 184class UniMatchv2Augmenters: 185 def __init__( 186 self, 187 ndim: int, 188 weak=None, 189 strong1=None, 190 strong2=None, 191 ): 192 self.weak = weak or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim)) 193 self.strong1 = strong1 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim)) 194 self.strong2 = strong2 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim)) 195 196 def reset_all(self): 197 self.weak.reset() 198 self.strong1.reset() 199 self.strong2.reset()
29def get_intensity_augmentations(aug_name, ndim, p: float = 0.75) -> callable: 30 if aug_name == "weak": 31 aug_dict = DEFAULT_WEAK_AUGMENTATIONS["intensity"] 32 elif aug_name == "strong": 33 aug_dict = DEFAULT_STRONG_AUGMENTATIONS["intensity"] 34 else: 35 raise ValueError(f"Number of dimensions must be 2 or 3, got ndim={ndim}") 36 37 transforms_list = [] 38 for trafo, kwargs in aug_dict.items(): 39 assert trafo in dir(K), f"{trafo} not found in kornia.augmentation" 40 transforms_list.append(getattr(K, trafo)(p=p, **kwargs)) 41 42 if ndim == 2: 43 return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False) 44 elif ndim == 3: 45 return AugmentationSequential3D(*transforms_list)
48def get_geometrical_augmentations(aug_name, ndim, p: float = 0.75) -> callable: 49 if aug_name == "weak": 50 aug_dict = DEFAULT_WEAK_AUGMENTATIONS["geometrical"] 51 elif aug_name == "strong": 52 aug_dict = DEFAULT_STRONG_AUGMENTATIONS["geometrical"] 53 else: 54 raise ValueError(f"Number of dimensions must be 2 or 3, got ndim={ndim}") 55 56 transforms_list = [] 57 for trafo, kwargs in aug_dict.items(): 58 assert trafo in dir(K), f"{trafo} not found in kornia.augmentation" 59 transforms_list.append(getattr(K, trafo)(p=p, **kwargs)) 60 61 if ndim == 2: 62 return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False) 63 elif ndim == 3: 64 return AugmentationSequential3D(*transforms_list)
67def get_augmentations(aug_name: str, ndim: int, p: float = 0.75): 68 if aug_name == "weak": 69 intensity_transforms = get_intensity_augmentations(aug_name, ndim=ndim, p=p) 70 geometrical_transforms = get_geometrical_augmentations(aug_name, ndim=ndim, p=p) 71 elif aug_name == "strong": 72 intensity_transforms = get_intensity_augmentations(aug_name, ndim=ndim, p=p) 73 geometrical_transforms = get_geometrical_augmentations(aug_name, ndim=ndim, p=p) 74 else: 75 raise ValueError(f"aug_name must be 'weak' or 'strong', got {aug_name}") 76 77 return intensity_transforms, geometrical_transforms
80class AugmentationSequential3D(torch.nn.Module): 81 def __init__(self, *augmentations: torch.nn.Module): 82 super().__init__() 83 self.augmentations = torch.nn.ModuleList(augmentations) 84 self._params = None 85 86 @staticmethod 87 def _flatten(x): 88 """ 89 (B, C, D, H, W) -> (B, C*D, H, W) 90 """ 91 if x.ndim != 5: 92 raise RuntimeError(f"Expected 5D tensor, got {x.shape}") 93 b, c, d, h, w = x.shape 94 x = x.reshape(b, c * d, h, w) 95 return x, (b, c, d, h, w) 96 97 @staticmethod 98 def _unflatten(x, shape): 99 """ 100 (B, C*D, H, W) -> (B, C, D, H, W) 101 """ 102 b, c, d, h, w = shape 103 x = x.reshape(b, c, d, h, w) 104 return x 105 106 def forward(self, x: torch.Tensor) -> torch.Tensor: 107 params_all = [] 108 109 flat_x, shape = self._flatten(x) 110 for aug in self.augmentations: 111 flat_x = aug(flat_x) 112 params_all.append(aug._params) 113 out = self._unflatten(flat_x, shape) 114 self._params = params_all 115 return out 116 117 def inverse(self, x: torch.Tensor, params) -> torch.Tensor: 118 119 flat_x, shape = self._flatten(x) 120 for aug, p in reversed(list(zip(self.augmentations, params))): 121 flat_x = aug.inverse(flat_x, params=p) 122 out = self._unflatten(flat_x, shape) 123 124 return out
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
81 def __init__(self, *augmentations: torch.nn.Module): 82 super().__init__() 83 self.augmentations = torch.nn.ModuleList(augmentations) 84 self._params = None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
106 def forward(self, x: torch.Tensor) -> torch.Tensor: 107 params_all = [] 108 109 flat_x, shape = self._flatten(x) 110 for aug in self.augmentations: 111 flat_x = aug(flat_x) 112 params_all.append(aug._params) 113 out = self._unflatten(flat_x, shape) 114 self._params = params_all 115 return out
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
127class InvertibleAugmenter(torch.nn.Module): 128 129 def __init__( 130 self, 131 intensity_transforms: Callable[[torch.Tensor], torch.Tensor], 132 geometrical_transforms: Callable[[torch.Tensor], torch.Tensor], 133 **kwargs, 134 ): 135 super().__init__(**kwargs) 136 self.intensity_transforms = intensity_transforms 137 self.geometrical_transforms = geometrical_transforms 138 139 def reset(self): 140 self.params = None 141 142 def transform(self, x: torch.Tensor) -> torch.Tensor: 143 x = self.intensity_transforms(x) 144 x = self.geometrical_transforms(x) 145 146 self.params = self.geometrical_transforms._params 147 148 return x 149 150 def reverse_transform(self, x: torch.Tensor) -> torch.Tensor: 151 x_inv = self.geometrical_transforms.inverse(x, params=self.params) 152 return x_inv
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
129 def __init__( 130 self, 131 intensity_transforms: Callable[[torch.Tensor], torch.Tensor], 132 geometrical_transforms: Callable[[torch.Tensor], torch.Tensor], 133 **kwargs, 134 ): 135 super().__init__(**kwargs) 136 self.intensity_transforms = intensity_transforms 137 self.geometrical_transforms = geometrical_transforms
Initialize internal Module state, shared by both nn.Module and ScriptModule.
155class MeanTeacherAugmenters: 156 def __init__( 157 self, 158 ndim: int, 159 teacher=None, 160 student=None, 161 ): 162 self.teacher = teacher or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim)) 163 self.student = student or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim)) 164 165 def reset_all(self): 166 self.teacher.reset() 167 self.student.reset()
170class FixMatchAugmenters: 171 def __init__( 172 self, 173 ndim: int, 174 teacher=None, 175 student=None, 176 ): 177 self.teacher = teacher or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim)) 178 self.student = student or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim)) 179 180 def reset_all(self): 181 self.teacher.reset() 182 self.student.reset()
185class UniMatchv2Augmenters: 186 def __init__( 187 self, 188 ndim: int, 189 weak=None, 190 strong1=None, 191 strong2=None, 192 ): 193 self.weak = weak or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim)) 194 self.strong1 = strong1 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim)) 195 self.strong2 = strong2 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim)) 196 197 def reset_all(self): 198 self.weak.reset() 199 self.strong1.reset() 200 self.strong2.reset()
186 def __init__( 187 self, 188 ndim: int, 189 weak=None, 190 strong1=None, 191 strong2=None, 192 ): 193 self.weak = weak or InvertibleAugmenter(*get_augmentations("weak", ndim=ndim)) 194 self.strong1 = strong1 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim)) 195 self.strong2 = strong2 or InvertibleAugmenter(*get_augmentations("strong", ndim=ndim))