torch_em.transform.raw

  1from typing import Union, Optional, Tuple, Dict, Callable
  2
  3import numpy as np
  4
  5import torch
  6from torchvision import transforms
  7
  8
  9#
 10# normalization functions
 11#
 12
 13
 14TORCH_DTYPES = {
 15    "float16": torch.float16,
 16    "float32": torch.float32,
 17    "float64": torch.float64,
 18    "complex64": torch.complex64,
 19    "complex128": torch.complex128,
 20    "uint8": torch.uint8,
 21    "int8": torch.int8,
 22    "int16": torch.int16,
 23    "int32": torch.int32,
 24    "int64": torch.int64,
 25    "bool": torch.bool,
 26}
 27"""@private
 28"""
 29
 30
 31def cast(inpt: Union[np.ndarray, torch.tensor], typestring: torch.dtype):
 32    """@private
 33    """
 34    if torch.is_tensor(inpt):
 35        assert typestring in TORCH_DTYPES, f"{typestring} not in TORCH_DTYPES"
 36        return inpt.to(TORCH_DTYPES[typestring])
 37    return inpt.astype(typestring)
 38
 39
 40def standardize(
 41    raw: np.ndarray,
 42    mean: Optional[float] = None,
 43    std: Optional[float] = None,
 44    axis: Optional[Union[int, Tuple[int, ...]]] = None,
 45    eps: float = 1e-7,
 46) -> np.ndarray:
 47    """Standardize the input data by subtracting its mean and dividing by its standard deviation.
 48
 49    Args:
 50        raw: The input data.
 51        mean: The mean value. If None, it will be computed from the data.
 52        std: The standard deviation. If None, it will be computed from the data.
 53        axis: The axis along which to compute the mean and standard deviation.
 54        eps: The epsilon value for numerical stability.
 55
 56    Returns:
 57        The standardized input data.
 58    """
 59    raw = cast(raw, "float32")
 60    mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean
 61    raw -= mean
 62
 63    std = raw.std(axis=axis, keepdims=True) if std is None else std
 64    raw /= (std + eps)
 65    return raw
 66
 67
 68def _normalize_torch(tensor, minval, maxval, axis, eps):
 69    if axis:  # torch returns torch.return_types.min or torch.return_types.max
 70        minval = torch.amin(tensor, dim=axis, keepdim=True) if minval is None else minval
 71        tensor -= minval
 72
 73        maxval = torch.amax(tensor, dim=axis, keepdim=True) if maxval is None else maxval
 74        tensor /= (maxval + eps)
 75
 76        return tensor
 77
 78    # keepdim can only be used in combination with dim
 79    minval = tensor.min() if minval is None else minval
 80    tensor -= minval
 81
 82    maxval = tensor.max() if maxval is None else maxval
 83    tensor /= (maxval + eps)
 84
 85    return tensor
 86
 87
 88def normalize(
 89    raw: Union[torch.tensor, np.ndarray],
 90    minval: Optional[float] = None,
 91    maxval: Optional[float] = None,
 92    axis: Optional[Union[int, Tuple[int, ...]]] = None,
 93    eps: float = 1e-7,
 94) -> np.ndarray:
 95    """Normalize the input data so that it is in range [0, 1].
 96
 97    Args:
 98        raw: The input data.
 99        minval: The minimum data value. If None, it will be computed from the data.
100        maxval: The maximum data value. If None, it will be computed from the data.
101        axis: The axis along which to compute the min and max value.
102        eps: The epsilon value for numerical stability.
103
104    Returns:
105        The normalized input data.
106    """
107    raw = cast(raw, "float32")
108    if torch.is_tensor(raw):
109        return _normalize_torch(raw, minval=minval, maxval=maxval, axis=axis, eps=eps)
110
111    minval = raw.min(axis=axis, keepdims=True) if minval is None else minval
112    raw -= minval
113
114    maxval = raw.max(axis=axis, keepdims=True) if maxval is None else maxval
115    raw /= (maxval + eps)
116    return raw
117
118
119def normalize_percentile(
120    raw: np.ndarray,
121    lower: float = 1.0,
122    upper: float = 99.0,
123    axis: Optional[Union[int, Tuple[int, ...]]] = None,
124    eps: float = 1e-7,
125) -> np.ndarray:
126    """Normalize the input data based on percentile values.
127
128    Args:
129        raw: The input data.
130        lower: The lower percentile.
131        upper: The upper percentile.
132        axis: The axis along which to compute the percentiles.
133        eps: The epsilon value for numerical stability.
134
135    Returns:
136        The normalized input data.
137    """
138    v_lower = np.percentile(raw, lower, axis=axis, keepdims=True)
139    v_upper = np.percentile(raw, upper, axis=axis, keepdims=True) - v_lower
140    return normalize(raw, v_lower, v_upper, eps=eps)
141
142
143#
144# Intensity Augmentations / Noise Augmentations.
145#
146
147# modified from https://github.com/kreshuklab/spoco/blob/main/spoco/transforms.py
148class RandomContrast:
149    """Transformation to adjust contrast by scaling image to `mean + alpha * (image - mean)`.
150
151    Args:
152        alpha: Minimal and maximal alpha value for adjusting the contrast.
153            The value for the transformation will be drawn uniformly from the corresponding interval.
154        mean: Mean value for the image data.
155        clip_kwargs: Keyword arguments for clipping the data after the contrast augmentation.
156    """
157    def __init__(
158        self, alpha: Tuple[float, float] = (0.5, 2), mean: float = 0.5, clip_kwargs: Dict = {"a_min": 0, "a_max": 1}
159    ):
160        self.alpha = alpha
161        self.mean = mean
162        self.clip_kwargs = clip_kwargs
163
164    def __call__(self, img: np.ndarray) -> np.ndarray:
165        """Apply the augmentation to data.
166
167        Args:
168            img: The input image.
169
170        Returns:
171            The transformed image.
172        """
173        alpha = np.random.uniform(self.alpha[0], self.alpha[1])
174        result = self.mean + alpha * (img - self.mean)
175        if self.clip_kwargs:
176            return np.clip(result, **self.clip_kwargs)
177        return result
178
179
180class AdditiveGaussianNoise:
181    """Transformation to add random Gaussian noise to image.
182
183    Args:
184        scale: Scale for the noise.
185        clip_kwargs: Keyword arguments for clipping the data after the tranformation.
186    """
187    def __init__(self, scale: Tuple[float, float] = (0.0, 0.3), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}):
188        self.scale = scale
189        self.clip_kwargs = clip_kwargs
190
191    def __call__(self, img: np.ndarray) -> np.ndarray:
192        """Apply the augmentation to data.
193
194        Args:
195            img: The input image.
196
197        Returns:
198            The transformed image.
199        """
200        std = np.random.uniform(self.scale[0], self.scale[1])
201        gaussian_noise = np.random.normal(0, std, size=img.shape)
202
203        if self.clip_kwargs:
204            return np.clip(img + gaussian_noise, 0, 1)
205
206        return img + gaussian_noise
207
208
209class AdditivePoissonNoise:
210    """Transformation to add random additive Poisson noise to image.
211
212    Args:
213        lam: Lambda value for the Poisson transformation.
214        clip_kwargs: Keyword arguments for clipping the data after the tranformation.
215    """
216    # Not sure if Poisson noise like this does make sense for data that is already normalized
217    def __init__(self, lam: Tuple[float, float] = (0.0, 0.1), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}):
218        self.lam = lam
219        self.clip_kwargs = clip_kwargs
220
221    def __call__(self, img: np.ndarray) -> np.ndarray:
222        """Apply the augmentation to data.
223
224        Args:
225            img: The input image.
226
227        Returns:
228            The transformed image.
229        """
230        lam = np.random.uniform(self.lam[0], self.lam[1])
231        poisson_noise = np.random.poisson(lam, size=img.shape) / lam
232        if self.clip_kwargs:
233            return np.clip(img + poisson_noise, 0, 1)
234        return img + poisson_noise
235
236
237class PoissonNoise:
238    """Transformation to add random data-dependant Poisson noise to image.
239
240    Args:
241        multiplier: Multiplicative factors for deriving the lambda factor from the data.
242            The factor used for the transformation will be uniformly sampled form the range of this parameter.
243        clip_kwargs: Keyword arguments for clipping the data after the tranformation.
244    """
245    def __init__(self, multiplier: Tuple[float, float] = (5.0, 10.0), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}):
246        self.multiplier = multiplier
247        self.clip_kwargs = clip_kwargs
248
249    def __call__(self, img: np.ndarray) -> np.ndarray:
250        """Apply the augmentation to data.
251
252        Args:
253            img: The input image.
254
255        Returns:
256            The transformed image.
257        """
258        multiplier = np.random.uniform(self.multiplier[0], self.multiplier[1])
259        offset = img.min()
260        poisson_noise = np.random.poisson((img - offset) * multiplier)
261
262        if isinstance(img, torch.Tensor):
263            poisson_noise = torch.Tensor(poisson_noise)
264        poisson_noise = poisson_noise / multiplier + offset
265
266        if self.clip_kwargs:
267            return np.clip(poisson_noise, **self.clip_kwargs)
268        return poisson_noise
269
270
271class GaussianBlur:
272    """Transformation to blur the image with a randomly drawn sigma value.
273
274    Args:
275        sigma: The sigma value for the transformation.
276            The value used in the transformation will be uniformly drawn from the range specified here.
277    """
278    def __init__(self, sigma: Tuple[float, float] = (0.0, 3.0)):
279        self.sigma = sigma
280
281    def __call__(self, img: np.ndarray) -> np.ndarray:
282        """Apply the augmentation to data.
283
284        Args:
285            img: The input image.
286
287        Returns:
288            The transformed image.
289        """
290        # Sample the sigma value. Note that we switch the bounds to ensure zero is excluded from sampling.
291        sigma = np.random.uniform(self.sigma[1], self.sigma[0])
292        # Determine the kernel size based on the sigma value.
293        kernel_size = int(2 * np.ceil(3 * sigma) + 1)
294        if isinstance(img, np.ndarray):
295            img = torch.from_numpy(img)
296
297        return transforms.GaussianBlur(kernel_size, sigma=sigma)(img)
298
299
300#
301# Default Transformation: Apply intensity augmentations and normalize.
302#
303
304class RawTransform:
305    """The transformation for raw data during training.
306
307    Args:
308        normalizer: The normalization function.
309        augmentation1: Intensity augmentation applied before the normalization.
310        augmentation2: Intensity augmentation applied after the normalization.
311    """
312    def __init__(
313        self, normalizer: Callable, augmentation1: Optional[Callable] = None, augmentation2: Optional[Callable] = None
314    ):
315        self.normalizer = normalizer
316        self.augmentation1 = augmentation1
317        self.augmentation2 = augmentation2
318
319    def __call__(self, raw: np.ndarray) -> np.ndarray:
320        """Apply the raw transformation.
321
322        Args:
323            raw: The raw data.
324
325        Returns:
326            The transformed raw data.
327        """
328        if self.augmentation1 is not None:
329            raw = self.augmentation1(raw)
330
331        raw = self.normalizer(raw)
332
333        if self.augmentation2 is not None:
334            raw = self.augmentation2(raw)
335        return raw
336
337
338def get_raw_transform(
339    normalizer: Callable = standardize,
340    augmentation1: Optional[Callable] = None,
341    augmentation2: Optional[Callable] = None
342) -> Callable:
343    """Get the raw transformation.
344
345    Args:
346        normalizer: The normalization function.
347        augmentation1: Intensity augmentation applied before the normalization.
348        augmentation2: Intensity augmentation applied after the normalization.
349
350    Returns:
351        The raw transformation.
352    """
353    return RawTransform(normalizer, augmentation1=augmentation1, augmentation2=augmentation2)
354
355
356def get_default_mean_teacher_augmentations(
357    p: float = 0.3,
358    norm: Optional[Callable] = None,
359    blur_kwargs: Optional[Dict] = None,
360    poisson_kwargs: Optional[Dict] = None,
361    gaussian_kwargs: Optional[Dict] = None,
362) -> Callable:
363    """Get the default augmentations for mean teacher training.
364
365    The default values for the augmentations are designed for an image with pixel values in range [0, 1].
366    By default, a normalization transformation is applied for this reason.
367
368    Args:
369        p: The probability for applying the individual intensity transformations.
370        norm: The noromaization function.
371        blur_kwargs: The keyword arguments for `GaussianBlur`.
372        poisson_kwargs: The keyword arguments for `PoissonNoise`.
373        gaussian_kwargs: The keyword arguments for `AdditiveGaussianNoise`.
374
375    Returns:
376        The raw transformation with augmentations.
377    """
378    if norm is None:
379        norm = normalize
380
381    aug1 = transforms.Compose([
382        norm,
383        transforms.RandomApply([GaussianBlur(**({} if blur_kwargs is None else blur_kwargs))], p=p),
384        transforms.RandomApply([PoissonNoise(**({} if poisson_kwargs is None else poisson_kwargs))], p=p/2),
385        transforms.RandomApply([AdditiveGaussianNoise(**({} if gaussian_kwargs is None else gaussian_kwargs))], p=p/2),
386    ])
387
388    aug2 = transforms.RandomApply([RandomContrast(clip_kwargs={"a_min": 0, "a_max": 1})], p=p)
389    return get_raw_transform(normalizer=norm, augmentation1=aug1, augmentation2=aug2)
def standardize( raw: numpy.ndarray, mean: Optional[float] = None, std: Optional[float] = None, axis: Union[int, Tuple[int, ...], NoneType] = None, eps: float = 1e-07) -> numpy.ndarray:
41def standardize(
42    raw: np.ndarray,
43    mean: Optional[float] = None,
44    std: Optional[float] = None,
45    axis: Optional[Union[int, Tuple[int, ...]]] = None,
46    eps: float = 1e-7,
47) -> np.ndarray:
48    """Standardize the input data by subtracting its mean and dividing by its standard deviation.
49
50    Args:
51        raw: The input data.
52        mean: The mean value. If None, it will be computed from the data.
53        std: The standard deviation. If None, it will be computed from the data.
54        axis: The axis along which to compute the mean and standard deviation.
55        eps: The epsilon value for numerical stability.
56
57    Returns:
58        The standardized input data.
59    """
60    raw = cast(raw, "float32")
61    mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean
62    raw -= mean
63
64    std = raw.std(axis=axis, keepdims=True) if std is None else std
65    raw /= (std + eps)
66    return raw

Standardize the input data by subtracting its mean and dividing by its standard deviation.

Arguments:
  • raw: The input data.
  • mean: The mean value. If None, it will be computed from the data.
  • std: The standard deviation. If None, it will be computed from the data.
  • axis: The axis along which to compute the mean and standard deviation.
  • eps: The epsilon value for numerical stability.
Returns:

The standardized input data.

def normalize( raw: Union[<built-in method tensor of type object>, numpy.ndarray], minval: Optional[float] = None, maxval: Optional[float] = None, axis: Union[int, Tuple[int, ...], NoneType] = None, eps: float = 1e-07) -> numpy.ndarray:
 89def normalize(
 90    raw: Union[torch.tensor, np.ndarray],
 91    minval: Optional[float] = None,
 92    maxval: Optional[float] = None,
 93    axis: Optional[Union[int, Tuple[int, ...]]] = None,
 94    eps: float = 1e-7,
 95) -> np.ndarray:
 96    """Normalize the input data so that it is in range [0, 1].
 97
 98    Args:
 99        raw: The input data.
100        minval: The minimum data value. If None, it will be computed from the data.
101        maxval: The maximum data value. If None, it will be computed from the data.
102        axis: The axis along which to compute the min and max value.
103        eps: The epsilon value for numerical stability.
104
105    Returns:
106        The normalized input data.
107    """
108    raw = cast(raw, "float32")
109    if torch.is_tensor(raw):
110        return _normalize_torch(raw, minval=minval, maxval=maxval, axis=axis, eps=eps)
111
112    minval = raw.min(axis=axis, keepdims=True) if minval is None else minval
113    raw -= minval
114
115    maxval = raw.max(axis=axis, keepdims=True) if maxval is None else maxval
116    raw /= (maxval + eps)
117    return raw

Normalize the input data so that it is in range [0, 1].

Arguments:
  • raw: The input data.
  • minval: The minimum data value. If None, it will be computed from the data.
  • maxval: The maximum data value. If None, it will be computed from the data.
  • axis: The axis along which to compute the min and max value.
  • eps: The epsilon value for numerical stability.
Returns:

The normalized input data.

def normalize_percentile( raw: numpy.ndarray, lower: float = 1.0, upper: float = 99.0, axis: Union[int, Tuple[int, ...], NoneType] = None, eps: float = 1e-07) -> numpy.ndarray:
120def normalize_percentile(
121    raw: np.ndarray,
122    lower: float = 1.0,
123    upper: float = 99.0,
124    axis: Optional[Union[int, Tuple[int, ...]]] = None,
125    eps: float = 1e-7,
126) -> np.ndarray:
127    """Normalize the input data based on percentile values.
128
129    Args:
130        raw: The input data.
131        lower: The lower percentile.
132        upper: The upper percentile.
133        axis: The axis along which to compute the percentiles.
134        eps: The epsilon value for numerical stability.
135
136    Returns:
137        The normalized input data.
138    """
139    v_lower = np.percentile(raw, lower, axis=axis, keepdims=True)
140    v_upper = np.percentile(raw, upper, axis=axis, keepdims=True) - v_lower
141    return normalize(raw, v_lower, v_upper, eps=eps)

Normalize the input data based on percentile values.

Arguments:
  • raw: The input data.
  • lower: The lower percentile.
  • upper: The upper percentile.
  • axis: The axis along which to compute the percentiles.
  • eps: The epsilon value for numerical stability.
Returns:

The normalized input data.

class RandomContrast:
149class RandomContrast:
150    """Transformation to adjust contrast by scaling image to `mean + alpha * (image - mean)`.
151
152    Args:
153        alpha: Minimal and maximal alpha value for adjusting the contrast.
154            The value for the transformation will be drawn uniformly from the corresponding interval.
155        mean: Mean value for the image data.
156        clip_kwargs: Keyword arguments for clipping the data after the contrast augmentation.
157    """
158    def __init__(
159        self, alpha: Tuple[float, float] = (0.5, 2), mean: float = 0.5, clip_kwargs: Dict = {"a_min": 0, "a_max": 1}
160    ):
161        self.alpha = alpha
162        self.mean = mean
163        self.clip_kwargs = clip_kwargs
164
165    def __call__(self, img: np.ndarray) -> np.ndarray:
166        """Apply the augmentation to data.
167
168        Args:
169            img: The input image.
170
171        Returns:
172            The transformed image.
173        """
174        alpha = np.random.uniform(self.alpha[0], self.alpha[1])
175        result = self.mean + alpha * (img - self.mean)
176        if self.clip_kwargs:
177            return np.clip(result, **self.clip_kwargs)
178        return result

Transformation to adjust contrast by scaling image to mean + alpha * (image - mean).

Arguments:
  • alpha: Minimal and maximal alpha value for adjusting the contrast. The value for the transformation will be drawn uniformly from the corresponding interval.
  • mean: Mean value for the image data.
  • clip_kwargs: Keyword arguments for clipping the data after the contrast augmentation.
RandomContrast( alpha: Tuple[float, float] = (0.5, 2), mean: float = 0.5, clip_kwargs: Dict = {'a_min': 0, 'a_max': 1})
158    def __init__(
159        self, alpha: Tuple[float, float] = (0.5, 2), mean: float = 0.5, clip_kwargs: Dict = {"a_min": 0, "a_max": 1}
160    ):
161        self.alpha = alpha
162        self.mean = mean
163        self.clip_kwargs = clip_kwargs
alpha
mean
clip_kwargs
class AdditiveGaussianNoise:
181class AdditiveGaussianNoise:
182    """Transformation to add random Gaussian noise to image.
183
184    Args:
185        scale: Scale for the noise.
186        clip_kwargs: Keyword arguments for clipping the data after the tranformation.
187    """
188    def __init__(self, scale: Tuple[float, float] = (0.0, 0.3), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}):
189        self.scale = scale
190        self.clip_kwargs = clip_kwargs
191
192    def __call__(self, img: np.ndarray) -> np.ndarray:
193        """Apply the augmentation to data.
194
195        Args:
196            img: The input image.
197
198        Returns:
199            The transformed image.
200        """
201        std = np.random.uniform(self.scale[0], self.scale[1])
202        gaussian_noise = np.random.normal(0, std, size=img.shape)
203
204        if self.clip_kwargs:
205            return np.clip(img + gaussian_noise, 0, 1)
206
207        return img + gaussian_noise

Transformation to add random Gaussian noise to image.

Arguments:
  • scale: Scale for the noise.
  • clip_kwargs: Keyword arguments for clipping the data after the tranformation.
AdditiveGaussianNoise( scale: Tuple[float, float] = (0.0, 0.3), clip_kwargs: Dict = {'a_min': 0, 'a_max': 1})
188    def __init__(self, scale: Tuple[float, float] = (0.0, 0.3), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}):
189        self.scale = scale
190        self.clip_kwargs = clip_kwargs
scale
clip_kwargs
class AdditivePoissonNoise:
210class AdditivePoissonNoise:
211    """Transformation to add random additive Poisson noise to image.
212
213    Args:
214        lam: Lambda value for the Poisson transformation.
215        clip_kwargs: Keyword arguments for clipping the data after the tranformation.
216    """
217    # Not sure if Poisson noise like this does make sense for data that is already normalized
218    def __init__(self, lam: Tuple[float, float] = (0.0, 0.1), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}):
219        self.lam = lam
220        self.clip_kwargs = clip_kwargs
221
222    def __call__(self, img: np.ndarray) -> np.ndarray:
223        """Apply the augmentation to data.
224
225        Args:
226            img: The input image.
227
228        Returns:
229            The transformed image.
230        """
231        lam = np.random.uniform(self.lam[0], self.lam[1])
232        poisson_noise = np.random.poisson(lam, size=img.shape) / lam
233        if self.clip_kwargs:
234            return np.clip(img + poisson_noise, 0, 1)
235        return img + poisson_noise

Transformation to add random additive Poisson noise to image.

Arguments:
  • lam: Lambda value for the Poisson transformation.
  • clip_kwargs: Keyword arguments for clipping the data after the tranformation.
AdditivePoissonNoise( lam: Tuple[float, float] = (0.0, 0.1), clip_kwargs: Dict = {'a_min': 0, 'a_max': 1})
218    def __init__(self, lam: Tuple[float, float] = (0.0, 0.1), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}):
219        self.lam = lam
220        self.clip_kwargs = clip_kwargs
lam
clip_kwargs
class PoissonNoise:
238class PoissonNoise:
239    """Transformation to add random data-dependant Poisson noise to image.
240
241    Args:
242        multiplier: Multiplicative factors for deriving the lambda factor from the data.
243            The factor used for the transformation will be uniformly sampled form the range of this parameter.
244        clip_kwargs: Keyword arguments for clipping the data after the tranformation.
245    """
246    def __init__(self, multiplier: Tuple[float, float] = (5.0, 10.0), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}):
247        self.multiplier = multiplier
248        self.clip_kwargs = clip_kwargs
249
250    def __call__(self, img: np.ndarray) -> np.ndarray:
251        """Apply the augmentation to data.
252
253        Args:
254            img: The input image.
255
256        Returns:
257            The transformed image.
258        """
259        multiplier = np.random.uniform(self.multiplier[0], self.multiplier[1])
260        offset = img.min()
261        poisson_noise = np.random.poisson((img - offset) * multiplier)
262
263        if isinstance(img, torch.Tensor):
264            poisson_noise = torch.Tensor(poisson_noise)
265        poisson_noise = poisson_noise / multiplier + offset
266
267        if self.clip_kwargs:
268            return np.clip(poisson_noise, **self.clip_kwargs)
269        return poisson_noise

Transformation to add random data-dependant Poisson noise to image.

Arguments:
  • multiplier: Multiplicative factors for deriving the lambda factor from the data. The factor used for the transformation will be uniformly sampled form the range of this parameter.
  • clip_kwargs: Keyword arguments for clipping the data after the tranformation.
PoissonNoise( multiplier: Tuple[float, float] = (5.0, 10.0), clip_kwargs: Dict = {'a_min': 0, 'a_max': 1})
246    def __init__(self, multiplier: Tuple[float, float] = (5.0, 10.0), clip_kwargs: Dict = {"a_min": 0, "a_max": 1}):
247        self.multiplier = multiplier
248        self.clip_kwargs = clip_kwargs
multiplier
clip_kwargs
class GaussianBlur:
272class GaussianBlur:
273    """Transformation to blur the image with a randomly drawn sigma value.
274
275    Args:
276        sigma: The sigma value for the transformation.
277            The value used in the transformation will be uniformly drawn from the range specified here.
278    """
279    def __init__(self, sigma: Tuple[float, float] = (0.0, 3.0)):
280        self.sigma = sigma
281
282    def __call__(self, img: np.ndarray) -> np.ndarray:
283        """Apply the augmentation to data.
284
285        Args:
286            img: The input image.
287
288        Returns:
289            The transformed image.
290        """
291        # Sample the sigma value. Note that we switch the bounds to ensure zero is excluded from sampling.
292        sigma = np.random.uniform(self.sigma[1], self.sigma[0])
293        # Determine the kernel size based on the sigma value.
294        kernel_size = int(2 * np.ceil(3 * sigma) + 1)
295        if isinstance(img, np.ndarray):
296            img = torch.from_numpy(img)
297
298        return transforms.GaussianBlur(kernel_size, sigma=sigma)(img)

Transformation to blur the image with a randomly drawn sigma value.

Arguments:
  • sigma: The sigma value for the transformation. The value used in the transformation will be uniformly drawn from the range specified here.
GaussianBlur(sigma: Tuple[float, float] = (0.0, 3.0))
279    def __init__(self, sigma: Tuple[float, float] = (0.0, 3.0)):
280        self.sigma = sigma
sigma
class RawTransform:
305class RawTransform:
306    """The transformation for raw data during training.
307
308    Args:
309        normalizer: The normalization function.
310        augmentation1: Intensity augmentation applied before the normalization.
311        augmentation2: Intensity augmentation applied after the normalization.
312    """
313    def __init__(
314        self, normalizer: Callable, augmentation1: Optional[Callable] = None, augmentation2: Optional[Callable] = None
315    ):
316        self.normalizer = normalizer
317        self.augmentation1 = augmentation1
318        self.augmentation2 = augmentation2
319
320    def __call__(self, raw: np.ndarray) -> np.ndarray:
321        """Apply the raw transformation.
322
323        Args:
324            raw: The raw data.
325
326        Returns:
327            The transformed raw data.
328        """
329        if self.augmentation1 is not None:
330            raw = self.augmentation1(raw)
331
332        raw = self.normalizer(raw)
333
334        if self.augmentation2 is not None:
335            raw = self.augmentation2(raw)
336        return raw

The transformation for raw data during training.

Arguments:
  • normalizer: The normalization function.
  • augmentation1: Intensity augmentation applied before the normalization.
  • augmentation2: Intensity augmentation applied after the normalization.
RawTransform( normalizer: Callable, augmentation1: Optional[Callable] = None, augmentation2: Optional[Callable] = None)
313    def __init__(
314        self, normalizer: Callable, augmentation1: Optional[Callable] = None, augmentation2: Optional[Callable] = None
315    ):
316        self.normalizer = normalizer
317        self.augmentation1 = augmentation1
318        self.augmentation2 = augmentation2
normalizer
augmentation1
augmentation2
def get_raw_transform( normalizer: Callable = <function standardize>, augmentation1: Optional[Callable] = None, augmentation2: Optional[Callable] = None) -> Callable:
339def get_raw_transform(
340    normalizer: Callable = standardize,
341    augmentation1: Optional[Callable] = None,
342    augmentation2: Optional[Callable] = None
343) -> Callable:
344    """Get the raw transformation.
345
346    Args:
347        normalizer: The normalization function.
348        augmentation1: Intensity augmentation applied before the normalization.
349        augmentation2: Intensity augmentation applied after the normalization.
350
351    Returns:
352        The raw transformation.
353    """
354    return RawTransform(normalizer, augmentation1=augmentation1, augmentation2=augmentation2)

Get the raw transformation.

Arguments:
  • normalizer: The normalization function.
  • augmentation1: Intensity augmentation applied before the normalization.
  • augmentation2: Intensity augmentation applied after the normalization.
Returns:

The raw transformation.

def get_default_mean_teacher_augmentations( p: float = 0.3, norm: Optional[Callable] = None, blur_kwargs: Optional[Dict] = None, poisson_kwargs: Optional[Dict] = None, gaussian_kwargs: Optional[Dict] = None) -> Callable:
357def get_default_mean_teacher_augmentations(
358    p: float = 0.3,
359    norm: Optional[Callable] = None,
360    blur_kwargs: Optional[Dict] = None,
361    poisson_kwargs: Optional[Dict] = None,
362    gaussian_kwargs: Optional[Dict] = None,
363) -> Callable:
364    """Get the default augmentations for mean teacher training.
365
366    The default values for the augmentations are designed for an image with pixel values in range [0, 1].
367    By default, a normalization transformation is applied for this reason.
368
369    Args:
370        p: The probability for applying the individual intensity transformations.
371        norm: The noromaization function.
372        blur_kwargs: The keyword arguments for `GaussianBlur`.
373        poisson_kwargs: The keyword arguments for `PoissonNoise`.
374        gaussian_kwargs: The keyword arguments for `AdditiveGaussianNoise`.
375
376    Returns:
377        The raw transformation with augmentations.
378    """
379    if norm is None:
380        norm = normalize
381
382    aug1 = transforms.Compose([
383        norm,
384        transforms.RandomApply([GaussianBlur(**({} if blur_kwargs is None else blur_kwargs))], p=p),
385        transforms.RandomApply([PoissonNoise(**({} if poisson_kwargs is None else poisson_kwargs))], p=p/2),
386        transforms.RandomApply([AdditiveGaussianNoise(**({} if gaussian_kwargs is None else gaussian_kwargs))], p=p/2),
387    ])
388
389    aug2 = transforms.RandomApply([RandomContrast(clip_kwargs={"a_min": 0, "a_max": 1})], p=p)
390    return get_raw_transform(normalizer=norm, augmentation1=aug1, augmentation2=aug2)

Get the default augmentations for mean teacher training.

The default values for the augmentations are designed for an image with pixel values in range [0, 1]. By default, a normalization transformation is applied for this reason.

Arguments:
  • p: The probability for applying the individual intensity transformations.
  • norm: The noromaization function.
  • blur_kwargs: The keyword arguments for GaussianBlur.
  • poisson_kwargs: The keyword arguments for PoissonNoise.
  • gaussian_kwargs: The keyword arguments for AdditiveGaussianNoise.
Returns:

The raw transformation with augmentations.