torch_em.model.unet

  1import numpy as np
  2import torch
  3import torch.nn as nn
  4
  5
  6#
  7# Model Internal Post-processing
  8#
  9# Note: these are mainly for bioimage.io models, where postprocessing has to be done
 10# inside of the model unless its defined in the general spec
 11
 12
 13# TODO think about more (multicut-friendly) boundary postprocessing
 14# e.g. max preserving smoothing: bd = np.maximum(bd, gaussian(bd, sigma=1))
 15class AccumulateChannels(nn.Module):
 16    def __init__(
 17        self,
 18        invariant_channels,
 19        accumulate_channels,
 20        accumulator
 21    ):
 22        super().__init__()
 23        self.invariant_channels = invariant_channels
 24        self.accumulate_channels = accumulate_channels
 25        assert accumulator in ("mean", "min", "max")
 26        self.accumulator = getattr(torch, accumulator)
 27
 28    def _accumulate(self, x, c0, c1):
 29        res = self.accumulator(x[:, c0:c1], dim=1, keepdim=True)
 30        if not torch.is_tensor(res):
 31            res = res.values
 32        assert torch.is_tensor(res)
 33        return res
 34
 35    def forward(self, x):
 36        if self.invariant_channels is None:
 37            c0, c1 = self.accumulate_channels
 38            return self._accumulate(x, c0, c1)
 39        else:
 40            i0, i1 = self.invariant_channels
 41            c0, c1 = self.accumulate_channels
 42            return torch.cat([x[:, i0:i1], self._accumulate(x, c0, c1)], dim=1)
 43
 44
 45def affinities_to_boundaries(aff_channels, accumulator="max"):
 46    return AccumulateChannels(None, aff_channels, accumulator)
 47
 48
 49def affinities_with_foreground_to_boundaries(aff_channels, fg_channel=(0, 1), accumulator="max"):
 50    return AccumulateChannels(fg_channel, aff_channels, accumulator)
 51
 52
 53def affinities_to_boundaries2d():
 54    return affinities_to_boundaries((0, 2))
 55
 56
 57def affinities_with_foreground_to_boundaries2d():
 58    return affinities_with_foreground_to_boundaries((1, 3))
 59
 60
 61def affinities_to_boundaries3d():
 62    return affinities_to_boundaries((0, 3))
 63
 64
 65def affinities_with_foreground_to_boundaries3d():
 66    return affinities_with_foreground_to_boundaries((1, 4))
 67
 68
 69def affinities_to_boundaries_anisotropic():
 70    return AccumulateChannels(None, (1, 3), "max")
 71
 72
 73POSTPROCESSING = {
 74    "affinities_to_boundaries_anisotropic": affinities_to_boundaries_anisotropic,
 75    "affinities_to_boundaries2d": affinities_to_boundaries2d,
 76    "affinities_with_foreground_to_boundaries2d": affinities_with_foreground_to_boundaries2d,
 77    "affinities_to_boundaries3d": affinities_to_boundaries3d,
 78    "affinities_with_foreground_to_boundaries3d": affinities_with_foreground_to_boundaries3d,
 79}
 80
 81
 82#
 83# Base Implementations
 84#
 85
 86class UNetBase(nn.Module):
 87    """
 88    """
 89    def __init__(
 90        self,
 91        encoder,
 92        base,
 93        decoder,
 94        out_conv=None,
 95        final_activation=None,
 96        postprocessing=None,
 97        check_shape=True,
 98    ):
 99        super().__init__()
100        if len(encoder) != len(decoder):
101            raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})")
102
103        self.encoder = encoder
104        self.base = base
105        self.decoder = decoder
106
107        if out_conv is None:
108            self.return_decoder_outputs = False
109            self._out_channels = self.decoder.out_channels
110        elif isinstance(out_conv, nn.ModuleList):
111            if len(out_conv) != len(self.decoder):
112                raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}")
113            self.return_decoder_outputs = True
114            self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv]
115        else:
116            self.return_decoder_outputs = False
117            self._out_channels = out_conv.out_channels
118        self.out_conv = out_conv
119        self.check_shape = check_shape
120        self.final_activation = self._get_activation(final_activation)
121        self.postprocessing = self._get_postprocessing(postprocessing)
122
123    @property
124    def in_channels(self):
125        return self.encoder.in_channels
126
127    @property
128    def out_channels(self):
129        return self._out_channels
130
131    @property
132    def depth(self):
133        return len(self.encoder)
134
135    def _get_activation(self, activation):
136        return_activation = None
137        if activation is None:
138            return None
139        if isinstance(activation, nn.Module):
140            return activation
141        if isinstance(activation, str):
142            return_activation = getattr(nn, activation, None)
143        if return_activation is None:
144            raise ValueError(f"Invalid activation: {activation}")
145        return return_activation()
146
147    def _get_postprocessing(self, postprocessing):
148        if postprocessing is None:
149            return None
150        elif isinstance(postprocessing, nn.Module):
151            return postprocessing
152        elif postprocessing in POSTPROCESSING:
153            return POSTPROCESSING[postprocessing]()
154        else:
155            raise ValueError(f"Invalid postprocessing: {postprocessing}")
156
157    # load encoder / decoder / base states for pretraining
158    def load_encoder_state(self, state):
159        self.encoder.load_state_dict(state)
160
161    def load_decoder_state(self, state):
162        self.decoder.load_state_dict(state)
163
164    def load_base_state(self, state):
165        self.base.load_state_dict(state)
166
167    def _apply_default(self, x):
168        self.encoder.return_outputs = True
169        self.decoder.return_outputs = False
170
171        x, encoder_out = self.encoder(x)
172        x = self.base(x)
173        x = self.decoder(x, encoder_inputs=encoder_out[::-1])
174
175        if self.out_conv is not None:
176            x = self.out_conv(x)
177        if self.final_activation is not None:
178            x = self.final_activation(x)
179        if self.postprocessing is not None:
180            x = self.postprocessing(x)
181
182        return x
183
184    def _apply_with_side_outputs(self, x):
185        self.encoder.return_outputs = True
186        self.decoder.return_outputs = True
187
188        x, encoder_out = self.encoder(x)
189        x = self.base(x)
190        x = self.decoder(x, encoder_inputs=encoder_out[::-1])
191
192        x = [x if conv is None else conv(xx) for xx, conv in zip(x, self.out_conv)]
193        if self.final_activation is not None:
194            x = [self.final_activation(xx) for xx in x]
195
196        if self.postprocessing is not None:
197            x = [self.postprocessing(xx) for xx in x]
198
199        # we reverse the list to have the full shape output as first element
200        return x[::-1]
201
202    def _check_shape(self, x):
203        spatial_shape = tuple(x.shape)[2:]
204        depth = len(self.encoder)
205        factor = [2**depth] * len(spatial_shape)
206        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
207            msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}"
208            raise ValueError(msg)
209
210    def forward(self, x):
211        # cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference
212        # x = x.float()
213        if getattr(self, "check_shape", True):
214            self._check_shape(x)
215        if self.return_decoder_outputs:
216            return self._apply_with_side_outputs(x)
217        else:
218            return self._apply_default(x)
219
220
221def _update_conv_kwargs(kwargs, scale_factor):
222    # if the scale factor is a scalar or all entries are the same we don"t need to update the kwargs
223    if isinstance(scale_factor, int) or scale_factor.count(scale_factor[0]) == len(scale_factor):
224        return kwargs
225    else:  # otherwise set anisotropic kernel
226        kernel_size = kwargs.get("kernel_size", 3)
227        padding = kwargs.get("padding", 1)
228
229        # bail out if kernel size or padding aren"t scalars, because it"s
230        # unclear what to do in this case
231        if not (isinstance(kernel_size, int) and isinstance(padding, int)):
232            return kwargs
233
234        kernel_size = tuple(1 if factor == 1 else kernel_size for factor in scale_factor)
235        padding = tuple(0 if factor == 1 else padding for factor in scale_factor)
236        kwargs.update({"kernel_size": kernel_size, "padding": padding})
237        return kwargs
238
239
240class Encoder(nn.Module):
241    def __init__(
242        self,
243        features,
244        scale_factors,
245        conv_block_impl,
246        pooler_impl,
247        anisotropic_kernel=False,
248        **conv_block_kwargs
249    ):
250        super().__init__()
251        if len(features) != len(scale_factors) + 1:
252            raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}")
253
254        conv_kwargs = [conv_block_kwargs] * len(scale_factors)
255        if anisotropic_kernel:
256            conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor)
257                           for kwargs, scale_factor in zip(conv_kwargs, scale_factors)]
258
259        self.blocks = nn.ModuleList(
260            [conv_block_impl(inc, outc, **kwargs)
261             for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)]
262        )
263        self.poolers = nn.ModuleList(
264            [pooler_impl(factor) for factor in scale_factors]
265        )
266        self.return_outputs = True
267
268        self.in_channels = features[0]
269        self.out_channels = features[-1]
270
271    def __len__(self):
272        return len(self.blocks)
273
274    def forward(self, x):
275        encoder_out = []
276        for block, pooler in zip(self.blocks, self.poolers):
277            x = block(x)
278            encoder_out.append(x)
279            x = pooler(x)
280
281        if self.return_outputs:
282            return x, encoder_out
283        else:
284            return x
285
286
287class Decoder(nn.Module):
288    def __init__(
289        self,
290        features,
291        scale_factors,
292        conv_block_impl,
293        sampler_impl,
294        anisotropic_kernel=False,
295        **conv_block_kwargs
296    ):
297        super().__init__()
298        if len(features) != len(scale_factors) + 1:
299            raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}")
300
301        conv_kwargs = [conv_block_kwargs] * len(scale_factors)
302        if anisotropic_kernel:
303            conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor)
304                           for kwargs, scale_factor in zip(conv_kwargs, scale_factors)]
305
306        self.blocks = nn.ModuleList(
307            [conv_block_impl(inc, outc, **kwargs)
308             for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)]
309        )
310        self.samplers = nn.ModuleList(
311            [sampler_impl(factor, inc, outc) for factor, inc, outc
312             in zip(scale_factors, features[:-1], features[1:])]
313        )
314        self.return_outputs = False
315
316        self.in_channels = features[0]
317        self.out_channels = features[-1]
318
319    def __len__(self):
320        return len(self.blocks)
321
322    # FIXME this prevents traces from being valid for other input sizes, need to find
323    # a solution to traceable cropping
324    def _crop(self, x, shape):
325        shape_diff = [(xsh - sh) // 2 for xsh, sh in zip(x.shape, shape)]
326        crop = tuple([slice(sd, xsh - sd) for sd, xsh in zip(shape_diff, x.shape)])
327        return x[crop]
328        # # Implementation with torch.narrow, does not fix the tracing warnings!
329        # for dim, (sh, sd) in enumerate(zip(shape, shape_diff)):
330        #     x = torch.narrow(x, dim, sd, sh)
331        # return x
332
333    def _concat(self, x1, x2):
334        return torch.cat([x1, self._crop(x2, x1.shape)], dim=1)
335
336    def forward(self, x, encoder_inputs):
337        if len(encoder_inputs) != len(self.blocks):
338            raise ValueError(f"Invalid number of encoder_inputs: expect {len(self.blocks)}, got {len(encoder_inputs)}")
339
340        decoder_out = []
341        for block, sampler, from_encoder in zip(self.blocks, self.samplers, encoder_inputs):
342            x = sampler(x)
343            x = block(self._concat(x, from_encoder))
344            decoder_out.append(x)
345
346        if self.return_outputs:
347            return decoder_out + [x]
348        else:
349            return x
350
351
352def get_norm_layer(norm, dim, channels, n_groups=32):
353    if norm is None:
354        return None
355    if norm == "InstanceNorm":
356        kwargs = {"affine": True, "track_running_stats": True, "momentum": 0.01}
357        return nn.InstanceNorm2d(channels, **kwargs) if dim == 2 else nn.InstanceNorm3d(channels, **kwargs)
358    elif norm == "OldDefault":
359        return nn.InstanceNorm2d(channels) if dim == 2 else nn.InstanceNorm3d(channels)
360    elif norm == "GroupNorm":
361        return nn.GroupNorm(min(n_groups, channels), channels)
362    elif norm == "BatchNorm":
363        return nn.BatchNorm2d(channels) if dim == 2 else nn.BatchNorm3d(channels)
364    else:
365        raise ValueError(f"Invalid norm: expect one of 'InstanceNorm', 'BatchNorm' or 'GroupNorm', got {norm}")
366
367
368class ConvBlock(nn.Module):
369    def __init__(self, in_channels, out_channels, dim,
370                 kernel_size=3, padding=1, norm="InstanceNorm"):
371        super().__init__()
372        self.in_channels = in_channels
373        self.out_channels = out_channels
374
375        conv = nn.Conv2d if dim == 2 else nn.Conv3d
376
377        if norm is None:
378            self.block = nn.Sequential(
379                conv(in_channels, out_channels,
380                     kernel_size=kernel_size, padding=padding),
381                nn.ReLU(inplace=True),
382                conv(out_channels, out_channels,
383                     kernel_size=kernel_size, padding=padding),
384                nn.ReLU(inplace=True)
385            )
386        else:
387            self.block = nn.Sequential(
388                get_norm_layer(norm, dim, in_channels),
389                conv(in_channels, out_channels,
390                     kernel_size=kernel_size, padding=padding),
391                nn.ReLU(inplace=True),
392                get_norm_layer(norm, dim, out_channels),
393                conv(out_channels, out_channels,
394                     kernel_size=kernel_size, padding=padding),
395                nn.ReLU(inplace=True)
396            )
397
398    def forward(self, x):
399        return self.block(x)
400
401
402class Upsampler(nn.Module):
403    def __init__(self, scale_factor,
404                 in_channels, out_channels,
405                 dim, mode):
406        super().__init__()
407        self.mode = mode
408        self.scale_factor = scale_factor
409
410        conv = nn.Conv2d if dim == 2 else nn.Conv3d
411        self.conv = conv(in_channels, out_channels, 1)
412
413    def forward(self, x):
414        x = nn.functional.interpolate(x, scale_factor=self.scale_factor,
415                                      mode=self.mode, align_corners=False)
416        x = self.conv(x)
417        return x
418
419
420#
421# 2d unet implementations
422#
423
424class ConvBlock2d(ConvBlock):
425    def __init__(self, in_channels, out_channels, **kwargs):
426        super().__init__(in_channels, out_channels, dim=2, **kwargs)
427
428
429class Upsampler2d(Upsampler):
430    def __init__(self, scale_factor,
431                 in_channels, out_channels,
432                 mode="bilinear"):
433        super().__init__(scale_factor, in_channels, out_channels,
434                         dim=2, mode=mode)
435
436
437class UNet2d(UNetBase):
438    def __init__(
439        self,
440        in_channels,
441        out_channels,
442        depth=4,
443        initial_features=32,
444        gain=2,
445        final_activation=None,
446        return_side_outputs=False,
447        conv_block_impl=ConvBlock2d,
448        pooler_impl=nn.MaxPool2d,
449        sampler_impl=Upsampler2d,
450        postprocessing=None,
451        check_shape=True,
452        **conv_block_kwargs,
453    ):
454        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
455        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
456        scale_factors = depth * [2]
457
458        if return_side_outputs:
459            if isinstance(out_channels, int) or out_channels is None:
460                out_channels = [out_channels] * depth
461            if len(out_channels) != depth:
462                raise ValueError()
463            out_conv = nn.ModuleList(
464                [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
465            )
466        else:
467            out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1)
468
469        super().__init__(
470            encoder=Encoder(
471                features=features_encoder,
472                scale_factors=scale_factors,
473                conv_block_impl=conv_block_impl,
474                pooler_impl=pooler_impl,
475                **conv_block_kwargs
476            ),
477            decoder=Decoder(
478                features=features_decoder,
479                scale_factors=scale_factors[::-1],
480                conv_block_impl=conv_block_impl,
481                sampler_impl=sampler_impl,
482                **conv_block_kwargs
483            ),
484            base=conv_block_impl(
485                features_encoder[-1], features_encoder[-1] * gain,
486                **conv_block_kwargs
487            ),
488            out_conv=out_conv,
489            final_activation=final_activation,
490            postprocessing=postprocessing,
491            check_shape=check_shape,
492        )
493        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
494                            "initial_features": initial_features, "gain": gain,
495                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
496                            "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl,
497                            "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs}
498
499
500#
501# 3d unet implementations
502#
503
504class ConvBlock3d(ConvBlock):
505    def __init__(self, in_channels, out_channels, **kwargs):
506        super().__init__(in_channels, out_channels, dim=3, **kwargs)
507
508
509class Upsampler3d(Upsampler):
510    def __init__(self, scale_factor,
511                 in_channels, out_channels,
512                 mode="trilinear"):
513        super().__init__(scale_factor, in_channels, out_channels,
514                         dim=3, mode=mode)
515
516
517class AnisotropicUNet(UNetBase):
518    def __init__(
519        self,
520        in_channels,
521        out_channels,
522        scale_factors,
523        initial_features=32,
524        gain=2,
525        final_activation=None,
526        return_side_outputs=False,
527        conv_block_impl=ConvBlock3d,
528        anisotropic_kernel=False,  # TODO benchmark which option is better and set as default
529        postprocessing=None,
530        check_shape=True,
531        **conv_block_kwargs,
532    ):
533        depth = len(scale_factors)
534        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
535        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
536
537        if return_side_outputs:
538            if isinstance(out_channels, int) or out_channels is None:
539                out_channels = [out_channels] * depth
540            if len(out_channels) != depth:
541                raise ValueError()
542            out_conv = nn.ModuleList(
543                [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
544            )
545        else:
546            out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1)
547
548        super().__init__(
549            encoder=Encoder(
550                features=features_encoder,
551                scale_factors=scale_factors,
552                conv_block_impl=conv_block_impl,
553                pooler_impl=nn.MaxPool3d,
554                anisotropic_kernel=anisotropic_kernel,
555                **conv_block_kwargs
556            ),
557            decoder=Decoder(
558                features=features_decoder,
559                scale_factors=scale_factors[::-1],
560                conv_block_impl=conv_block_impl,
561                sampler_impl=Upsampler3d,
562                anisotropic_kernel=anisotropic_kernel,
563                **conv_block_kwargs
564            ),
565            base=conv_block_impl(
566                features_encoder[-1], features_encoder[-1] * gain,
567                **conv_block_kwargs
568            ),
569            out_conv=out_conv,
570            final_activation=final_activation,
571            postprocessing=postprocessing,
572            check_shape=check_shape,
573        )
574        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors,
575                            "initial_features": initial_features, "gain": gain,
576                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
577                            "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel,
578                            "postprocessing": postprocessing, **conv_block_kwargs}
579
580    def _check_shape(self, x):
581        spatial_shape = tuple(x.shape)[2:]
582        scale_factors = self.init_kwargs.get("scale_factors", [[2, 2, 2]]*len(self.encoder))
583        factor = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
584        if len(spatial_shape) != len(factor):
585            msg = f"Invalid shape for U-Net: dimensions don't agree {len(spatial_shape)} != {len(factor)}"
586            raise ValueError(msg)
587        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
588            msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}"
589            raise ValueError(msg)
590
591
592class UNet3d(AnisotropicUNet):
593    def __init__(
594        self,
595        in_channels,
596        out_channels,
597        depth=4,
598        initial_features=32,
599        gain=2,
600        final_activation=None,
601        return_side_outputs=False,
602        conv_block_impl=ConvBlock3d,
603        postprocessing=None,
604        check_shape=True,
605        **conv_block_kwargs,
606    ):
607        scale_factors = depth * [2]
608        super().__init__(in_channels, out_channels, scale_factors,
609                         initial_features=initial_features, gain=gain,
610                         final_activation=final_activation,
611                         return_side_outputs=return_side_outputs,
612                         anisotropic_kernel=False,
613                         postprocessing=postprocessing,
614                         conv_block_impl=conv_block_impl,
615                         check_shape=check_shape,
616                         **conv_block_kwargs)
617        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
618                            "initial_features": initial_features, "gain": gain,
619                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
620                            "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}
class AccumulateChannels(torch.nn.modules.module.Module):
16class AccumulateChannels(nn.Module):
17    def __init__(
18        self,
19        invariant_channels,
20        accumulate_channels,
21        accumulator
22    ):
23        super().__init__()
24        self.invariant_channels = invariant_channels
25        self.accumulate_channels = accumulate_channels
26        assert accumulator in ("mean", "min", "max")
27        self.accumulator = getattr(torch, accumulator)
28
29    def _accumulate(self, x, c0, c1):
30        res = self.accumulator(x[:, c0:c1], dim=1, keepdim=True)
31        if not torch.is_tensor(res):
32            res = res.values
33        assert torch.is_tensor(res)
34        return res
35
36    def forward(self, x):
37        if self.invariant_channels is None:
38            c0, c1 = self.accumulate_channels
39            return self._accumulate(x, c0, c1)
40        else:
41            i0, i1 = self.invariant_channels
42            c0, c1 = self.accumulate_channels
43            return torch.cat([x[:, i0:i1], self._accumulate(x, c0, c1)], dim=1)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AccumulateChannels(invariant_channels, accumulate_channels, accumulator)
17    def __init__(
18        self,
19        invariant_channels,
20        accumulate_channels,
21        accumulator
22    ):
23        super().__init__()
24        self.invariant_channels = invariant_channels
25        self.accumulate_channels = accumulate_channels
26        assert accumulator in ("mean", "min", "max")
27        self.accumulator = getattr(torch, accumulator)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

invariant_channels
accumulate_channels
accumulator
def forward(self, x):
36    def forward(self, x):
37        if self.invariant_channels is None:
38            c0, c1 = self.accumulate_channels
39            return self._accumulate(x, c0, c1)
40        else:
41            i0, i1 = self.invariant_channels
42            c0, c1 = self.accumulate_channels
43            return torch.cat([x[:, i0:i1], self._accumulate(x, c0, c1)], dim=1)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def affinities_to_boundaries(aff_channels, accumulator='max'):
46def affinities_to_boundaries(aff_channels, accumulator="max"):
47    return AccumulateChannels(None, aff_channels, accumulator)
def affinities_with_foreground_to_boundaries(aff_channels, fg_channel=(0, 1), accumulator='max'):
50def affinities_with_foreground_to_boundaries(aff_channels, fg_channel=(0, 1), accumulator="max"):
51    return AccumulateChannels(fg_channel, aff_channels, accumulator)
def affinities_to_boundaries2d():
54def affinities_to_boundaries2d():
55    return affinities_to_boundaries((0, 2))
def affinities_with_foreground_to_boundaries2d():
58def affinities_with_foreground_to_boundaries2d():
59    return affinities_with_foreground_to_boundaries((1, 3))
def affinities_to_boundaries3d():
62def affinities_to_boundaries3d():
63    return affinities_to_boundaries((0, 3))
def affinities_with_foreground_to_boundaries3d():
66def affinities_with_foreground_to_boundaries3d():
67    return affinities_with_foreground_to_boundaries((1, 4))
def affinities_to_boundaries_anisotropic():
70def affinities_to_boundaries_anisotropic():
71    return AccumulateChannels(None, (1, 3), "max")
POSTPROCESSING = {'affinities_to_boundaries_anisotropic': <function affinities_to_boundaries_anisotropic>, 'affinities_to_boundaries2d': <function affinities_to_boundaries2d>, 'affinities_with_foreground_to_boundaries2d': <function affinities_with_foreground_to_boundaries2d>, 'affinities_to_boundaries3d': <function affinities_to_boundaries3d>, 'affinities_with_foreground_to_boundaries3d': <function affinities_with_foreground_to_boundaries3d>}
class UNetBase(torch.nn.modules.module.Module):
 87class UNetBase(nn.Module):
 88    """
 89    """
 90    def __init__(
 91        self,
 92        encoder,
 93        base,
 94        decoder,
 95        out_conv=None,
 96        final_activation=None,
 97        postprocessing=None,
 98        check_shape=True,
 99    ):
100        super().__init__()
101        if len(encoder) != len(decoder):
102            raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})")
103
104        self.encoder = encoder
105        self.base = base
106        self.decoder = decoder
107
108        if out_conv is None:
109            self.return_decoder_outputs = False
110            self._out_channels = self.decoder.out_channels
111        elif isinstance(out_conv, nn.ModuleList):
112            if len(out_conv) != len(self.decoder):
113                raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}")
114            self.return_decoder_outputs = True
115            self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv]
116        else:
117            self.return_decoder_outputs = False
118            self._out_channels = out_conv.out_channels
119        self.out_conv = out_conv
120        self.check_shape = check_shape
121        self.final_activation = self._get_activation(final_activation)
122        self.postprocessing = self._get_postprocessing(postprocessing)
123
124    @property
125    def in_channels(self):
126        return self.encoder.in_channels
127
128    @property
129    def out_channels(self):
130        return self._out_channels
131
132    @property
133    def depth(self):
134        return len(self.encoder)
135
136    def _get_activation(self, activation):
137        return_activation = None
138        if activation is None:
139            return None
140        if isinstance(activation, nn.Module):
141            return activation
142        if isinstance(activation, str):
143            return_activation = getattr(nn, activation, None)
144        if return_activation is None:
145            raise ValueError(f"Invalid activation: {activation}")
146        return return_activation()
147
148    def _get_postprocessing(self, postprocessing):
149        if postprocessing is None:
150            return None
151        elif isinstance(postprocessing, nn.Module):
152            return postprocessing
153        elif postprocessing in POSTPROCESSING:
154            return POSTPROCESSING[postprocessing]()
155        else:
156            raise ValueError(f"Invalid postprocessing: {postprocessing}")
157
158    # load encoder / decoder / base states for pretraining
159    def load_encoder_state(self, state):
160        self.encoder.load_state_dict(state)
161
162    def load_decoder_state(self, state):
163        self.decoder.load_state_dict(state)
164
165    def load_base_state(self, state):
166        self.base.load_state_dict(state)
167
168    def _apply_default(self, x):
169        self.encoder.return_outputs = True
170        self.decoder.return_outputs = False
171
172        x, encoder_out = self.encoder(x)
173        x = self.base(x)
174        x = self.decoder(x, encoder_inputs=encoder_out[::-1])
175
176        if self.out_conv is not None:
177            x = self.out_conv(x)
178        if self.final_activation is not None:
179            x = self.final_activation(x)
180        if self.postprocessing is not None:
181            x = self.postprocessing(x)
182
183        return x
184
185    def _apply_with_side_outputs(self, x):
186        self.encoder.return_outputs = True
187        self.decoder.return_outputs = True
188
189        x, encoder_out = self.encoder(x)
190        x = self.base(x)
191        x = self.decoder(x, encoder_inputs=encoder_out[::-1])
192
193        x = [x if conv is None else conv(xx) for xx, conv in zip(x, self.out_conv)]
194        if self.final_activation is not None:
195            x = [self.final_activation(xx) for xx in x]
196
197        if self.postprocessing is not None:
198            x = [self.postprocessing(xx) for xx in x]
199
200        # we reverse the list to have the full shape output as first element
201        return x[::-1]
202
203    def _check_shape(self, x):
204        spatial_shape = tuple(x.shape)[2:]
205        depth = len(self.encoder)
206        factor = [2**depth] * len(spatial_shape)
207        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
208            msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}"
209            raise ValueError(msg)
210
211    def forward(self, x):
212        # cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference
213        # x = x.float()
214        if getattr(self, "check_shape", True):
215            self._check_shape(x)
216        if self.return_decoder_outputs:
217            return self._apply_with_side_outputs(x)
218        else:
219            return self._apply_default(x)
UNetBase( encoder, base, decoder, out_conv=None, final_activation=None, postprocessing=None, check_shape=True)
 90    def __init__(
 91        self,
 92        encoder,
 93        base,
 94        decoder,
 95        out_conv=None,
 96        final_activation=None,
 97        postprocessing=None,
 98        check_shape=True,
 99    ):
100        super().__init__()
101        if len(encoder) != len(decoder):
102            raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})")
103
104        self.encoder = encoder
105        self.base = base
106        self.decoder = decoder
107
108        if out_conv is None:
109            self.return_decoder_outputs = False
110            self._out_channels = self.decoder.out_channels
111        elif isinstance(out_conv, nn.ModuleList):
112            if len(out_conv) != len(self.decoder):
113                raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}")
114            self.return_decoder_outputs = True
115            self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv]
116        else:
117            self.return_decoder_outputs = False
118            self._out_channels = out_conv.out_channels
119        self.out_conv = out_conv
120        self.check_shape = check_shape
121        self.final_activation = self._get_activation(final_activation)
122        self.postprocessing = self._get_postprocessing(postprocessing)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

encoder
base
decoder
out_conv
check_shape
final_activation
postprocessing
in_channels
out_channels
depth
def load_encoder_state(self, state):
159    def load_encoder_state(self, state):
160        self.encoder.load_state_dict(state)
def load_decoder_state(self, state):
162    def load_decoder_state(self, state):
163        self.decoder.load_state_dict(state)
def load_base_state(self, state):
165    def load_base_state(self, state):
166        self.base.load_state_dict(state)
def forward(self, x):
211    def forward(self, x):
212        # cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference
213        # x = x.float()
214        if getattr(self, "check_shape", True):
215            self._check_shape(x)
216        if self.return_decoder_outputs:
217            return self._apply_with_side_outputs(x)
218        else:
219            return self._apply_default(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class Encoder(torch.nn.modules.module.Module):
241class Encoder(nn.Module):
242    def __init__(
243        self,
244        features,
245        scale_factors,
246        conv_block_impl,
247        pooler_impl,
248        anisotropic_kernel=False,
249        **conv_block_kwargs
250    ):
251        super().__init__()
252        if len(features) != len(scale_factors) + 1:
253            raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}")
254
255        conv_kwargs = [conv_block_kwargs] * len(scale_factors)
256        if anisotropic_kernel:
257            conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor)
258                           for kwargs, scale_factor in zip(conv_kwargs, scale_factors)]
259
260        self.blocks = nn.ModuleList(
261            [conv_block_impl(inc, outc, **kwargs)
262             for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)]
263        )
264        self.poolers = nn.ModuleList(
265            [pooler_impl(factor) for factor in scale_factors]
266        )
267        self.return_outputs = True
268
269        self.in_channels = features[0]
270        self.out_channels = features[-1]
271
272    def __len__(self):
273        return len(self.blocks)
274
275    def forward(self, x):
276        encoder_out = []
277        for block, pooler in zip(self.blocks, self.poolers):
278            x = block(x)
279            encoder_out.append(x)
280            x = pooler(x)
281
282        if self.return_outputs:
283            return x, encoder_out
284        else:
285            return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Encoder( features, scale_factors, conv_block_impl, pooler_impl, anisotropic_kernel=False, **conv_block_kwargs)
242    def __init__(
243        self,
244        features,
245        scale_factors,
246        conv_block_impl,
247        pooler_impl,
248        anisotropic_kernel=False,
249        **conv_block_kwargs
250    ):
251        super().__init__()
252        if len(features) != len(scale_factors) + 1:
253            raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}")
254
255        conv_kwargs = [conv_block_kwargs] * len(scale_factors)
256        if anisotropic_kernel:
257            conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor)
258                           for kwargs, scale_factor in zip(conv_kwargs, scale_factors)]
259
260        self.blocks = nn.ModuleList(
261            [conv_block_impl(inc, outc, **kwargs)
262             for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)]
263        )
264        self.poolers = nn.ModuleList(
265            [pooler_impl(factor) for factor in scale_factors]
266        )
267        self.return_outputs = True
268
269        self.in_channels = features[0]
270        self.out_channels = features[-1]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

blocks
poolers
return_outputs
in_channels
out_channels
def forward(self, x):
275    def forward(self, x):
276        encoder_out = []
277        for block, pooler in zip(self.blocks, self.poolers):
278            x = block(x)
279            encoder_out.append(x)
280            x = pooler(x)
281
282        if self.return_outputs:
283            return x, encoder_out
284        else:
285            return x

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class Decoder(torch.nn.modules.module.Module):
288class Decoder(nn.Module):
289    def __init__(
290        self,
291        features,
292        scale_factors,
293        conv_block_impl,
294        sampler_impl,
295        anisotropic_kernel=False,
296        **conv_block_kwargs
297    ):
298        super().__init__()
299        if len(features) != len(scale_factors) + 1:
300            raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}")
301
302        conv_kwargs = [conv_block_kwargs] * len(scale_factors)
303        if anisotropic_kernel:
304            conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor)
305                           for kwargs, scale_factor in zip(conv_kwargs, scale_factors)]
306
307        self.blocks = nn.ModuleList(
308            [conv_block_impl(inc, outc, **kwargs)
309             for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)]
310        )
311        self.samplers = nn.ModuleList(
312            [sampler_impl(factor, inc, outc) for factor, inc, outc
313             in zip(scale_factors, features[:-1], features[1:])]
314        )
315        self.return_outputs = False
316
317        self.in_channels = features[0]
318        self.out_channels = features[-1]
319
320    def __len__(self):
321        return len(self.blocks)
322
323    # FIXME this prevents traces from being valid for other input sizes, need to find
324    # a solution to traceable cropping
325    def _crop(self, x, shape):
326        shape_diff = [(xsh - sh) // 2 for xsh, sh in zip(x.shape, shape)]
327        crop = tuple([slice(sd, xsh - sd) for sd, xsh in zip(shape_diff, x.shape)])
328        return x[crop]
329        # # Implementation with torch.narrow, does not fix the tracing warnings!
330        # for dim, (sh, sd) in enumerate(zip(shape, shape_diff)):
331        #     x = torch.narrow(x, dim, sd, sh)
332        # return x
333
334    def _concat(self, x1, x2):
335        return torch.cat([x1, self._crop(x2, x1.shape)], dim=1)
336
337    def forward(self, x, encoder_inputs):
338        if len(encoder_inputs) != len(self.blocks):
339            raise ValueError(f"Invalid number of encoder_inputs: expect {len(self.blocks)}, got {len(encoder_inputs)}")
340
341        decoder_out = []
342        for block, sampler, from_encoder in zip(self.blocks, self.samplers, encoder_inputs):
343            x = sampler(x)
344            x = block(self._concat(x, from_encoder))
345            decoder_out.append(x)
346
347        if self.return_outputs:
348            return decoder_out + [x]
349        else:
350            return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Decoder( features, scale_factors, conv_block_impl, sampler_impl, anisotropic_kernel=False, **conv_block_kwargs)
289    def __init__(
290        self,
291        features,
292        scale_factors,
293        conv_block_impl,
294        sampler_impl,
295        anisotropic_kernel=False,
296        **conv_block_kwargs
297    ):
298        super().__init__()
299        if len(features) != len(scale_factors) + 1:
300            raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}")
301
302        conv_kwargs = [conv_block_kwargs] * len(scale_factors)
303        if anisotropic_kernel:
304            conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor)
305                           for kwargs, scale_factor in zip(conv_kwargs, scale_factors)]
306
307        self.blocks = nn.ModuleList(
308            [conv_block_impl(inc, outc, **kwargs)
309             for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)]
310        )
311        self.samplers = nn.ModuleList(
312            [sampler_impl(factor, inc, outc) for factor, inc, outc
313             in zip(scale_factors, features[:-1], features[1:])]
314        )
315        self.return_outputs = False
316
317        self.in_channels = features[0]
318        self.out_channels = features[-1]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

blocks
samplers
return_outputs
in_channels
out_channels
def forward(self, x, encoder_inputs):
337    def forward(self, x, encoder_inputs):
338        if len(encoder_inputs) != len(self.blocks):
339            raise ValueError(f"Invalid number of encoder_inputs: expect {len(self.blocks)}, got {len(encoder_inputs)}")
340
341        decoder_out = []
342        for block, sampler, from_encoder in zip(self.blocks, self.samplers, encoder_inputs):
343            x = sampler(x)
344            x = block(self._concat(x, from_encoder))
345            decoder_out.append(x)
346
347        if self.return_outputs:
348            return decoder_out + [x]
349        else:
350            return x

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def get_norm_layer(norm, dim, channels, n_groups=32):
353def get_norm_layer(norm, dim, channels, n_groups=32):
354    if norm is None:
355        return None
356    if norm == "InstanceNorm":
357        kwargs = {"affine": True, "track_running_stats": True, "momentum": 0.01}
358        return nn.InstanceNorm2d(channels, **kwargs) if dim == 2 else nn.InstanceNorm3d(channels, **kwargs)
359    elif norm == "OldDefault":
360        return nn.InstanceNorm2d(channels) if dim == 2 else nn.InstanceNorm3d(channels)
361    elif norm == "GroupNorm":
362        return nn.GroupNorm(min(n_groups, channels), channels)
363    elif norm == "BatchNorm":
364        return nn.BatchNorm2d(channels) if dim == 2 else nn.BatchNorm3d(channels)
365    else:
366        raise ValueError(f"Invalid norm: expect one of 'InstanceNorm', 'BatchNorm' or 'GroupNorm', got {norm}")
class ConvBlock(torch.nn.modules.module.Module):
369class ConvBlock(nn.Module):
370    def __init__(self, in_channels, out_channels, dim,
371                 kernel_size=3, padding=1, norm="InstanceNorm"):
372        super().__init__()
373        self.in_channels = in_channels
374        self.out_channels = out_channels
375
376        conv = nn.Conv2d if dim == 2 else nn.Conv3d
377
378        if norm is None:
379            self.block = nn.Sequential(
380                conv(in_channels, out_channels,
381                     kernel_size=kernel_size, padding=padding),
382                nn.ReLU(inplace=True),
383                conv(out_channels, out_channels,
384                     kernel_size=kernel_size, padding=padding),
385                nn.ReLU(inplace=True)
386            )
387        else:
388            self.block = nn.Sequential(
389                get_norm_layer(norm, dim, in_channels),
390                conv(in_channels, out_channels,
391                     kernel_size=kernel_size, padding=padding),
392                nn.ReLU(inplace=True),
393                get_norm_layer(norm, dim, out_channels),
394                conv(out_channels, out_channels,
395                     kernel_size=kernel_size, padding=padding),
396                nn.ReLU(inplace=True)
397            )
398
399    def forward(self, x):
400        return self.block(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ConvBlock( in_channels, out_channels, dim, kernel_size=3, padding=1, norm='InstanceNorm')
370    def __init__(self, in_channels, out_channels, dim,
371                 kernel_size=3, padding=1, norm="InstanceNorm"):
372        super().__init__()
373        self.in_channels = in_channels
374        self.out_channels = out_channels
375
376        conv = nn.Conv2d if dim == 2 else nn.Conv3d
377
378        if norm is None:
379            self.block = nn.Sequential(
380                conv(in_channels, out_channels,
381                     kernel_size=kernel_size, padding=padding),
382                nn.ReLU(inplace=True),
383                conv(out_channels, out_channels,
384                     kernel_size=kernel_size, padding=padding),
385                nn.ReLU(inplace=True)
386            )
387        else:
388            self.block = nn.Sequential(
389                get_norm_layer(norm, dim, in_channels),
390                conv(in_channels, out_channels,
391                     kernel_size=kernel_size, padding=padding),
392                nn.ReLU(inplace=True),
393                get_norm_layer(norm, dim, out_channels),
394                conv(out_channels, out_channels,
395                     kernel_size=kernel_size, padding=padding),
396                nn.ReLU(inplace=True)
397            )

Initializes internal Module state, shared by both nn.Module and ScriptModule.

in_channels
out_channels
def forward(self, x):
399    def forward(self, x):
400        return self.block(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class Upsampler(torch.nn.modules.module.Module):
403class Upsampler(nn.Module):
404    def __init__(self, scale_factor,
405                 in_channels, out_channels,
406                 dim, mode):
407        super().__init__()
408        self.mode = mode
409        self.scale_factor = scale_factor
410
411        conv = nn.Conv2d if dim == 2 else nn.Conv3d
412        self.conv = conv(in_channels, out_channels, 1)
413
414    def forward(self, x):
415        x = nn.functional.interpolate(x, scale_factor=self.scale_factor,
416                                      mode=self.mode, align_corners=False)
417        x = self.conv(x)
418        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Upsampler(scale_factor, in_channels, out_channels, dim, mode)
404    def __init__(self, scale_factor,
405                 in_channels, out_channels,
406                 dim, mode):
407        super().__init__()
408        self.mode = mode
409        self.scale_factor = scale_factor
410
411        conv = nn.Conv2d if dim == 2 else nn.Conv3d
412        self.conv = conv(in_channels, out_channels, 1)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

mode
scale_factor
conv
def forward(self, x):
414    def forward(self, x):
415        x = nn.functional.interpolate(x, scale_factor=self.scale_factor,
416                                      mode=self.mode, align_corners=False)
417        x = self.conv(x)
418        return x

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ConvBlock2d(ConvBlock):
425class ConvBlock2d(ConvBlock):
426    def __init__(self, in_channels, out_channels, **kwargs):
427        super().__init__(in_channels, out_channels, dim=2, **kwargs)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ConvBlock2d(in_channels, out_channels, **kwargs)
426    def __init__(self, in_channels, out_channels, **kwargs):
427        super().__init__(in_channels, out_channels, dim=2, **kwargs)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
ConvBlock
in_channels
out_channels
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class Upsampler2d(Upsampler):
430class Upsampler2d(Upsampler):
431    def __init__(self, scale_factor,
432                 in_channels, out_channels,
433                 mode="bilinear"):
434        super().__init__(scale_factor, in_channels, out_channels,
435                         dim=2, mode=mode)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Upsampler2d(scale_factor, in_channels, out_channels, mode='bilinear')
431    def __init__(self, scale_factor,
432                 in_channels, out_channels,
433                 mode="bilinear"):
434        super().__init__(scale_factor, in_channels, out_channels,
435                         dim=2, mode=mode)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
Upsampler
mode
scale_factor
conv
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class UNet2d(UNetBase):
438class UNet2d(UNetBase):
439    def __init__(
440        self,
441        in_channels,
442        out_channels,
443        depth=4,
444        initial_features=32,
445        gain=2,
446        final_activation=None,
447        return_side_outputs=False,
448        conv_block_impl=ConvBlock2d,
449        pooler_impl=nn.MaxPool2d,
450        sampler_impl=Upsampler2d,
451        postprocessing=None,
452        check_shape=True,
453        **conv_block_kwargs,
454    ):
455        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
456        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
457        scale_factors = depth * [2]
458
459        if return_side_outputs:
460            if isinstance(out_channels, int) or out_channels is None:
461                out_channels = [out_channels] * depth
462            if len(out_channels) != depth:
463                raise ValueError()
464            out_conv = nn.ModuleList(
465                [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
466            )
467        else:
468            out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1)
469
470        super().__init__(
471            encoder=Encoder(
472                features=features_encoder,
473                scale_factors=scale_factors,
474                conv_block_impl=conv_block_impl,
475                pooler_impl=pooler_impl,
476                **conv_block_kwargs
477            ),
478            decoder=Decoder(
479                features=features_decoder,
480                scale_factors=scale_factors[::-1],
481                conv_block_impl=conv_block_impl,
482                sampler_impl=sampler_impl,
483                **conv_block_kwargs
484            ),
485            base=conv_block_impl(
486                features_encoder[-1], features_encoder[-1] * gain,
487                **conv_block_kwargs
488            ),
489            out_conv=out_conv,
490            final_activation=final_activation,
491            postprocessing=postprocessing,
492            check_shape=check_shape,
493        )
494        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
495                            "initial_features": initial_features, "gain": gain,
496                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
497                            "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl,
498                            "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs}
UNet2d( in_channels, out_channels, depth=4, initial_features=32, gain=2, final_activation=None, return_side_outputs=False, conv_block_impl=<class 'ConvBlock2d'>, pooler_impl=<class 'torch.nn.modules.pooling.MaxPool2d'>, sampler_impl=<class 'Upsampler2d'>, postprocessing=None, check_shape=True, **conv_block_kwargs)
439    def __init__(
440        self,
441        in_channels,
442        out_channels,
443        depth=4,
444        initial_features=32,
445        gain=2,
446        final_activation=None,
447        return_side_outputs=False,
448        conv_block_impl=ConvBlock2d,
449        pooler_impl=nn.MaxPool2d,
450        sampler_impl=Upsampler2d,
451        postprocessing=None,
452        check_shape=True,
453        **conv_block_kwargs,
454    ):
455        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
456        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
457        scale_factors = depth * [2]
458
459        if return_side_outputs:
460            if isinstance(out_channels, int) or out_channels is None:
461                out_channels = [out_channels] * depth
462            if len(out_channels) != depth:
463                raise ValueError()
464            out_conv = nn.ModuleList(
465                [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
466            )
467        else:
468            out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1)
469
470        super().__init__(
471            encoder=Encoder(
472                features=features_encoder,
473                scale_factors=scale_factors,
474                conv_block_impl=conv_block_impl,
475                pooler_impl=pooler_impl,
476                **conv_block_kwargs
477            ),
478            decoder=Decoder(
479                features=features_decoder,
480                scale_factors=scale_factors[::-1],
481                conv_block_impl=conv_block_impl,
482                sampler_impl=sampler_impl,
483                **conv_block_kwargs
484            ),
485            base=conv_block_impl(
486                features_encoder[-1], features_encoder[-1] * gain,
487                **conv_block_kwargs
488            ),
489            out_conv=out_conv,
490            final_activation=final_activation,
491            postprocessing=postprocessing,
492            check_shape=check_shape,
493        )
494        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
495                            "initial_features": initial_features, "gain": gain,
496                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
497                            "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl,
498                            "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

init_kwargs
Inherited Members
UNetBase
encoder
base
decoder
out_conv
check_shape
final_activation
postprocessing
in_channels
out_channels
depth
load_encoder_state
load_decoder_state
load_base_state
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ConvBlock3d(ConvBlock):
505class ConvBlock3d(ConvBlock):
506    def __init__(self, in_channels, out_channels, **kwargs):
507        super().__init__(in_channels, out_channels, dim=3, **kwargs)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ConvBlock3d(in_channels, out_channels, **kwargs)
506    def __init__(self, in_channels, out_channels, **kwargs):
507        super().__init__(in_channels, out_channels, dim=3, **kwargs)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
ConvBlock
in_channels
out_channels
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class Upsampler3d(Upsampler):
510class Upsampler3d(Upsampler):
511    def __init__(self, scale_factor,
512                 in_channels, out_channels,
513                 mode="trilinear"):
514        super().__init__(scale_factor, in_channels, out_channels,
515                         dim=3, mode=mode)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Upsampler3d(scale_factor, in_channels, out_channels, mode='trilinear')
511    def __init__(self, scale_factor,
512                 in_channels, out_channels,
513                 mode="trilinear"):
514        super().__init__(scale_factor, in_channels, out_channels,
515                         dim=3, mode=mode)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
Upsampler
mode
scale_factor
conv
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class AnisotropicUNet(UNetBase):
518class AnisotropicUNet(UNetBase):
519    def __init__(
520        self,
521        in_channels,
522        out_channels,
523        scale_factors,
524        initial_features=32,
525        gain=2,
526        final_activation=None,
527        return_side_outputs=False,
528        conv_block_impl=ConvBlock3d,
529        anisotropic_kernel=False,  # TODO benchmark which option is better and set as default
530        postprocessing=None,
531        check_shape=True,
532        **conv_block_kwargs,
533    ):
534        depth = len(scale_factors)
535        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
536        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
537
538        if return_side_outputs:
539            if isinstance(out_channels, int) or out_channels is None:
540                out_channels = [out_channels] * depth
541            if len(out_channels) != depth:
542                raise ValueError()
543            out_conv = nn.ModuleList(
544                [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
545            )
546        else:
547            out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1)
548
549        super().__init__(
550            encoder=Encoder(
551                features=features_encoder,
552                scale_factors=scale_factors,
553                conv_block_impl=conv_block_impl,
554                pooler_impl=nn.MaxPool3d,
555                anisotropic_kernel=anisotropic_kernel,
556                **conv_block_kwargs
557            ),
558            decoder=Decoder(
559                features=features_decoder,
560                scale_factors=scale_factors[::-1],
561                conv_block_impl=conv_block_impl,
562                sampler_impl=Upsampler3d,
563                anisotropic_kernel=anisotropic_kernel,
564                **conv_block_kwargs
565            ),
566            base=conv_block_impl(
567                features_encoder[-1], features_encoder[-1] * gain,
568                **conv_block_kwargs
569            ),
570            out_conv=out_conv,
571            final_activation=final_activation,
572            postprocessing=postprocessing,
573            check_shape=check_shape,
574        )
575        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors,
576                            "initial_features": initial_features, "gain": gain,
577                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
578                            "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel,
579                            "postprocessing": postprocessing, **conv_block_kwargs}
580
581    def _check_shape(self, x):
582        spatial_shape = tuple(x.shape)[2:]
583        scale_factors = self.init_kwargs.get("scale_factors", [[2, 2, 2]]*len(self.encoder))
584        factor = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
585        if len(spatial_shape) != len(factor):
586            msg = f"Invalid shape for U-Net: dimensions don't agree {len(spatial_shape)} != {len(factor)}"
587            raise ValueError(msg)
588        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
589            msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}"
590            raise ValueError(msg)
AnisotropicUNet( in_channels, out_channels, scale_factors, initial_features=32, gain=2, final_activation=None, return_side_outputs=False, conv_block_impl=<class 'ConvBlock3d'>, anisotropic_kernel=False, postprocessing=None, check_shape=True, **conv_block_kwargs)
519    def __init__(
520        self,
521        in_channels,
522        out_channels,
523        scale_factors,
524        initial_features=32,
525        gain=2,
526        final_activation=None,
527        return_side_outputs=False,
528        conv_block_impl=ConvBlock3d,
529        anisotropic_kernel=False,  # TODO benchmark which option is better and set as default
530        postprocessing=None,
531        check_shape=True,
532        **conv_block_kwargs,
533    ):
534        depth = len(scale_factors)
535        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
536        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
537
538        if return_side_outputs:
539            if isinstance(out_channels, int) or out_channels is None:
540                out_channels = [out_channels] * depth
541            if len(out_channels) != depth:
542                raise ValueError()
543            out_conv = nn.ModuleList(
544                [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
545            )
546        else:
547            out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1)
548
549        super().__init__(
550            encoder=Encoder(
551                features=features_encoder,
552                scale_factors=scale_factors,
553                conv_block_impl=conv_block_impl,
554                pooler_impl=nn.MaxPool3d,
555                anisotropic_kernel=anisotropic_kernel,
556                **conv_block_kwargs
557            ),
558            decoder=Decoder(
559                features=features_decoder,
560                scale_factors=scale_factors[::-1],
561                conv_block_impl=conv_block_impl,
562                sampler_impl=Upsampler3d,
563                anisotropic_kernel=anisotropic_kernel,
564                **conv_block_kwargs
565            ),
566            base=conv_block_impl(
567                features_encoder[-1], features_encoder[-1] * gain,
568                **conv_block_kwargs
569            ),
570            out_conv=out_conv,
571            final_activation=final_activation,
572            postprocessing=postprocessing,
573            check_shape=check_shape,
574        )
575        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors,
576                            "initial_features": initial_features, "gain": gain,
577                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
578                            "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel,
579                            "postprocessing": postprocessing, **conv_block_kwargs}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

init_kwargs
Inherited Members
UNetBase
encoder
base
decoder
out_conv
check_shape
final_activation
postprocessing
in_channels
out_channels
depth
load_encoder_state
load_decoder_state
load_base_state
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class UNet3d(AnisotropicUNet):
593class UNet3d(AnisotropicUNet):
594    def __init__(
595        self,
596        in_channels,
597        out_channels,
598        depth=4,
599        initial_features=32,
600        gain=2,
601        final_activation=None,
602        return_side_outputs=False,
603        conv_block_impl=ConvBlock3d,
604        postprocessing=None,
605        check_shape=True,
606        **conv_block_kwargs,
607    ):
608        scale_factors = depth * [2]
609        super().__init__(in_channels, out_channels, scale_factors,
610                         initial_features=initial_features, gain=gain,
611                         final_activation=final_activation,
612                         return_side_outputs=return_side_outputs,
613                         anisotropic_kernel=False,
614                         postprocessing=postprocessing,
615                         conv_block_impl=conv_block_impl,
616                         check_shape=check_shape,
617                         **conv_block_kwargs)
618        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
619                            "initial_features": initial_features, "gain": gain,
620                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
621                            "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}
UNet3d( in_channels, out_channels, depth=4, initial_features=32, gain=2, final_activation=None, return_side_outputs=False, conv_block_impl=<class 'ConvBlock3d'>, postprocessing=None, check_shape=True, **conv_block_kwargs)
594    def __init__(
595        self,
596        in_channels,
597        out_channels,
598        depth=4,
599        initial_features=32,
600        gain=2,
601        final_activation=None,
602        return_side_outputs=False,
603        conv_block_impl=ConvBlock3d,
604        postprocessing=None,
605        check_shape=True,
606        **conv_block_kwargs,
607    ):
608        scale_factors = depth * [2]
609        super().__init__(in_channels, out_channels, scale_factors,
610                         initial_features=initial_features, gain=gain,
611                         final_activation=final_activation,
612                         return_side_outputs=return_side_outputs,
613                         anisotropic_kernel=False,
614                         postprocessing=postprocessing,
615                         conv_block_impl=conv_block_impl,
616                         check_shape=check_shape,
617                         **conv_block_kwargs)
618        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
619                            "initial_features": initial_features, "gain": gain,
620                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
621                            "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

init_kwargs
Inherited Members
UNetBase
encoder
base
decoder
out_conv
check_shape
final_activation
postprocessing
in_channels
out_channels
depth
load_encoder_state
load_decoder_state
load_base_state
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile