torch_em.transform.generic
1from typing import Any, Dict, Optional, Sequence, Union 2 3import numpy as np 4from skimage.transform import rescale, resize 5 6import torch 7 8 9class Tile(torch.nn.Module): 10 _params = None 11 12 def __init__(self, reps: Sequence[int] = (2,), match_shape_exactly: bool = True): 13 super().__init__() 14 self.reps = reps 15 self.match_shape_exactly = match_shape_exactly 16 17 def forward(self, input: Union[torch.Tensor, np.ndarray], params: Optional[Dict[str, Any]] = None): 18 assert not self.match_shape_exactly or len(input.shape) == len(self.reps), (input.shape, self.reps) 19 if isinstance(input, torch.Tensor): 20 # return torch.tile(input, self.reps) # todo: use torch.tile (for pytorch >=1.8?) 21 reps = list(self.reps) 22 for _ in range(max(0, len(input.shape) - len(reps))): 23 reps.insert(0, 1) 24 25 for _ in range(max(0, len(reps) - len(input.shape))): 26 input = input.unsqueeze(0) 27 28 return input.repeat(*reps) 29 elif isinstance(input, np.ndarray): 30 return np.tile(input, self.reps) 31 else: 32 raise NotImplementedError(type(input)) 33 34 35# a simple way to compose transforms 36class Compose: 37 def __init__(self, *transforms): 38 self.transforms = transforms 39 40 def __call__(self, *inputs): 41 outputs = self.transforms[0](*inputs) 42 for trafo in self.transforms[1:]: 43 outputs = trafo(*outputs) 44 return outputs 45 46 47class Rescale: 48 def __init__(self, scale, with_channels=None): 49 self.scale = scale 50 self.with_channels = with_channels 51 52 def _rescale_with_channels(self, input_, **kwargs): 53 out = [rescale(inp, **kwargs)[None] for inp in input_] 54 return np.concatenate(out, axis=0) 55 56 def __call__(self, *inputs): 57 if self.with_channels is None: 58 outputs = tuple(rescale(inp, scale=self.scale, preserve_range=True) for inp in inputs) 59 else: 60 if isinstance(self.with_channels, (tuple, list)): 61 assert len(self.with_channels) == len(inputs) 62 with_channels = self.with_channels 63 else: 64 with_channels = [self.with_channels] * len(inputs) 65 outputs = tuple( 66 self._rescale_with_channels(inp, scale=self.scale, preserve_range=True) if wc else 67 rescale(inp, scale=self.scale, preserve_range=True) 68 for inp, wc in zip(inputs, with_channels) 69 ) 70 if len(outputs) == 1: 71 return outputs[0] 72 return outputs 73 74 75class ResizeInputs: 76 def __init__(self, target_shape, is_label=False, is_rgb=False): 77 self.target_shape = target_shape 78 self.is_label = is_label 79 self.is_rgb = is_rgb 80 81 def __call__(self, inputs): 82 if self.is_label: # kwargs needed for int data 83 kwargs = {"order": 0, "anti_aliasing": False} 84 else: # we use the default settings for float data 85 kwargs = {} 86 87 if self.is_rgb: 88 assert inputs.ndim == 3 and inputs.shape[0] == 3 89 patch_shape = (3, *self.target_shape) 90 else: 91 patch_shape = self.target_shape 92 93 inputs = resize( 94 image=inputs, 95 output_shape=patch_shape, 96 preserve_range=True, 97 **kwargs 98 ).astype(inputs.dtype) 99 100 return inputs 101 102 103class PadIfNecessary: 104 def __init__(self, shape): 105 self.shape = tuple(shape) 106 107 def _pad_if_necessary(self, data): 108 if data.ndim == len(self.shape): 109 pad_shape = self.shape 110 else: 111 dim_diff = data.ndim - len(self.shape) 112 pad_shape = data.shape[:dim_diff] + self.shape 113 assert len(pad_shape) == data.ndim 114 115 data_shape = data.shape 116 if all(dsh == sh for dsh, sh in zip(data_shape, pad_shape)): 117 return data 118 119 pad_width = [sh - dsh for dsh, sh in zip(data_shape, pad_shape)] 120 assert all(pw >= 0 for pw in pad_width) 121 pad_width = [(0, pw) for pw in pad_width] 122 return np.pad(data, pad_width, mode="reflect") 123 124 def __call__(self, *inputs): 125 outputs = tuple(self._pad_if_necessary(input_) for input_ in inputs) 126 if len(outputs) == 1: 127 return outputs[0] 128 return outputs
10class Tile(torch.nn.Module): 11 _params = None 12 13 def __init__(self, reps: Sequence[int] = (2,), match_shape_exactly: bool = True): 14 super().__init__() 15 self.reps = reps 16 self.match_shape_exactly = match_shape_exactly 17 18 def forward(self, input: Union[torch.Tensor, np.ndarray], params: Optional[Dict[str, Any]] = None): 19 assert not self.match_shape_exactly or len(input.shape) == len(self.reps), (input.shape, self.reps) 20 if isinstance(input, torch.Tensor): 21 # return torch.tile(input, self.reps) # todo: use torch.tile (for pytorch >=1.8?) 22 reps = list(self.reps) 23 for _ in range(max(0, len(input.shape) - len(reps))): 24 reps.insert(0, 1) 25 26 for _ in range(max(0, len(reps) - len(input.shape))): 27 input = input.unsqueeze(0) 28 29 return input.repeat(*reps) 30 elif isinstance(input, np.ndarray): 31 return np.tile(input, self.reps) 32 else: 33 raise NotImplementedError(type(input))
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
13 def __init__(self, reps: Sequence[int] = (2,), match_shape_exactly: bool = True): 14 super().__init__() 15 self.reps = reps 16 self.match_shape_exactly = match_shape_exactly
Initializes internal Module state, shared by both nn.Module and ScriptModule.
18 def forward(self, input: Union[torch.Tensor, np.ndarray], params: Optional[Dict[str, Any]] = None): 19 assert not self.match_shape_exactly or len(input.shape) == len(self.reps), (input.shape, self.reps) 20 if isinstance(input, torch.Tensor): 21 # return torch.tile(input, self.reps) # todo: use torch.tile (for pytorch >=1.8?) 22 reps = list(self.reps) 23 for _ in range(max(0, len(input.shape) - len(reps))): 24 reps.insert(0, 1) 25 26 for _ in range(max(0, len(reps) - len(input.shape))): 27 input = input.unsqueeze(0) 28 29 return input.repeat(*reps) 30 elif isinstance(input, np.ndarray): 31 return np.tile(input, self.reps) 32 else: 33 raise NotImplementedError(type(input))
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
48class Rescale: 49 def __init__(self, scale, with_channels=None): 50 self.scale = scale 51 self.with_channels = with_channels 52 53 def _rescale_with_channels(self, input_, **kwargs): 54 out = [rescale(inp, **kwargs)[None] for inp in input_] 55 return np.concatenate(out, axis=0) 56 57 def __call__(self, *inputs): 58 if self.with_channels is None: 59 outputs = tuple(rescale(inp, scale=self.scale, preserve_range=True) for inp in inputs) 60 else: 61 if isinstance(self.with_channels, (tuple, list)): 62 assert len(self.with_channels) == len(inputs) 63 with_channels = self.with_channels 64 else: 65 with_channels = [self.with_channels] * len(inputs) 66 outputs = tuple( 67 self._rescale_with_channels(inp, scale=self.scale, preserve_range=True) if wc else 68 rescale(inp, scale=self.scale, preserve_range=True) 69 for inp, wc in zip(inputs, with_channels) 70 ) 71 if len(outputs) == 1: 72 return outputs[0] 73 return outputs
76class ResizeInputs: 77 def __init__(self, target_shape, is_label=False, is_rgb=False): 78 self.target_shape = target_shape 79 self.is_label = is_label 80 self.is_rgb = is_rgb 81 82 def __call__(self, inputs): 83 if self.is_label: # kwargs needed for int data 84 kwargs = {"order": 0, "anti_aliasing": False} 85 else: # we use the default settings for float data 86 kwargs = {} 87 88 if self.is_rgb: 89 assert inputs.ndim == 3 and inputs.shape[0] == 3 90 patch_shape = (3, *self.target_shape) 91 else: 92 patch_shape = self.target_shape 93 94 inputs = resize( 95 image=inputs, 96 output_shape=patch_shape, 97 preserve_range=True, 98 **kwargs 99 ).astype(inputs.dtype) 100 101 return inputs
104class PadIfNecessary: 105 def __init__(self, shape): 106 self.shape = tuple(shape) 107 108 def _pad_if_necessary(self, data): 109 if data.ndim == len(self.shape): 110 pad_shape = self.shape 111 else: 112 dim_diff = data.ndim - len(self.shape) 113 pad_shape = data.shape[:dim_diff] + self.shape 114 assert len(pad_shape) == data.ndim 115 116 data_shape = data.shape 117 if all(dsh == sh for dsh, sh in zip(data_shape, pad_shape)): 118 return data 119 120 pad_width = [sh - dsh for dsh, sh in zip(data_shape, pad_shape)] 121 assert all(pw >= 0 for pw in pad_width) 122 pad_width = [(0, pw) for pw in pad_width] 123 return np.pad(data, pad_width, mode="reflect") 124 125 def __call__(self, *inputs): 126 outputs = tuple(self._pad_if_necessary(input_) for input_ in inputs) 127 if len(outputs) == 1: 128 return outputs[0] 129 return outputs