torch_em.transform.raw
1from typing import Union, Optional, Tuple, Dict, Callable 2 3import numpy as np 4 5import torch 6from torchvision import transforms 7 8 9# 10# normalization functions 11# 12 13 14TORCH_DTYPES = { 15 "float16": torch.float16, 16 "float32": torch.float32, 17 "float64": torch.float64, 18 "complex64": torch.complex64, 19 "complex128": torch.complex128, 20 "uint8": torch.uint8, 21 "int8": torch.int8, 22 "int16": torch.int16, 23 "int32": torch.int32, 24 "int64": torch.int64, 25 "bool": torch.bool, 26} 27"""@private 28""" 29 30 31def cast(inpt: Union[np.ndarray, torch.tensor], typestring: torch.dtype): 32 """@private 33 """ 34 if torch.is_tensor(inpt): 35 assert typestring in TORCH_DTYPES, f"{typestring} not in TORCH_DTYPES" 36 return inpt.to(TORCH_DTYPES[typestring]) 37 return inpt.astype(typestring) 38 39 40def standardize( 41 raw: np.ndarray, 42 mean: Optional[float] = None, 43 std: Optional[float] = None, 44 axis: Optional[Union[int, Tuple[int, ...]]] = None, 45 eps: float = 1e-7, 46) -> np.ndarray: 47 """Standardize the input data by subtracting its mean and dividing by its standard deviation. 48 49 Args: 50 raw: The input data. 51 mean: The mean value. If None, it will be computed from the data. 52 std: The standard deviation. If None, it will be computed from the data. 53 axis: The axis along which to compute the mean and standard deviation. 54 eps: The epsilon value for numerical stability. 55 56 Returns: 57 The standardized input data. 58 """ 59 raw = cast(raw, "float32") 60 mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean 61 raw -= mean 62 63 std = raw.std(axis=axis, keepdims=True) if std is None else std 64 raw /= (std + eps) 65 return raw 66 67 68def _normalize_torch(tensor, minval, maxval, axis, eps): 69 if axis: # torch returns torch.return_types.min or torch.return_types.max 70 minval = torch.amin(tensor, dim=axis, keepdim=True) if minval is None else minval 71 tensor -= minval 72 73 maxval = torch.amax(tensor, dim=axis, keepdim=True) if maxval is None else maxval 74 tensor /= (maxval + eps) 75 76 return tensor 77 78 # keepdim can only be used in combination with dim 79 minval = tensor.min() if minval is None else minval 80 tensor -= minval 81 82 maxval = tensor.max() if maxval is None else maxval 83 tensor /= (maxval + eps) 84 85 return tensor 86 87 88def normalize( 89 raw: Union[torch.tensor, np.ndarray], 90 minval: Optional[float] = None, 91 maxval: Optional[float] = None, 92 axis: Optional[Union[int, Tuple[int, ...]]] = None, 93 eps: float = 1e-7, 94) -> np.ndarray: 95 """Normalize the input data so that it is in range [0, 1]. 96 97 Args: 98 raw: The input data. 99 minval: The minimum data value. If None, it will be computed from the data. 100 maxval: The maximum data value. If None, it will be computed from the data. 101 axis: The axis along which to compute the min and max value. 102 eps: The epsilon value for numerical stability. 103 104 Returns: 105 The normalized input data. 106 """ 107 raw = cast(raw, "float32") 108 if torch.is_tensor(raw): 109 return _normalize_torch(raw, minval=minval, maxval=maxval, axis=axis, eps=eps) 110 111 minval = raw.min(axis=axis, keepdims=True) if minval is None else minval 112 raw -= minval 113 114 maxval = raw.max(axis=axis, keepdims=True) if maxval is None else maxval 115 raw /= (maxval + eps) 116 return raw 117 118 119def normalize_percentile( 120 raw: np.ndarray, 121 lower: float = 1.0, 122 upper: float = 99.0, 123 axis: Optional[Union[int, Tuple[int, ...]]] = None, 124 eps: float = 1e-7, 125) -> np.ndarray: 126 """Normalize the input data based on percentile values. 127 128 Args: 129 raw: The input data. 130 lower: The lower percentile. 131 upper: The upper percentile. 132 axis: The axis along which to compute the percentiles. 133 eps: The epsilon value for numerical stability. 134 135 Returns: 136 The normalized input data. 137 """ 138 v_lower = np.percentile(raw, lower, axis=axis, keepdims=True) 139 v_upper = np.percentile(raw, upper, axis=axis, keepdims=True) - v_lower 140 return normalize(raw, v_lower, v_upper, eps=eps) 141 142 143# 144# Intensity Augmentations / Noise Augmentations. 145# 146 147# modified from https://github.com/kreshuklab/spoco/blob/main/spoco/transforms.py 148class RandomContrast: 149 """Transformation to adjust contrast by scaling image to `mean + alpha * (image - mean)`. 150 151 Args: 152 alpha: Minimal and maximal alpha value for adjusting the contrast. 153 The value for the transformation will be drawn uniformly from the corresponding interval. 154 mean: Mean value for the image data. 155 clip_kwargs: Keyword arguments for clipping the data after the contrast augmentation. 156 """ 157 def __init__( 158 self, alpha: Tuple[float, float] = (0.5, 2), mean: float = 0.5, clip_kwargs: Dict = {"a_min": 0, "a_max": 1} 159 ): 160 self.alpha = alpha 161 self.mean = mean 162 self.clip_kwargs = clip_kwargs 163 164 def __call__(self, img: np.ndarray) -> np.ndarray: 165 """Apply the augmentation to data. 166 167 Args: 168 img: The input image. 169 170 Returns: 171 The transformed image. 172 """ 173 alpha = np.random.uniform(self.alpha[0], self.alpha[1]) 174 result = self.mean + alpha * (img - self.mean) 175 if self.clip_kwargs: 176 return np.clip(result, **self.clip_kwargs) 177 return result 178 179 180class AdditiveGaussianNoise: 181 """Transformation to add random Gaussian noise to image. 182 183 Args: 184 scale: Scale for the noise. 185 clip_kwargs: Keyword arguments for clipping the data after the tranformation. 186 """ 187 def __init__(self, scale: Tuple[float, float] = (0.0, 0.3), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}): 188 self.scale = scale 189 self.clip_kwargs = clip_kwargs 190 191 def __call__(self, img: np.ndarray) -> np.ndarray: 192 """Apply the augmentation to data. 193 194 Args: 195 img: The input image. 196 197 Returns: 198 The transformed image. 199 """ 200 std = np.random.uniform(self.scale[0], self.scale[1]) 201 gaussian_noise = np.random.normal(0, std, size=img.shape) 202 203 if self.clip_kwargs: 204 return np.clip(img + gaussian_noise, 0, 1) 205 206 return img + gaussian_noise 207 208 209class AdditivePoissonNoise: 210 """Transformation to add random additive Poisson noise to image. 211 212 Args: 213 lam: Lambda value for the Poisson transformation. 214 clip_kwargs: Keyword arguments for clipping the data after the tranformation. 215 """ 216 # Not sure if Poisson noise like this does make sense for data that is already normalized 217 def __init__(self, lam: Tuple[float, float] = (0.0, 0.1), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}): 218 self.lam = lam 219 self.clip_kwargs = clip_kwargs 220 221 def __call__(self, img: np.ndarray) -> np.ndarray: 222 """Apply the augmentation to data. 223 224 Args: 225 img: The input image. 226 227 Returns: 228 The transformed image. 229 """ 230 lam = np.random.uniform(self.lam[0], self.lam[1]) 231 poisson_noise = np.random.poisson(lam, size=img.shape) / lam 232 if self.clip_kwargs: 233 return np.clip(img + poisson_noise, 0, 1) 234 return img + poisson_noise 235 236 237class PoissonNoise: 238 """Transformation to add random data-dependant Poisson noise to image. 239 240 Args: 241 multiplier: Multiplicative factors for deriving the lambda factor from the data. 242 The factor used for the transformation will be uniformly sampled form the range of this parameter. 243 clip_kwargs: Keyword arguments for clipping the data after the tranformation. 244 """ 245 def __init__(self, multiplier: Tuple[float, float] = (5.0, 10.0), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}): 246 self.multiplier = multiplier 247 self.clip_kwargs = clip_kwargs 248 249 def __call__(self, img: np.ndarray) -> np.ndarray: 250 """Apply the augmentation to data. 251 252 Args: 253 img: The input image. 254 255 Returns: 256 The transformed image. 257 """ 258 multiplier = np.random.uniform(self.multiplier[0], self.multiplier[1]) 259 offset = img.min() 260 poisson_noise = np.random.poisson((img - offset) * multiplier) 261 262 if isinstance(img, torch.Tensor): 263 poisson_noise = torch.Tensor(poisson_noise) 264 poisson_noise = poisson_noise / multiplier + offset 265 266 if self.clip_kwargs: 267 return np.clip(poisson_noise, **self.clip_kwargs) 268 return poisson_noise 269 270 271class GaussianBlur: 272 """Transformation to blur the image with a randomly drawn sigma value. 273 274 Args: 275 sigma: The sigma value for the transformation. 276 The value used in the transformation will be uniformly drawn from the range specified here. 277 """ 278 def __init__(self, sigma: Tuple[float, float] = (0.0, 3.0)): 279 self.sigma = sigma 280 281 def __call__(self, img: np.ndarray) -> np.ndarray: 282 """Apply the augmentation to data. 283 284 Args: 285 img: The input image. 286 287 Returns: 288 The transformed image. 289 """ 290 # Sample the sigma value. Note that we switch the bounds to ensure zero is excluded from sampling. 291 sigma = np.random.uniform(self.sigma[1], self.sigma[0]) 292 # Determine the kernel size based on the sigma value. 293 kernel_size = int(2 * np.ceil(3 * sigma) + 1) 294 if isinstance(img, np.ndarray): 295 img = torch.from_numpy(img) 296 297 return transforms.GaussianBlur(kernel_size, sigma=sigma)(img) 298 299 300# 301# Default Transformation: Apply intensity augmentations and normalize. 302# 303 304class RawTransform: 305 """The transformation for raw data during training. 306 307 Args: 308 normalizer: The normalization function. 309 augmentation1: Intensity augmentation applied before the normalization. 310 augmentation2: Intensity augmentation applied after the normalization. 311 """ 312 def __init__( 313 self, normalizer: Callable, augmentation1: Optional[Callable] = None, augmentation2: Optional[Callable] = None 314 ): 315 self.normalizer = normalizer 316 self.augmentation1 = augmentation1 317 self.augmentation2 = augmentation2 318 319 def __call__(self, raw: np.ndarray) -> np.ndarray: 320 """Apply the raw transformation. 321 322 Args: 323 raw: The raw data. 324 325 Returns: 326 The transformed raw data. 327 """ 328 if self.augmentation1 is not None: 329 raw = self.augmentation1(raw) 330 331 raw = self.normalizer(raw) 332 333 if self.augmentation2 is not None: 334 raw = self.augmentation2(raw) 335 return raw 336 337 338def get_raw_transform( 339 normalizer: Callable = standardize, 340 augmentation1: Optional[Callable] = None, 341 augmentation2: Optional[Callable] = None 342) -> Callable: 343 """Get the raw transformation. 344 345 Args: 346 normalizer: The normalization function. 347 augmentation1: Intensity augmentation applied before the normalization. 348 augmentation2: Intensity augmentation applied after the normalization. 349 350 Returns: 351 The raw transformation. 352 """ 353 return RawTransform(normalizer, augmentation1=augmentation1, augmentation2=augmentation2) 354 355 356def get_default_mean_teacher_augmentations( 357 p: float = 0.3, 358 norm: Optional[Callable] = None, 359 blur_kwargs: Optional[Dict] = None, 360 poisson_kwargs: Optional[Dict] = None, 361 gaussian_kwargs: Optional[Dict] = None, 362) -> Callable: 363 """Get the default augmentations for mean teacher training. 364 365 The default values for the augmentations are designed for an image with pixel values in range [0, 1]. 366 By default, a normalization transformation is applied for this reason. 367 368 Args: 369 p: The probability for applying the individual intensity transformations. 370 norm: The noromaization function. 371 blur_kwargs: The keyword arguments for `GaussianBlur`. 372 poisson_kwargs: The keyword arguments for `PoissonNoise`. 373 gaussian_kwargs: The keyword arguments for `AdditiveGaussianNoise`. 374 375 Returns: 376 The raw transformation with augmentations. 377 """ 378 if norm is None: 379 norm = normalize 380 381 aug1 = transforms.Compose([ 382 norm, 383 transforms.RandomApply([GaussianBlur(**({} if blur_kwargs is None else blur_kwargs))], p=p), 384 transforms.RandomApply([PoissonNoise(**({} if poisson_kwargs is None else poisson_kwargs))], p=p/2), 385 transforms.RandomApply([AdditiveGaussianNoise(**({} if gaussian_kwargs is None else gaussian_kwargs))], p=p/2), 386 ]) 387 388 aug2 = transforms.RandomApply([RandomContrast(clip_kwargs={"a_min": 0, "a_max": 1})], p=p) 389 return get_raw_transform(normalizer=norm, augmentation1=aug1, augmentation2=aug2)
41def standardize( 42 raw: np.ndarray, 43 mean: Optional[float] = None, 44 std: Optional[float] = None, 45 axis: Optional[Union[int, Tuple[int, ...]]] = None, 46 eps: float = 1e-7, 47) -> np.ndarray: 48 """Standardize the input data by subtracting its mean and dividing by its standard deviation. 49 50 Args: 51 raw: The input data. 52 mean: The mean value. If None, it will be computed from the data. 53 std: The standard deviation. If None, it will be computed from the data. 54 axis: The axis along which to compute the mean and standard deviation. 55 eps: The epsilon value for numerical stability. 56 57 Returns: 58 The standardized input data. 59 """ 60 raw = cast(raw, "float32") 61 mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean 62 raw -= mean 63 64 std = raw.std(axis=axis, keepdims=True) if std is None else std 65 raw /= (std + eps) 66 return raw
Standardize the input data by subtracting its mean and dividing by its standard deviation.
Arguments:
- raw: The input data.
- mean: The mean value. If None, it will be computed from the data.
- std: The standard deviation. If None, it will be computed from the data.
- axis: The axis along which to compute the mean and standard deviation.
- eps: The epsilon value for numerical stability.
Returns:
The standardized input data.
89def normalize( 90 raw: Union[torch.tensor, np.ndarray], 91 minval: Optional[float] = None, 92 maxval: Optional[float] = None, 93 axis: Optional[Union[int, Tuple[int, ...]]] = None, 94 eps: float = 1e-7, 95) -> np.ndarray: 96 """Normalize the input data so that it is in range [0, 1]. 97 98 Args: 99 raw: The input data. 100 minval: The minimum data value. If None, it will be computed from the data. 101 maxval: The maximum data value. If None, it will be computed from the data. 102 axis: The axis along which to compute the min and max value. 103 eps: The epsilon value for numerical stability. 104 105 Returns: 106 The normalized input data. 107 """ 108 raw = cast(raw, "float32") 109 if torch.is_tensor(raw): 110 return _normalize_torch(raw, minval=minval, maxval=maxval, axis=axis, eps=eps) 111 112 minval = raw.min(axis=axis, keepdims=True) if minval is None else minval 113 raw -= minval 114 115 maxval = raw.max(axis=axis, keepdims=True) if maxval is None else maxval 116 raw /= (maxval + eps) 117 return raw
Normalize the input data so that it is in range [0, 1].
Arguments:
- raw: The input data.
- minval: The minimum data value. If None, it will be computed from the data.
- maxval: The maximum data value. If None, it will be computed from the data.
- axis: The axis along which to compute the min and max value.
- eps: The epsilon value for numerical stability.
Returns:
The normalized input data.
120def normalize_percentile( 121 raw: np.ndarray, 122 lower: float = 1.0, 123 upper: float = 99.0, 124 axis: Optional[Union[int, Tuple[int, ...]]] = None, 125 eps: float = 1e-7, 126) -> np.ndarray: 127 """Normalize the input data based on percentile values. 128 129 Args: 130 raw: The input data. 131 lower: The lower percentile. 132 upper: The upper percentile. 133 axis: The axis along which to compute the percentiles. 134 eps: The epsilon value for numerical stability. 135 136 Returns: 137 The normalized input data. 138 """ 139 v_lower = np.percentile(raw, lower, axis=axis, keepdims=True) 140 v_upper = np.percentile(raw, upper, axis=axis, keepdims=True) - v_lower 141 return normalize(raw, v_lower, v_upper, eps=eps)
Normalize the input data based on percentile values.
Arguments:
- raw: The input data.
- lower: The lower percentile.
- upper: The upper percentile.
- axis: The axis along which to compute the percentiles.
- eps: The epsilon value for numerical stability.
Returns:
The normalized input data.
149class RandomContrast: 150 """Transformation to adjust contrast by scaling image to `mean + alpha * (image - mean)`. 151 152 Args: 153 alpha: Minimal and maximal alpha value for adjusting the contrast. 154 The value for the transformation will be drawn uniformly from the corresponding interval. 155 mean: Mean value for the image data. 156 clip_kwargs: Keyword arguments for clipping the data after the contrast augmentation. 157 """ 158 def __init__( 159 self, alpha: Tuple[float, float] = (0.5, 2), mean: float = 0.5, clip_kwargs: Dict = {"a_min": 0, "a_max": 1} 160 ): 161 self.alpha = alpha 162 self.mean = mean 163 self.clip_kwargs = clip_kwargs 164 165 def __call__(self, img: np.ndarray) -> np.ndarray: 166 """Apply the augmentation to data. 167 168 Args: 169 img: The input image. 170 171 Returns: 172 The transformed image. 173 """ 174 alpha = np.random.uniform(self.alpha[0], self.alpha[1]) 175 result = self.mean + alpha * (img - self.mean) 176 if self.clip_kwargs: 177 return np.clip(result, **self.clip_kwargs) 178 return result
Transformation to adjust contrast by scaling image to mean + alpha * (image - mean)
.
Arguments:
- alpha: Minimal and maximal alpha value for adjusting the contrast. The value for the transformation will be drawn uniformly from the corresponding interval.
- mean: Mean value for the image data.
- clip_kwargs: Keyword arguments for clipping the data after the contrast augmentation.
181class AdditiveGaussianNoise: 182 """Transformation to add random Gaussian noise to image. 183 184 Args: 185 scale: Scale for the noise. 186 clip_kwargs: Keyword arguments for clipping the data after the tranformation. 187 """ 188 def __init__(self, scale: Tuple[float, float] = (0.0, 0.3), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}): 189 self.scale = scale 190 self.clip_kwargs = clip_kwargs 191 192 def __call__(self, img: np.ndarray) -> np.ndarray: 193 """Apply the augmentation to data. 194 195 Args: 196 img: The input image. 197 198 Returns: 199 The transformed image. 200 """ 201 std = np.random.uniform(self.scale[0], self.scale[1]) 202 gaussian_noise = np.random.normal(0, std, size=img.shape) 203 204 if self.clip_kwargs: 205 return np.clip(img + gaussian_noise, 0, 1) 206 207 return img + gaussian_noise
Transformation to add random Gaussian noise to image.
Arguments:
- scale: Scale for the noise.
- clip_kwargs: Keyword arguments for clipping the data after the tranformation.
210class AdditivePoissonNoise: 211 """Transformation to add random additive Poisson noise to image. 212 213 Args: 214 lam: Lambda value for the Poisson transformation. 215 clip_kwargs: Keyword arguments for clipping the data after the tranformation. 216 """ 217 # Not sure if Poisson noise like this does make sense for data that is already normalized 218 def __init__(self, lam: Tuple[float, float] = (0.0, 0.1), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}): 219 self.lam = lam 220 self.clip_kwargs = clip_kwargs 221 222 def __call__(self, img: np.ndarray) -> np.ndarray: 223 """Apply the augmentation to data. 224 225 Args: 226 img: The input image. 227 228 Returns: 229 The transformed image. 230 """ 231 lam = np.random.uniform(self.lam[0], self.lam[1]) 232 poisson_noise = np.random.poisson(lam, size=img.shape) / lam 233 if self.clip_kwargs: 234 return np.clip(img + poisson_noise, 0, 1) 235 return img + poisson_noise
Transformation to add random additive Poisson noise to image.
Arguments:
- lam: Lambda value for the Poisson transformation.
- clip_kwargs: Keyword arguments for clipping the data after the tranformation.
238class PoissonNoise: 239 """Transformation to add random data-dependant Poisson noise to image. 240 241 Args: 242 multiplier: Multiplicative factors for deriving the lambda factor from the data. 243 The factor used for the transformation will be uniformly sampled form the range of this parameter. 244 clip_kwargs: Keyword arguments for clipping the data after the tranformation. 245 """ 246 def __init__(self, multiplier: Tuple[float, float] = (5.0, 10.0), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}): 247 self.multiplier = multiplier 248 self.clip_kwargs = clip_kwargs 249 250 def __call__(self, img: np.ndarray) -> np.ndarray: 251 """Apply the augmentation to data. 252 253 Args: 254 img: The input image. 255 256 Returns: 257 The transformed image. 258 """ 259 multiplier = np.random.uniform(self.multiplier[0], self.multiplier[1]) 260 offset = img.min() 261 poisson_noise = np.random.poisson((img - offset) * multiplier) 262 263 if isinstance(img, torch.Tensor): 264 poisson_noise = torch.Tensor(poisson_noise) 265 poisson_noise = poisson_noise / multiplier + offset 266 267 if self.clip_kwargs: 268 return np.clip(poisson_noise, **self.clip_kwargs) 269 return poisson_noise
Transformation to add random data-dependant Poisson noise to image.
Arguments:
- multiplier: Multiplicative factors for deriving the lambda factor from the data. The factor used for the transformation will be uniformly sampled form the range of this parameter.
- clip_kwargs: Keyword arguments for clipping the data after the tranformation.
272class GaussianBlur: 273 """Transformation to blur the image with a randomly drawn sigma value. 274 275 Args: 276 sigma: The sigma value for the transformation. 277 The value used in the transformation will be uniformly drawn from the range specified here. 278 """ 279 def __init__(self, sigma: Tuple[float, float] = (0.0, 3.0)): 280 self.sigma = sigma 281 282 def __call__(self, img: np.ndarray) -> np.ndarray: 283 """Apply the augmentation to data. 284 285 Args: 286 img: The input image. 287 288 Returns: 289 The transformed image. 290 """ 291 # Sample the sigma value. Note that we switch the bounds to ensure zero is excluded from sampling. 292 sigma = np.random.uniform(self.sigma[1], self.sigma[0]) 293 # Determine the kernel size based on the sigma value. 294 kernel_size = int(2 * np.ceil(3 * sigma) + 1) 295 if isinstance(img, np.ndarray): 296 img = torch.from_numpy(img) 297 298 return transforms.GaussianBlur(kernel_size, sigma=sigma)(img)
Transformation to blur the image with a randomly drawn sigma value.
Arguments:
- sigma: The sigma value for the transformation. The value used in the transformation will be uniformly drawn from the range specified here.
305class RawTransform: 306 """The transformation for raw data during training. 307 308 Args: 309 normalizer: The normalization function. 310 augmentation1: Intensity augmentation applied before the normalization. 311 augmentation2: Intensity augmentation applied after the normalization. 312 """ 313 def __init__( 314 self, normalizer: Callable, augmentation1: Optional[Callable] = None, augmentation2: Optional[Callable] = None 315 ): 316 self.normalizer = normalizer 317 self.augmentation1 = augmentation1 318 self.augmentation2 = augmentation2 319 320 def __call__(self, raw: np.ndarray) -> np.ndarray: 321 """Apply the raw transformation. 322 323 Args: 324 raw: The raw data. 325 326 Returns: 327 The transformed raw data. 328 """ 329 if self.augmentation1 is not None: 330 raw = self.augmentation1(raw) 331 332 raw = self.normalizer(raw) 333 334 if self.augmentation2 is not None: 335 raw = self.augmentation2(raw) 336 return raw
The transformation for raw data during training.
Arguments:
- normalizer: The normalization function.
- augmentation1: Intensity augmentation applied before the normalization.
- augmentation2: Intensity augmentation applied after the normalization.
339def get_raw_transform( 340 normalizer: Callable = standardize, 341 augmentation1: Optional[Callable] = None, 342 augmentation2: Optional[Callable] = None 343) -> Callable: 344 """Get the raw transformation. 345 346 Args: 347 normalizer: The normalization function. 348 augmentation1: Intensity augmentation applied before the normalization. 349 augmentation2: Intensity augmentation applied after the normalization. 350 351 Returns: 352 The raw transformation. 353 """ 354 return RawTransform(normalizer, augmentation1=augmentation1, augmentation2=augmentation2)
Get the raw transformation.
Arguments:
- normalizer: The normalization function.
- augmentation1: Intensity augmentation applied before the normalization.
- augmentation2: Intensity augmentation applied after the normalization.
Returns:
The raw transformation.
357def get_default_mean_teacher_augmentations( 358 p: float = 0.3, 359 norm: Optional[Callable] = None, 360 blur_kwargs: Optional[Dict] = None, 361 poisson_kwargs: Optional[Dict] = None, 362 gaussian_kwargs: Optional[Dict] = None, 363) -> Callable: 364 """Get the default augmentations for mean teacher training. 365 366 The default values for the augmentations are designed for an image with pixel values in range [0, 1]. 367 By default, a normalization transformation is applied for this reason. 368 369 Args: 370 p: The probability for applying the individual intensity transformations. 371 norm: The noromaization function. 372 blur_kwargs: The keyword arguments for `GaussianBlur`. 373 poisson_kwargs: The keyword arguments for `PoissonNoise`. 374 gaussian_kwargs: The keyword arguments for `AdditiveGaussianNoise`. 375 376 Returns: 377 The raw transformation with augmentations. 378 """ 379 if norm is None: 380 norm = normalize 381 382 aug1 = transforms.Compose([ 383 norm, 384 transforms.RandomApply([GaussianBlur(**({} if blur_kwargs is None else blur_kwargs))], p=p), 385 transforms.RandomApply([PoissonNoise(**({} if poisson_kwargs is None else poisson_kwargs))], p=p/2), 386 transforms.RandomApply([AdditiveGaussianNoise(**({} if gaussian_kwargs is None else gaussian_kwargs))], p=p/2), 387 ]) 388 389 aug2 = transforms.RandomApply([RandomContrast(clip_kwargs={"a_min": 0, "a_max": 1})], p=p) 390 return get_raw_transform(normalizer=norm, augmentation1=aug1, augmentation2=aug2)
Get the default augmentations for mean teacher training.
The default values for the augmentations are designed for an image with pixel values in range [0, 1]. By default, a normalization transformation is applied for this reason.
Arguments:
- p: The probability for applying the individual intensity transformations.
- norm: The noromaization function.
- blur_kwargs: The keyword arguments for
GaussianBlur
. - poisson_kwargs: The keyword arguments for
PoissonNoise
. - gaussian_kwargs: The keyword arguments for
AdditiveGaussianNoise
.
Returns:
The raw transformation with augmentations.