torch_em.transform.raw

  1import numpy as np
  2import torch
  3from torchvision import transforms
  4
  5
  6#
  7# normalization functions
  8#
  9
 10
 11TORCH_DTYPES = {
 12    "float16": torch.float16,
 13    "float32": torch.float32,
 14    "float64": torch.float64,
 15    "complex64": torch.complex64,
 16    "complex128": torch.complex128,
 17    "uint8": torch.uint8,
 18    "int8": torch.int8,
 19    "int16": torch.int16,
 20    "int32": torch.int32,
 21    "int64": torch.int64,
 22    "bool": torch.bool,
 23}
 24
 25
 26def cast(inpt, typestring):
 27    if torch.is_tensor(inpt):
 28        assert typestring in TORCH_DTYPES, f"{typestring} not in TORCH_DTYPES"
 29        return inpt.to(TORCH_DTYPES[typestring])
 30    return inpt.astype(typestring)
 31
 32
 33def standardize(raw, mean=None, std=None, axis=None, eps=1e-7):
 34    raw = cast(raw, "float32")
 35
 36    mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean
 37    raw -= mean
 38
 39    std = raw.std(axis=axis, keepdims=True) if std is None else std
 40    raw /= (std + eps)
 41
 42    return raw
 43
 44
 45def _normalize_torch(tensor, minval, maxval, axis, eps):
 46    if axis:  # torch returns torch.return_types.min or torch.return_types.max
 47        minval = torch.amin(tensor, dim=axis, keepdim=True) if minval is None else minval
 48        tensor -= minval
 49
 50        maxval = torch.amax(tensor, dim=axis, keepdim=True) if maxval is None else maxval
 51        tensor /= (maxval + eps)
 52
 53        return tensor
 54
 55    # keepdim can only be used in combination with dim
 56    minval = tensor.min() if minval is None else minval
 57    tensor -= minval
 58
 59    maxval = tensor.max() if maxval is None else maxval
 60    tensor /= (maxval + eps)
 61
 62    return tensor
 63
 64
 65def normalize(raw, minval=None, maxval=None, axis=None, eps=1e-7):
 66    raw = cast(raw, "float32")
 67
 68    if torch.is_tensor(raw):
 69        return _normalize_torch(raw, minval=minval, maxval=maxval, axis=axis, eps=eps)
 70
 71    minval = raw.min(axis=axis, keepdims=True) if minval is None else minval
 72    raw -= minval
 73
 74    maxval = raw.max(axis=axis, keepdims=True) if maxval is None else maxval
 75    raw /= (maxval + eps)
 76
 77    return raw
 78
 79
 80def normalize_percentile(raw, lower=1.0, upper=99.0, axis=None, eps=1e-7):
 81    v_lower = np.percentile(raw, lower, axis=axis, keepdims=True)
 82    v_upper = np.percentile(raw, upper, axis=axis, keepdims=True) - v_lower
 83    return normalize(raw, v_lower, v_upper, eps=eps)
 84
 85
 86#
 87# intensity augmentations / noise augmentations
 88#
 89
 90# modified from https://github.com/kreshuklab/spoco/blob/main/spoco/transforms.py
 91class RandomContrast():
 92    """
 93    Adjust contrast by scaling image to `mean + alpha * (image - mean)`.
 94    """
 95    def __init__(self, alpha=(0.5, 2), mean=0.5, clip_kwargs={'a_min': 0, 'a_max': 1}):
 96        self.alpha = alpha
 97        self.mean = mean
 98        self.clip_kwargs = clip_kwargs
 99
100    def __call__(self, img):
101        alpha = np.random.uniform(self.alpha[0], self.alpha[1])
102        result = self.mean + alpha * (img - self.mean)
103        if self.clip_kwargs:
104            return np.clip(result, **self.clip_kwargs)
105        return result
106
107
108class AdditiveGaussianNoise():
109    """
110    Add random Gaussian noise to image.
111    """
112    def __init__(self, scale=(0.0, 0.3), clip_kwargs={'a_min': 0, 'a_max': 1}):
113        self.scale = scale
114        self.clip_kwargs = clip_kwargs
115
116    def __call__(self, img):
117        std = np.random.uniform(self.scale[0], self.scale[1])
118        gaussian_noise = np.random.normal(0, std, size=img.shape)
119        if self.clip_kwargs:
120            return np.clip(img + gaussian_noise, 0, 1)
121        return img + gaussian_noise
122
123
124class AdditivePoissonNoise():
125    """
126    Add random Poisson noise to image.
127    """
128    # TODO: not sure if Poisson noise like this does make sense
129    # for data that is already normalized
130    def __init__(self, lam=(0.0, 0.1), clip_kwargs={'a_min': 0, 'a_max': 1}):
131        self.lam = lam
132        self.clip_kwargs = clip_kwargs
133
134    def __call__(self, img):
135        lam = np.random.uniform(self.lam[0], self.lam[1])
136        poisson_noise = np.random.poisson(lam, size=img.shape) / lam
137        if self.clip_kwargs:
138            return np.clip(img + poisson_noise, 0, 1)
139        return img + poisson_noise
140
141
142class PoissonNoise():
143    """
144    Add random data-dependent Poisson noise to image.
145    """
146    def __init__(self, multiplier=(5.0, 10.0), clip_kwargs={'a_min': 0, 'a_max': 1}):
147        self.multiplier = multiplier
148        self.clip_kwargs = clip_kwargs
149
150    def __call__(self, img):
151        multiplier = np.random.uniform(self.multiplier[0], self.multiplier[1])
152        offset = img.min()
153        poisson_noise = np.random.poisson((img - offset) * multiplier)
154        if isinstance(img, torch.Tensor):
155            poisson_noise = torch.Tensor(poisson_noise)
156        poisson_noise = poisson_noise / multiplier + offset
157        if self.clip_kwargs:
158            return np.clip(poisson_noise, **self.clip_kwargs)
159        return poisson_noise
160
161
162class GaussianBlur():
163    """
164    Blur the image.
165    """
166    def __init__(self, kernel_size=(2, 12), sigma=(0, 2.5)):
167        self.kernel_size = kernel_size
168        self.sigma = sigma
169
170    def __call__(self, img):
171        # sample kernel_size and make sure it is odd
172        kernel_size = 2 * (np.random.randint(self.kernel_size[0], self.kernel_size[1]) // 2) + 1
173        # switch boundaries to make sure 0 is excluded from sampling
174        sigma = np.random.uniform(self.sigma[1], self.sigma[0])
175        if isinstance(img, np.ndarray):
176            img = torch.from_numpy(img)
177        out = transforms.GaussianBlur(kernel_size, sigma=sigma)(img)
178        return out
179
180
181#
182# default transformation:
183# apply intensity augmentations and normalize
184#
185
186class RawTransform:
187    def __init__(self, normalizer, augmentation1=None, augmentation2=None):
188        self.normalizer = normalizer
189        self.augmentation1 = augmentation1
190        self.augmentation2 = augmentation2
191
192    def __call__(self, raw):
193        if self.augmentation1 is not None:
194            raw = self.augmentation1(raw)
195        raw = self.normalizer(raw)
196        if self.augmentation2 is not None:
197            raw = self.augmentation2(raw)
198        return raw
199
200
201def get_raw_transform(normalizer=standardize, augmentation1=None, augmentation2=None):
202    return RawTransform(normalizer,
203                        augmentation1=augmentation1,
204                        augmentation2=augmentation2)
205
206
207# The default values are made for an image with pixel values in
208# range [0, 1]. That the image is in this range is ensured by an
209# initial normalizations step.
210def get_default_mean_teacher_augmentations(
211    p=0.3, norm=None,
212    blur_kwargs=None, poisson_kwargs=None, gaussian_kwargs=None
213):
214    if norm is None:
215        norm = normalize
216    aug1 = transforms.Compose([
217        norm,
218        transforms.RandomApply([GaussianBlur(**({} if blur_kwargs is None else blur_kwargs))], p=p),
219        transforms.RandomApply([PoissonNoise(**({} if poisson_kwargs is None else poisson_kwargs))], p=p/2),
220        transforms.RandomApply([AdditiveGaussianNoise(**({} if gaussian_kwargs is None else gaussian_kwargs))], p=p/2),
221    ])
222    aug2 = transforms.RandomApply(
223        [RandomContrast(clip_kwargs={"a_min": 0, "a_max": 1})], p=p
224    )
225    return get_raw_transform(
226        normalizer=norm,
227        augmentation1=aug1,
228        augmentation2=aug2
229    )
TORCH_DTYPES = {'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'complex64': torch.complex64, 'complex128': torch.complex128, 'uint8': torch.uint8, 'int8': torch.int8, 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, 'bool': torch.bool}
def cast(inpt, typestring):
27def cast(inpt, typestring):
28    if torch.is_tensor(inpt):
29        assert typestring in TORCH_DTYPES, f"{typestring} not in TORCH_DTYPES"
30        return inpt.to(TORCH_DTYPES[typestring])
31    return inpt.astype(typestring)
def standardize(raw, mean=None, std=None, axis=None, eps=1e-07):
34def standardize(raw, mean=None, std=None, axis=None, eps=1e-7):
35    raw = cast(raw, "float32")
36
37    mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean
38    raw -= mean
39
40    std = raw.std(axis=axis, keepdims=True) if std is None else std
41    raw /= (std + eps)
42
43    return raw
def normalize(raw, minval=None, maxval=None, axis=None, eps=1e-07):
66def normalize(raw, minval=None, maxval=None, axis=None, eps=1e-7):
67    raw = cast(raw, "float32")
68
69    if torch.is_tensor(raw):
70        return _normalize_torch(raw, minval=minval, maxval=maxval, axis=axis, eps=eps)
71
72    minval = raw.min(axis=axis, keepdims=True) if minval is None else minval
73    raw -= minval
74
75    maxval = raw.max(axis=axis, keepdims=True) if maxval is None else maxval
76    raw /= (maxval + eps)
77
78    return raw
def normalize_percentile(raw, lower=1.0, upper=99.0, axis=None, eps=1e-07):
81def normalize_percentile(raw, lower=1.0, upper=99.0, axis=None, eps=1e-7):
82    v_lower = np.percentile(raw, lower, axis=axis, keepdims=True)
83    v_upper = np.percentile(raw, upper, axis=axis, keepdims=True) - v_lower
84    return normalize(raw, v_lower, v_upper, eps=eps)
class RandomContrast:
 92class RandomContrast():
 93    """
 94    Adjust contrast by scaling image to `mean + alpha * (image - mean)`.
 95    """
 96    def __init__(self, alpha=(0.5, 2), mean=0.5, clip_kwargs={'a_min': 0, 'a_max': 1}):
 97        self.alpha = alpha
 98        self.mean = mean
 99        self.clip_kwargs = clip_kwargs
100
101    def __call__(self, img):
102        alpha = np.random.uniform(self.alpha[0], self.alpha[1])
103        result = self.mean + alpha * (img - self.mean)
104        if self.clip_kwargs:
105            return np.clip(result, **self.clip_kwargs)
106        return result

Adjust contrast by scaling image to mean + alpha * (image - mean).

RandomContrast(alpha=(0.5, 2), mean=0.5, clip_kwargs={'a_min': 0, 'a_max': 1})
96    def __init__(self, alpha=(0.5, 2), mean=0.5, clip_kwargs={'a_min': 0, 'a_max': 1}):
97        self.alpha = alpha
98        self.mean = mean
99        self.clip_kwargs = clip_kwargs
alpha
mean
clip_kwargs
class AdditiveGaussianNoise:
109class AdditiveGaussianNoise():
110    """
111    Add random Gaussian noise to image.
112    """
113    def __init__(self, scale=(0.0, 0.3), clip_kwargs={'a_min': 0, 'a_max': 1}):
114        self.scale = scale
115        self.clip_kwargs = clip_kwargs
116
117    def __call__(self, img):
118        std = np.random.uniform(self.scale[0], self.scale[1])
119        gaussian_noise = np.random.normal(0, std, size=img.shape)
120        if self.clip_kwargs:
121            return np.clip(img + gaussian_noise, 0, 1)
122        return img + gaussian_noise

Add random Gaussian noise to image.

AdditiveGaussianNoise(scale=(0.0, 0.3), clip_kwargs={'a_min': 0, 'a_max': 1})
113    def __init__(self, scale=(0.0, 0.3), clip_kwargs={'a_min': 0, 'a_max': 1}):
114        self.scale = scale
115        self.clip_kwargs = clip_kwargs
scale
clip_kwargs
class AdditivePoissonNoise:
125class AdditivePoissonNoise():
126    """
127    Add random Poisson noise to image.
128    """
129    # TODO: not sure if Poisson noise like this does make sense
130    # for data that is already normalized
131    def __init__(self, lam=(0.0, 0.1), clip_kwargs={'a_min': 0, 'a_max': 1}):
132        self.lam = lam
133        self.clip_kwargs = clip_kwargs
134
135    def __call__(self, img):
136        lam = np.random.uniform(self.lam[0], self.lam[1])
137        poisson_noise = np.random.poisson(lam, size=img.shape) / lam
138        if self.clip_kwargs:
139            return np.clip(img + poisson_noise, 0, 1)
140        return img + poisson_noise

Add random Poisson noise to image.

AdditivePoissonNoise(lam=(0.0, 0.1), clip_kwargs={'a_min': 0, 'a_max': 1})
131    def __init__(self, lam=(0.0, 0.1), clip_kwargs={'a_min': 0, 'a_max': 1}):
132        self.lam = lam
133        self.clip_kwargs = clip_kwargs
lam
clip_kwargs
class PoissonNoise:
143class PoissonNoise():
144    """
145    Add random data-dependent Poisson noise to image.
146    """
147    def __init__(self, multiplier=(5.0, 10.0), clip_kwargs={'a_min': 0, 'a_max': 1}):
148        self.multiplier = multiplier
149        self.clip_kwargs = clip_kwargs
150
151    def __call__(self, img):
152        multiplier = np.random.uniform(self.multiplier[0], self.multiplier[1])
153        offset = img.min()
154        poisson_noise = np.random.poisson((img - offset) * multiplier)
155        if isinstance(img, torch.Tensor):
156            poisson_noise = torch.Tensor(poisson_noise)
157        poisson_noise = poisson_noise / multiplier + offset
158        if self.clip_kwargs:
159            return np.clip(poisson_noise, **self.clip_kwargs)
160        return poisson_noise

Add random data-dependent Poisson noise to image.

PoissonNoise(multiplier=(5.0, 10.0), clip_kwargs={'a_min': 0, 'a_max': 1})
147    def __init__(self, multiplier=(5.0, 10.0), clip_kwargs={'a_min': 0, 'a_max': 1}):
148        self.multiplier = multiplier
149        self.clip_kwargs = clip_kwargs
multiplier
clip_kwargs
class GaussianBlur:
163class GaussianBlur():
164    """
165    Blur the image.
166    """
167    def __init__(self, kernel_size=(2, 12), sigma=(0, 2.5)):
168        self.kernel_size = kernel_size
169        self.sigma = sigma
170
171    def __call__(self, img):
172        # sample kernel_size and make sure it is odd
173        kernel_size = 2 * (np.random.randint(self.kernel_size[0], self.kernel_size[1]) // 2) + 1
174        # switch boundaries to make sure 0 is excluded from sampling
175        sigma = np.random.uniform(self.sigma[1], self.sigma[0])
176        if isinstance(img, np.ndarray):
177            img = torch.from_numpy(img)
178        out = transforms.GaussianBlur(kernel_size, sigma=sigma)(img)
179        return out

Blur the image.

GaussianBlur(kernel_size=(2, 12), sigma=(0, 2.5))
167    def __init__(self, kernel_size=(2, 12), sigma=(0, 2.5)):
168        self.kernel_size = kernel_size
169        self.sigma = sigma
kernel_size
sigma
class RawTransform:
187class RawTransform:
188    def __init__(self, normalizer, augmentation1=None, augmentation2=None):
189        self.normalizer = normalizer
190        self.augmentation1 = augmentation1
191        self.augmentation2 = augmentation2
192
193    def __call__(self, raw):
194        if self.augmentation1 is not None:
195            raw = self.augmentation1(raw)
196        raw = self.normalizer(raw)
197        if self.augmentation2 is not None:
198            raw = self.augmentation2(raw)
199        return raw
RawTransform(normalizer, augmentation1=None, augmentation2=None)
188    def __init__(self, normalizer, augmentation1=None, augmentation2=None):
189        self.normalizer = normalizer
190        self.augmentation1 = augmentation1
191        self.augmentation2 = augmentation2
normalizer
augmentation1
augmentation2
def get_raw_transform( normalizer=<function standardize>, augmentation1=None, augmentation2=None):
202def get_raw_transform(normalizer=standardize, augmentation1=None, augmentation2=None):
203    return RawTransform(normalizer,
204                        augmentation1=augmentation1,
205                        augmentation2=augmentation2)
def get_default_mean_teacher_augmentations( p=0.3, norm=None, blur_kwargs=None, poisson_kwargs=None, gaussian_kwargs=None):
211def get_default_mean_teacher_augmentations(
212    p=0.3, norm=None,
213    blur_kwargs=None, poisson_kwargs=None, gaussian_kwargs=None
214):
215    if norm is None:
216        norm = normalize
217    aug1 = transforms.Compose([
218        norm,
219        transforms.RandomApply([GaussianBlur(**({} if blur_kwargs is None else blur_kwargs))], p=p),
220        transforms.RandomApply([PoissonNoise(**({} if poisson_kwargs is None else poisson_kwargs))], p=p/2),
221        transforms.RandomApply([AdditiveGaussianNoise(**({} if gaussian_kwargs is None else gaussian_kwargs))], p=p/2),
222    ])
223    aug2 = transforms.RandomApply(
224        [RandomContrast(clip_kwargs={"a_min": 0, "a_max": 1})], p=p
225    )
226    return get_raw_transform(
227        normalizer=norm,
228        augmentation1=aug1,
229        augmentation2=aug2
230    )