torch_em.transform.raw
1import numpy as np 2import torch 3from torchvision import transforms 4 5 6# 7# normalization functions 8# 9 10 11TORCH_DTYPES = { 12 "float16": torch.float16, 13 "float32": torch.float32, 14 "float64": torch.float64, 15 "complex64": torch.complex64, 16 "complex128": torch.complex128, 17 "uint8": torch.uint8, 18 "int8": torch.int8, 19 "int16": torch.int16, 20 "int32": torch.int32, 21 "int64": torch.int64, 22 "bool": torch.bool, 23} 24 25 26def cast(inpt, typestring): 27 if torch.is_tensor(inpt): 28 assert typestring in TORCH_DTYPES, f"{typestring} not in TORCH_DTYPES" 29 return inpt.to(TORCH_DTYPES[typestring]) 30 return inpt.astype(typestring) 31 32 33def standardize(raw, mean=None, std=None, axis=None, eps=1e-7): 34 raw = cast(raw, "float32") 35 36 mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean 37 raw -= mean 38 39 std = raw.std(axis=axis, keepdims=True) if std is None else std 40 raw /= (std + eps) 41 42 return raw 43 44 45def _normalize_torch(tensor, minval, maxval, axis, eps): 46 if axis: # torch returns torch.return_types.min or torch.return_types.max 47 minval = torch.amin(tensor, dim=axis, keepdim=True) if minval is None else minval 48 tensor -= minval 49 50 maxval = torch.amax(tensor, dim=axis, keepdim=True) if maxval is None else maxval 51 tensor /= (maxval + eps) 52 53 return tensor 54 55 # keepdim can only be used in combination with dim 56 minval = tensor.min() if minval is None else minval 57 tensor -= minval 58 59 maxval = tensor.max() if maxval is None else maxval 60 tensor /= (maxval + eps) 61 62 return tensor 63 64 65def normalize(raw, minval=None, maxval=None, axis=None, eps=1e-7): 66 raw = cast(raw, "float32") 67 68 if torch.is_tensor(raw): 69 return _normalize_torch(raw, minval=minval, maxval=maxval, axis=axis, eps=eps) 70 71 minval = raw.min(axis=axis, keepdims=True) if minval is None else minval 72 raw -= minval 73 74 maxval = raw.max(axis=axis, keepdims=True) if maxval is None else maxval 75 raw /= (maxval + eps) 76 77 return raw 78 79 80def normalize_percentile(raw, lower=1.0, upper=99.0, axis=None, eps=1e-7): 81 v_lower = np.percentile(raw, lower, axis=axis, keepdims=True) 82 v_upper = np.percentile(raw, upper, axis=axis, keepdims=True) - v_lower 83 return normalize(raw, v_lower, v_upper, eps=eps) 84 85 86# 87# intensity augmentations / noise augmentations 88# 89 90# modified from https://github.com/kreshuklab/spoco/blob/main/spoco/transforms.py 91class RandomContrast(): 92 """ 93 Adjust contrast by scaling image to `mean + alpha * (image - mean)`. 94 """ 95 def __init__(self, alpha=(0.5, 2), mean=0.5, clip_kwargs={'a_min': 0, 'a_max': 1}): 96 self.alpha = alpha 97 self.mean = mean 98 self.clip_kwargs = clip_kwargs 99 100 def __call__(self, img): 101 alpha = np.random.uniform(self.alpha[0], self.alpha[1]) 102 result = self.mean + alpha * (img - self.mean) 103 if self.clip_kwargs: 104 return np.clip(result, **self.clip_kwargs) 105 return result 106 107 108class AdditiveGaussianNoise(): 109 """ 110 Add random Gaussian noise to image. 111 """ 112 def __init__(self, scale=(0.0, 0.3), clip_kwargs={'a_min': 0, 'a_max': 1}): 113 self.scale = scale 114 self.clip_kwargs = clip_kwargs 115 116 def __call__(self, img): 117 std = np.random.uniform(self.scale[0], self.scale[1]) 118 gaussian_noise = np.random.normal(0, std, size=img.shape) 119 if self.clip_kwargs: 120 return np.clip(img + gaussian_noise, 0, 1) 121 return img + gaussian_noise 122 123 124class AdditivePoissonNoise(): 125 """ 126 Add random Poisson noise to image. 127 """ 128 # TODO: not sure if Poisson noise like this does make sense 129 # for data that is already normalized 130 def __init__(self, lam=(0.0, 0.1), clip_kwargs={'a_min': 0, 'a_max': 1}): 131 self.lam = lam 132 self.clip_kwargs = clip_kwargs 133 134 def __call__(self, img): 135 lam = np.random.uniform(self.lam[0], self.lam[1]) 136 poisson_noise = np.random.poisson(lam, size=img.shape) / lam 137 if self.clip_kwargs: 138 return np.clip(img + poisson_noise, 0, 1) 139 return img + poisson_noise 140 141 142class PoissonNoise(): 143 """ 144 Add random data-dependent Poisson noise to image. 145 """ 146 def __init__(self, multiplier=(5.0, 10.0), clip_kwargs={'a_min': 0, 'a_max': 1}): 147 self.multiplier = multiplier 148 self.clip_kwargs = clip_kwargs 149 150 def __call__(self, img): 151 multiplier = np.random.uniform(self.multiplier[0], self.multiplier[1]) 152 offset = img.min() 153 poisson_noise = np.random.poisson((img - offset) * multiplier) 154 if isinstance(img, torch.Tensor): 155 poisson_noise = torch.Tensor(poisson_noise) 156 poisson_noise = poisson_noise / multiplier + offset 157 if self.clip_kwargs: 158 return np.clip(poisson_noise, **self.clip_kwargs) 159 return poisson_noise 160 161 162class GaussianBlur(): 163 """ 164 Blur the image. 165 """ 166 def __init__(self, kernel_size=(2, 12), sigma=(0, 2.5)): 167 self.kernel_size = kernel_size 168 self.sigma = sigma 169 170 def __call__(self, img): 171 # sample kernel_size and make sure it is odd 172 kernel_size = 2 * (np.random.randint(self.kernel_size[0], self.kernel_size[1]) // 2) + 1 173 # switch boundaries to make sure 0 is excluded from sampling 174 sigma = np.random.uniform(self.sigma[1], self.sigma[0]) 175 if isinstance(img, np.ndarray): 176 img = torch.from_numpy(img) 177 out = transforms.GaussianBlur(kernel_size, sigma=sigma)(img) 178 return out 179 180 181# 182# default transformation: 183# apply intensity augmentations and normalize 184# 185 186class RawTransform: 187 def __init__(self, normalizer, augmentation1=None, augmentation2=None): 188 self.normalizer = normalizer 189 self.augmentation1 = augmentation1 190 self.augmentation2 = augmentation2 191 192 def __call__(self, raw): 193 if self.augmentation1 is not None: 194 raw = self.augmentation1(raw) 195 raw = self.normalizer(raw) 196 if self.augmentation2 is not None: 197 raw = self.augmentation2(raw) 198 return raw 199 200 201def get_raw_transform(normalizer=standardize, augmentation1=None, augmentation2=None): 202 return RawTransform(normalizer, 203 augmentation1=augmentation1, 204 augmentation2=augmentation2) 205 206 207# The default values are made for an image with pixel values in 208# range [0, 1]. That the image is in this range is ensured by an 209# initial normalizations step. 210def get_default_mean_teacher_augmentations( 211 p=0.3, norm=None, 212 blur_kwargs=None, poisson_kwargs=None, gaussian_kwargs=None 213): 214 if norm is None: 215 norm = normalize 216 aug1 = transforms.Compose([ 217 norm, 218 transforms.RandomApply([GaussianBlur(**({} if blur_kwargs is None else blur_kwargs))], p=p), 219 transforms.RandomApply([PoissonNoise(**({} if poisson_kwargs is None else poisson_kwargs))], p=p/2), 220 transforms.RandomApply([AdditiveGaussianNoise(**({} if gaussian_kwargs is None else gaussian_kwargs))], p=p/2), 221 ]) 222 aug2 = transforms.RandomApply( 223 [RandomContrast(clip_kwargs={"a_min": 0, "a_max": 1})], p=p 224 ) 225 return get_raw_transform( 226 normalizer=norm, 227 augmentation1=aug1, 228 augmentation2=aug2 229 )
TORCH_DTYPES =
{'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'complex64': torch.complex64, 'complex128': torch.complex128, 'uint8': torch.uint8, 'int8': torch.int8, 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, 'bool': torch.bool}
def
cast(inpt, typestring):
def
standardize(raw, mean=None, std=None, axis=None, eps=1e-07):
def
normalize(raw, minval=None, maxval=None, axis=None, eps=1e-07):
66def normalize(raw, minval=None, maxval=None, axis=None, eps=1e-7): 67 raw = cast(raw, "float32") 68 69 if torch.is_tensor(raw): 70 return _normalize_torch(raw, minval=minval, maxval=maxval, axis=axis, eps=eps) 71 72 minval = raw.min(axis=axis, keepdims=True) if minval is None else minval 73 raw -= minval 74 75 maxval = raw.max(axis=axis, keepdims=True) if maxval is None else maxval 76 raw /= (maxval + eps) 77 78 return raw
def
normalize_percentile(raw, lower=1.0, upper=99.0, axis=None, eps=1e-07):
class
RandomContrast:
92class RandomContrast(): 93 """ 94 Adjust contrast by scaling image to `mean + alpha * (image - mean)`. 95 """ 96 def __init__(self, alpha=(0.5, 2), mean=0.5, clip_kwargs={'a_min': 0, 'a_max': 1}): 97 self.alpha = alpha 98 self.mean = mean 99 self.clip_kwargs = clip_kwargs 100 101 def __call__(self, img): 102 alpha = np.random.uniform(self.alpha[0], self.alpha[1]) 103 result = self.mean + alpha * (img - self.mean) 104 if self.clip_kwargs: 105 return np.clip(result, **self.clip_kwargs) 106 return result
Adjust contrast by scaling image to mean + alpha * (image - mean)
.
class
AdditiveGaussianNoise:
109class AdditiveGaussianNoise(): 110 """ 111 Add random Gaussian noise to image. 112 """ 113 def __init__(self, scale=(0.0, 0.3), clip_kwargs={'a_min': 0, 'a_max': 1}): 114 self.scale = scale 115 self.clip_kwargs = clip_kwargs 116 117 def __call__(self, img): 118 std = np.random.uniform(self.scale[0], self.scale[1]) 119 gaussian_noise = np.random.normal(0, std, size=img.shape) 120 if self.clip_kwargs: 121 return np.clip(img + gaussian_noise, 0, 1) 122 return img + gaussian_noise
Add random Gaussian noise to image.
class
AdditivePoissonNoise:
125class AdditivePoissonNoise(): 126 """ 127 Add random Poisson noise to image. 128 """ 129 # TODO: not sure if Poisson noise like this does make sense 130 # for data that is already normalized 131 def __init__(self, lam=(0.0, 0.1), clip_kwargs={'a_min': 0, 'a_max': 1}): 132 self.lam = lam 133 self.clip_kwargs = clip_kwargs 134 135 def __call__(self, img): 136 lam = np.random.uniform(self.lam[0], self.lam[1]) 137 poisson_noise = np.random.poisson(lam, size=img.shape) / lam 138 if self.clip_kwargs: 139 return np.clip(img + poisson_noise, 0, 1) 140 return img + poisson_noise
Add random Poisson noise to image.
class
PoissonNoise:
143class PoissonNoise(): 144 """ 145 Add random data-dependent Poisson noise to image. 146 """ 147 def __init__(self, multiplier=(5.0, 10.0), clip_kwargs={'a_min': 0, 'a_max': 1}): 148 self.multiplier = multiplier 149 self.clip_kwargs = clip_kwargs 150 151 def __call__(self, img): 152 multiplier = np.random.uniform(self.multiplier[0], self.multiplier[1]) 153 offset = img.min() 154 poisson_noise = np.random.poisson((img - offset) * multiplier) 155 if isinstance(img, torch.Tensor): 156 poisson_noise = torch.Tensor(poisson_noise) 157 poisson_noise = poisson_noise / multiplier + offset 158 if self.clip_kwargs: 159 return np.clip(poisson_noise, **self.clip_kwargs) 160 return poisson_noise
Add random data-dependent Poisson noise to image.
class
GaussianBlur:
163class GaussianBlur(): 164 """ 165 Blur the image. 166 """ 167 def __init__(self, kernel_size=(2, 12), sigma=(0, 2.5)): 168 self.kernel_size = kernel_size 169 self.sigma = sigma 170 171 def __call__(self, img): 172 # sample kernel_size and make sure it is odd 173 kernel_size = 2 * (np.random.randint(self.kernel_size[0], self.kernel_size[1]) // 2) + 1 174 # switch boundaries to make sure 0 is excluded from sampling 175 sigma = np.random.uniform(self.sigma[1], self.sigma[0]) 176 if isinstance(img, np.ndarray): 177 img = torch.from_numpy(img) 178 out = transforms.GaussianBlur(kernel_size, sigma=sigma)(img) 179 return out
Blur the image.
class
RawTransform:
187class RawTransform: 188 def __init__(self, normalizer, augmentation1=None, augmentation2=None): 189 self.normalizer = normalizer 190 self.augmentation1 = augmentation1 191 self.augmentation2 = augmentation2 192 193 def __call__(self, raw): 194 if self.augmentation1 is not None: 195 raw = self.augmentation1(raw) 196 raw = self.normalizer(raw) 197 if self.augmentation2 is not None: 198 raw = self.augmentation2(raw) 199 return raw
def
get_raw_transform( normalizer=<function standardize>, augmentation1=None, augmentation2=None):
def
get_default_mean_teacher_augmentations( p=0.3, norm=None, blur_kwargs=None, poisson_kwargs=None, gaussian_kwargs=None):
211def get_default_mean_teacher_augmentations( 212 p=0.3, norm=None, 213 blur_kwargs=None, poisson_kwargs=None, gaussian_kwargs=None 214): 215 if norm is None: 216 norm = normalize 217 aug1 = transforms.Compose([ 218 norm, 219 transforms.RandomApply([GaussianBlur(**({} if blur_kwargs is None else blur_kwargs))], p=p), 220 transforms.RandomApply([PoissonNoise(**({} if poisson_kwargs is None else poisson_kwargs))], p=p/2), 221 transforms.RandomApply([AdditiveGaussianNoise(**({} if gaussian_kwargs is None else gaussian_kwargs))], p=p/2), 222 ]) 223 aug2 = transforms.RandomApply( 224 [RandomContrast(clip_kwargs={"a_min": 0, "a_max": 1})], p=p 225 ) 226 return get_raw_transform( 227 normalizer=norm, 228 augmentation1=aug1, 229 augmentation2=aug2 230 )