torch_em.transform.invertible_augmentations
1from typing import Callable 2import torch 3import kornia.augmentation as K 4 5 6DEFAULT_WEAK_AUGMENTATIONS = { 7 "intensity": {}, 8 "geometrical": { 9 "RandomHorizontalFlip": {}, 10 "RandomVerticalFlip": {}, 11 "RandomRotation90": {"times": (-1, 2)}, 12 } 13} 14 15DEFAULT_STRONG_AUGMENTATIONS = { 16 "intensity": { 17 "RandomGaussianBlur": {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}, 18 "RandomGaussianNoise": {"mean": (0.0), "std": (0.1)}, 19 }, 20 "geometrical": { 21 "RandomHorizontalFlip": {}, 22 "RandomVerticalFlip": {}, 23 "RandomRotation90": {"times": (-1, 2)}, 24 } 25} 26 27 28def get_intensity_augmentations(ndim, aug_name: str = None, aug_dict: dict = None, p: float = 0.75) -> callable: 29 assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}" 30 assert aug_name is not None or aug_dict is not None, "Either aug_name or aug_dict must be provided." 31 32 if aug_dict is None: 33 if aug_name == "weak": 34 aug_dict = DEFAULT_WEAK_AUGMENTATIONS["intensity"] 35 elif aug_name == "strong": 36 aug_dict = DEFAULT_STRONG_AUGMENTATIONS["intensity"] 37 else: 38 raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}") 39 40 transforms_list = [] 41 for trafo, kwargs in aug_dict.items(): 42 assert trafo in dir(K), f"{trafo} not found in kornia.augmentation" 43 transforms_list.append(getattr(K, trafo)(p=p, **kwargs)) 44 45 if ndim == 2: 46 return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False) 47 elif ndim == 3: 48 return AugmentationSequential3D(*transforms_list) 49 50 51def get_geometrical_augmentations(ndim, aug_name: str = None, aug_dict: dict = None, p: float = 0.75) -> callable: 52 assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}" 53 assert aug_name is not None or aug_dict is not None, "Either aug_name or aug_dict must be provided." 54 55 if aug_dict is None: 56 if aug_name == "weak": 57 aug_dict = DEFAULT_WEAK_AUGMENTATIONS["geometrical"] 58 elif aug_name == "strong": 59 aug_dict = DEFAULT_STRONG_AUGMENTATIONS["geometrical"] 60 else: 61 raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}") 62 63 transforms_list = [] 64 for trafo, kwargs in aug_dict.items(): 65 assert trafo in dir(K), f"{trafo} not found in kornia.augmentation" 66 transforms_list.append(getattr(K, trafo)(p=p, **kwargs)) 67 68 if ndim == 2: 69 return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False) 70 elif ndim == 3: 71 return AugmentationSequential3D(*transforms_list) 72 73 74def get_default_augmentations(aug_name: str, ndim: int, p: float = 0.75): 75 assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}" 76 77 if aug_name == "weak": 78 intensity_transforms = get_intensity_augmentations(ndim, aug_name=aug_name, p=p) 79 geometrical_transforms = get_geometrical_augmentations(ndim, aug_name=aug_name, p=p) 80 elif aug_name == "strong": 81 intensity_transforms = get_intensity_augmentations(ndim, aug_name=aug_name, p=p) 82 geometrical_transforms = get_geometrical_augmentations(ndim, aug_name=aug_name, p=p) 83 else: 84 raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}") 85 86 return intensity_transforms, geometrical_transforms 87 88 89class AugmentationSequential3D(torch.nn.Module): 90 def __init__(self, *augmentations: torch.nn.Module): 91 super().__init__() 92 self.augmentations = torch.nn.ModuleList(augmentations) 93 self._params = None 94 95 @staticmethod 96 def _flatten(x): 97 """ 98 (B, C, D, H, W) -> (B, C*D, H, W) 99 """ 100 if x.ndim != 5: 101 raise RuntimeError(f"Expected 5D tensor, got {x.shape}") 102 b, c, d, h, w = x.shape 103 x = x.reshape(b, c * d, h, w) 104 return x, (b, c, d, h, w) 105 106 @staticmethod 107 def _unflatten(x, shape): 108 """ 109 (B, C*D, H, W) -> (B, C, D, H, W) 110 """ 111 b, c, d, h, w = shape 112 x = x.reshape(b, c, d, h, w) 113 return x 114 115 def forward(self, x: torch.Tensor) -> torch.Tensor: 116 params_all = [] 117 118 flat_x, shape = self._flatten(x) 119 for aug in self.augmentations: 120 flat_x = aug(flat_x) 121 params_all.append(aug._params) 122 out = self._unflatten(flat_x, shape) 123 self._params = params_all 124 return out 125 126 def inverse(self, x: torch.Tensor, params) -> torch.Tensor: 127 128 flat_x, shape = self._flatten(x) 129 for aug, p in reversed(list(zip(self.augmentations, params))): 130 flat_x = aug.inverse(flat_x, params=p) 131 out = self._unflatten(flat_x, shape) 132 133 return out 134 135 136class InvertibleAugmenter(torch.nn.Module): 137 138 def __init__( 139 self, 140 intensity_transforms: Callable[[torch.Tensor], torch.Tensor], 141 geometrical_transforms: Callable[[torch.Tensor], torch.Tensor], 142 clip_max = None, 143 **kwargs, 144 ): 145 super().__init__(**kwargs) 146 self.intensity_transforms = intensity_transforms 147 self.geometrical_transforms = geometrical_transforms 148 self.clip_max = clip_max 149 150 def reset(self): 151 self.params = None 152 153 def transform(self, x: torch.Tensor) -> torch.Tensor: 154 x = self.intensity_transforms(x) 155 if self.clip_max is not None: 156 x = torch.clamp(x, 0.0, self.clip_max) 157 x = self.geometrical_transforms(x) 158 159 self.params = self.geometrical_transforms._params 160 161 return x 162 163 def reverse_transform(self, x: torch.Tensor) -> torch.Tensor: 164 x_inv = self.geometrical_transforms.inverse(x, params=self.params) 165 return x_inv 166 167 168class MeanTeacherAugmenters: 169 def __init__( 170 self, 171 ndim: int, 172 teacher=None, 173 student=None, 174 clip_max=None, 175 ): 176 self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 177 self.student = student or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 178 179 def reset_all(self): 180 self.teacher.reset() 181 self.student.reset() 182 183 184class FixMatchAugmenters: 185 def __init__( 186 self, 187 ndim: int, 188 teacher=None, 189 student=None, 190 clip_max=None, 191 ): 192 self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 193 self.student = student or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max) 194 195 def reset_all(self): 196 self.teacher.reset() 197 self.student.reset() 198 199 200class UniMatchv2Augmenters: 201 def __init__( 202 self, 203 ndim: int, 204 weak=None, 205 strong1=None, 206 strong2=None, 207 clip_max=None, 208 ): 209 self.weak = weak or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 210 self.strong1 = strong1 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max) 211 self.strong2 = strong2 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max) 212 213 def reset_all(self): 214 self.weak.reset() 215 self.strong1.reset() 216 self.strong2.reset()
29def get_intensity_augmentations(ndim, aug_name: str = None, aug_dict: dict = None, p: float = 0.75) -> callable: 30 assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}" 31 assert aug_name is not None or aug_dict is not None, "Either aug_name or aug_dict must be provided." 32 33 if aug_dict is None: 34 if aug_name == "weak": 35 aug_dict = DEFAULT_WEAK_AUGMENTATIONS["intensity"] 36 elif aug_name == "strong": 37 aug_dict = DEFAULT_STRONG_AUGMENTATIONS["intensity"] 38 else: 39 raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}") 40 41 transforms_list = [] 42 for trafo, kwargs in aug_dict.items(): 43 assert trafo in dir(K), f"{trafo} not found in kornia.augmentation" 44 transforms_list.append(getattr(K, trafo)(p=p, **kwargs)) 45 46 if ndim == 2: 47 return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False) 48 elif ndim == 3: 49 return AugmentationSequential3D(*transforms_list)
52def get_geometrical_augmentations(ndim, aug_name: str = None, aug_dict: dict = None, p: float = 0.75) -> callable: 53 assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}" 54 assert aug_name is not None or aug_dict is not None, "Either aug_name or aug_dict must be provided." 55 56 if aug_dict is None: 57 if aug_name == "weak": 58 aug_dict = DEFAULT_WEAK_AUGMENTATIONS["geometrical"] 59 elif aug_name == "strong": 60 aug_dict = DEFAULT_STRONG_AUGMENTATIONS["geometrical"] 61 else: 62 raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}") 63 64 transforms_list = [] 65 for trafo, kwargs in aug_dict.items(): 66 assert trafo in dir(K), f"{trafo} not found in kornia.augmentation" 67 transforms_list.append(getattr(K, trafo)(p=p, **kwargs)) 68 69 if ndim == 2: 70 return K.AugmentationSequential(*transforms_list, data_keys=["input"], same_on_batch=False) 71 elif ndim == 3: 72 return AugmentationSequential3D(*transforms_list)
75def get_default_augmentations(aug_name: str, ndim: int, p: float = 0.75): 76 assert ndim == 2 or ndim == 3, f"Number of dimensions must be 2 or 3, got ndim={ndim}" 77 78 if aug_name == "weak": 79 intensity_transforms = get_intensity_augmentations(ndim, aug_name=aug_name, p=p) 80 geometrical_transforms = get_geometrical_augmentations(ndim, aug_name=aug_name, p=p) 81 elif aug_name == "strong": 82 intensity_transforms = get_intensity_augmentations(ndim, aug_name=aug_name, p=p) 83 geometrical_transforms = get_geometrical_augmentations(ndim, aug_name=aug_name, p=p) 84 else: 85 raise ValueError(f"Augmentation name needs to be \"weak\" or \"strong\", got aug_name={aug_name}") 86 87 return intensity_transforms, geometrical_transforms
90class AugmentationSequential3D(torch.nn.Module): 91 def __init__(self, *augmentations: torch.nn.Module): 92 super().__init__() 93 self.augmentations = torch.nn.ModuleList(augmentations) 94 self._params = None 95 96 @staticmethod 97 def _flatten(x): 98 """ 99 (B, C, D, H, W) -> (B, C*D, H, W) 100 """ 101 if x.ndim != 5: 102 raise RuntimeError(f"Expected 5D tensor, got {x.shape}") 103 b, c, d, h, w = x.shape 104 x = x.reshape(b, c * d, h, w) 105 return x, (b, c, d, h, w) 106 107 @staticmethod 108 def _unflatten(x, shape): 109 """ 110 (B, C*D, H, W) -> (B, C, D, H, W) 111 """ 112 b, c, d, h, w = shape 113 x = x.reshape(b, c, d, h, w) 114 return x 115 116 def forward(self, x: torch.Tensor) -> torch.Tensor: 117 params_all = [] 118 119 flat_x, shape = self._flatten(x) 120 for aug in self.augmentations: 121 flat_x = aug(flat_x) 122 params_all.append(aug._params) 123 out = self._unflatten(flat_x, shape) 124 self._params = params_all 125 return out 126 127 def inverse(self, x: torch.Tensor, params) -> torch.Tensor: 128 129 flat_x, shape = self._flatten(x) 130 for aug, p in reversed(list(zip(self.augmentations, params))): 131 flat_x = aug.inverse(flat_x, params=p) 132 out = self._unflatten(flat_x, shape) 133 134 return out
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
91 def __init__(self, *augmentations: torch.nn.Module): 92 super().__init__() 93 self.augmentations = torch.nn.ModuleList(augmentations) 94 self._params = None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
116 def forward(self, x: torch.Tensor) -> torch.Tensor: 117 params_all = [] 118 119 flat_x, shape = self._flatten(x) 120 for aug in self.augmentations: 121 flat_x = aug(flat_x) 122 params_all.append(aug._params) 123 out = self._unflatten(flat_x, shape) 124 self._params = params_all 125 return out
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
137class InvertibleAugmenter(torch.nn.Module): 138 139 def __init__( 140 self, 141 intensity_transforms: Callable[[torch.Tensor], torch.Tensor], 142 geometrical_transforms: Callable[[torch.Tensor], torch.Tensor], 143 clip_max = None, 144 **kwargs, 145 ): 146 super().__init__(**kwargs) 147 self.intensity_transforms = intensity_transforms 148 self.geometrical_transforms = geometrical_transforms 149 self.clip_max = clip_max 150 151 def reset(self): 152 self.params = None 153 154 def transform(self, x: torch.Tensor) -> torch.Tensor: 155 x = self.intensity_transforms(x) 156 if self.clip_max is not None: 157 x = torch.clamp(x, 0.0, self.clip_max) 158 x = self.geometrical_transforms(x) 159 160 self.params = self.geometrical_transforms._params 161 162 return x 163 164 def reverse_transform(self, x: torch.Tensor) -> torch.Tensor: 165 x_inv = self.geometrical_transforms.inverse(x, params=self.params) 166 return x_inv
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
139 def __init__( 140 self, 141 intensity_transforms: Callable[[torch.Tensor], torch.Tensor], 142 geometrical_transforms: Callable[[torch.Tensor], torch.Tensor], 143 clip_max = None, 144 **kwargs, 145 ): 146 super().__init__(**kwargs) 147 self.intensity_transforms = intensity_transforms 148 self.geometrical_transforms = geometrical_transforms 149 self.clip_max = clip_max
Initialize internal Module state, shared by both nn.Module and ScriptModule.
169class MeanTeacherAugmenters: 170 def __init__( 171 self, 172 ndim: int, 173 teacher=None, 174 student=None, 175 clip_max=None, 176 ): 177 self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 178 self.student = student or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 179 180 def reset_all(self): 181 self.teacher.reset() 182 self.student.reset()
170 def __init__( 171 self, 172 ndim: int, 173 teacher=None, 174 student=None, 175 clip_max=None, 176 ): 177 self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 178 self.student = student or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max)
185class FixMatchAugmenters: 186 def __init__( 187 self, 188 ndim: int, 189 teacher=None, 190 student=None, 191 clip_max=None, 192 ): 193 self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 194 self.student = student or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max) 195 196 def reset_all(self): 197 self.teacher.reset() 198 self.student.reset()
186 def __init__( 187 self, 188 ndim: int, 189 teacher=None, 190 student=None, 191 clip_max=None, 192 ): 193 self.teacher = teacher or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 194 self.student = student or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)
201class UniMatchv2Augmenters: 202 def __init__( 203 self, 204 ndim: int, 205 weak=None, 206 strong1=None, 207 strong2=None, 208 clip_max=None, 209 ): 210 self.weak = weak or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 211 self.strong1 = strong1 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max) 212 self.strong2 = strong2 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max) 213 214 def reset_all(self): 215 self.weak.reset() 216 self.strong1.reset() 217 self.strong2.reset()
202 def __init__( 203 self, 204 ndim: int, 205 weak=None, 206 strong1=None, 207 strong2=None, 208 clip_max=None, 209 ): 210 self.weak = weak or InvertibleAugmenter(*get_default_augmentations("weak", ndim=ndim), clip_max=clip_max) 211 self.strong1 = strong1 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max) 212 self.strong2 = strong2 or InvertibleAugmenter(*get_default_augmentations("strong", ndim=ndim), clip_max=clip_max)