torch_em.transform.generic

  1from typing import Any, Dict, Optional, Sequence, Union
  2
  3import numpy as np
  4from skimage.transform import rescale, resize
  5
  6import torch
  7
  8
  9class Tile(torch.nn.Module):
 10    _params = None
 11
 12    def __init__(self, reps: Sequence[int] = (2,), match_shape_exactly: bool = True):
 13        super().__init__()
 14        self.reps = reps
 15        self.match_shape_exactly = match_shape_exactly
 16
 17    def forward(self, input: Union[torch.Tensor, np.ndarray], params: Optional[Dict[str, Any]] = None):
 18        assert not self.match_shape_exactly or len(input.shape) == len(self.reps), (input.shape, self.reps)
 19        if isinstance(input, torch.Tensor):
 20            # return torch.tile(input, self.reps)  # todo: use torch.tile (for pytorch >=1.8?)
 21            reps = list(self.reps)
 22            for _ in range(max(0, len(input.shape) - len(reps))):
 23                reps.insert(0, 1)
 24
 25            for _ in range(max(0, len(reps) - len(input.shape))):
 26                input = input.unsqueeze(0)
 27
 28            return input.repeat(*reps)
 29        elif isinstance(input, np.ndarray):
 30            return np.tile(input, self.reps)
 31        else:
 32            raise NotImplementedError(type(input))
 33
 34
 35# a simple way to compose transforms
 36class Compose:
 37    def __init__(self, *transforms):
 38        self.transforms = transforms
 39
 40    def __call__(self, *inputs):
 41        outputs = self.transforms[0](*inputs)
 42        for trafo in self.transforms[1:]:
 43            outputs = trafo(*outputs)
 44        return outputs
 45
 46
 47class Rescale:
 48    def __init__(self, scale, with_channels=None):
 49        self.scale = scale
 50        self.with_channels = with_channels
 51
 52    def _rescale_with_channels(self, input_, **kwargs):
 53        out = [rescale(inp, **kwargs)[None] for inp in input_]
 54        return np.concatenate(out, axis=0)
 55
 56    def __call__(self, *inputs):
 57        if self.with_channels is None:
 58            outputs = tuple(rescale(inp, scale=self.scale, preserve_range=True) for inp in inputs)
 59        else:
 60            if isinstance(self.with_channels, (tuple, list)):
 61                assert len(self.with_channels) == len(inputs)
 62                with_channels = self.with_channels
 63            else:
 64                with_channels = [self.with_channels] * len(inputs)
 65            outputs = tuple(
 66                self._rescale_with_channels(inp, scale=self.scale, preserve_range=True) if wc else
 67                rescale(inp, scale=self.scale, preserve_range=True)
 68                for inp, wc in zip(inputs, with_channels)
 69            )
 70        if len(outputs) == 1:
 71            return outputs[0]
 72        return outputs
 73
 74
 75class ResizeInputs:
 76    def __init__(self, target_shape, is_label=False, is_rgb=False):
 77        self.target_shape = target_shape
 78        self.is_label = is_label
 79        self.is_rgb = is_rgb
 80
 81    def __call__(self, inputs):
 82        if self.is_label:  # kwargs needed for int data
 83            kwargs = {"order": 0,  "anti_aliasing": False}
 84        else:  # we use the default settings for float data
 85            kwargs = {}
 86
 87        if self.is_rgb:
 88            assert inputs.ndim == 3 and inputs.shape[0] == 3
 89            patch_shape = (3, *self.target_shape)
 90        else:
 91            patch_shape = self.target_shape
 92
 93        inputs = resize(
 94            image=inputs,
 95            output_shape=patch_shape,
 96            preserve_range=True,
 97            **kwargs
 98        ).astype(inputs.dtype)
 99
100        return inputs
101
102
103class PadIfNecessary:
104    def __init__(self, shape):
105        self.shape = tuple(shape)
106
107    def _pad_if_necessary(self, data):
108        if data.ndim == len(self.shape):
109            pad_shape = self.shape
110        else:
111            dim_diff = data.ndim - len(self.shape)
112            pad_shape = data.shape[:dim_diff] + self.shape
113            assert len(pad_shape) == data.ndim
114
115        data_shape = data.shape
116        if all(dsh == sh for dsh, sh in zip(data_shape, pad_shape)):
117            return data
118
119        pad_width = [sh - dsh for dsh, sh in zip(data_shape, pad_shape)]
120        assert all(pw >= 0 for pw in pad_width)
121        pad_width = [(0, pw) for pw in pad_width]
122        return np.pad(data, pad_width, mode="reflect")
123
124    def __call__(self, *inputs):
125        outputs = tuple(self._pad_if_necessary(input_) for input_ in inputs)
126        if len(outputs) == 1:
127            return outputs[0]
128        return outputs
class Tile(torch.nn.modules.module.Module):
10class Tile(torch.nn.Module):
11    _params = None
12
13    def __init__(self, reps: Sequence[int] = (2,), match_shape_exactly: bool = True):
14        super().__init__()
15        self.reps = reps
16        self.match_shape_exactly = match_shape_exactly
17
18    def forward(self, input: Union[torch.Tensor, np.ndarray], params: Optional[Dict[str, Any]] = None):
19        assert not self.match_shape_exactly or len(input.shape) == len(self.reps), (input.shape, self.reps)
20        if isinstance(input, torch.Tensor):
21            # return torch.tile(input, self.reps)  # todo: use torch.tile (for pytorch >=1.8?)
22            reps = list(self.reps)
23            for _ in range(max(0, len(input.shape) - len(reps))):
24                reps.insert(0, 1)
25
26            for _ in range(max(0, len(reps) - len(input.shape))):
27                input = input.unsqueeze(0)
28
29            return input.repeat(*reps)
30        elif isinstance(input, np.ndarray):
31            return np.tile(input, self.reps)
32        else:
33            raise NotImplementedError(type(input))

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Tile(reps: Sequence[int] = (2,), match_shape_exactly: bool = True)
13    def __init__(self, reps: Sequence[int] = (2,), match_shape_exactly: bool = True):
14        super().__init__()
15        self.reps = reps
16        self.match_shape_exactly = match_shape_exactly

Initializes internal Module state, shared by both nn.Module and ScriptModule.

reps
match_shape_exactly
def forward( self, input: Union[torch.Tensor, numpy.ndarray], params: Optional[Dict[str, Any]] = None):
18    def forward(self, input: Union[torch.Tensor, np.ndarray], params: Optional[Dict[str, Any]] = None):
19        assert not self.match_shape_exactly or len(input.shape) == len(self.reps), (input.shape, self.reps)
20        if isinstance(input, torch.Tensor):
21            # return torch.tile(input, self.reps)  # todo: use torch.tile (for pytorch >=1.8?)
22            reps = list(self.reps)
23            for _ in range(max(0, len(input.shape) - len(reps))):
24                reps.insert(0, 1)
25
26            for _ in range(max(0, len(reps) - len(input.shape))):
27                input = input.unsqueeze(0)
28
29            return input.repeat(*reps)
30        elif isinstance(input, np.ndarray):
31            return np.tile(input, self.reps)
32        else:
33            raise NotImplementedError(type(input))

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class Compose:
37class Compose:
38    def __init__(self, *transforms):
39        self.transforms = transforms
40
41    def __call__(self, *inputs):
42        outputs = self.transforms[0](*inputs)
43        for trafo in self.transforms[1:]:
44            outputs = trafo(*outputs)
45        return outputs
Compose(*transforms)
38    def __init__(self, *transforms):
39        self.transforms = transforms
transforms
class Rescale:
48class Rescale:
49    def __init__(self, scale, with_channels=None):
50        self.scale = scale
51        self.with_channels = with_channels
52
53    def _rescale_with_channels(self, input_, **kwargs):
54        out = [rescale(inp, **kwargs)[None] for inp in input_]
55        return np.concatenate(out, axis=0)
56
57    def __call__(self, *inputs):
58        if self.with_channels is None:
59            outputs = tuple(rescale(inp, scale=self.scale, preserve_range=True) for inp in inputs)
60        else:
61            if isinstance(self.with_channels, (tuple, list)):
62                assert len(self.with_channels) == len(inputs)
63                with_channels = self.with_channels
64            else:
65                with_channels = [self.with_channels] * len(inputs)
66            outputs = tuple(
67                self._rescale_with_channels(inp, scale=self.scale, preserve_range=True) if wc else
68                rescale(inp, scale=self.scale, preserve_range=True)
69                for inp, wc in zip(inputs, with_channels)
70            )
71        if len(outputs) == 1:
72            return outputs[0]
73        return outputs
Rescale(scale, with_channels=None)
49    def __init__(self, scale, with_channels=None):
50        self.scale = scale
51        self.with_channels = with_channels
scale
with_channels
class ResizeInputs:
 76class ResizeInputs:
 77    def __init__(self, target_shape, is_label=False, is_rgb=False):
 78        self.target_shape = target_shape
 79        self.is_label = is_label
 80        self.is_rgb = is_rgb
 81
 82    def __call__(self, inputs):
 83        if self.is_label:  # kwargs needed for int data
 84            kwargs = {"order": 0,  "anti_aliasing": False}
 85        else:  # we use the default settings for float data
 86            kwargs = {}
 87
 88        if self.is_rgb:
 89            assert inputs.ndim == 3 and inputs.shape[0] == 3
 90            patch_shape = (3, *self.target_shape)
 91        else:
 92            patch_shape = self.target_shape
 93
 94        inputs = resize(
 95            image=inputs,
 96            output_shape=patch_shape,
 97            preserve_range=True,
 98            **kwargs
 99        ).astype(inputs.dtype)
100
101        return inputs
ResizeInputs(target_shape, is_label=False, is_rgb=False)
77    def __init__(self, target_shape, is_label=False, is_rgb=False):
78        self.target_shape = target_shape
79        self.is_label = is_label
80        self.is_rgb = is_rgb
target_shape
is_label
is_rgb
class PadIfNecessary:
104class PadIfNecessary:
105    def __init__(self, shape):
106        self.shape = tuple(shape)
107
108    def _pad_if_necessary(self, data):
109        if data.ndim == len(self.shape):
110            pad_shape = self.shape
111        else:
112            dim_diff = data.ndim - len(self.shape)
113            pad_shape = data.shape[:dim_diff] + self.shape
114            assert len(pad_shape) == data.ndim
115
116        data_shape = data.shape
117        if all(dsh == sh for dsh, sh in zip(data_shape, pad_shape)):
118            return data
119
120        pad_width = [sh - dsh for dsh, sh in zip(data_shape, pad_shape)]
121        assert all(pw >= 0 for pw in pad_width)
122        pad_width = [(0, pw) for pw in pad_width]
123        return np.pad(data, pad_width, mode="reflect")
124
125    def __call__(self, *inputs):
126        outputs = tuple(self._pad_if_necessary(input_) for input_ in inputs)
127        if len(outputs) == 1:
128            return outputs[0]
129        return outputs
PadIfNecessary(shape)
105    def __init__(self, shape):
106        self.shape = tuple(shape)
shape