torch_em.model.unet

  1from typing import List, Optional, Union
  2
  3import numpy as np
  4import torch
  5import torch.nn as nn
  6
  7
  8#
  9# Model Internal Post-processing
 10#
 11# Note: these are mainly for bioimage.io models, where postprocessing has to be done
 12# inside of the model unless its defined in the general spec
 13
 14
 15class AccumulateChannels(nn.Module):
 16    """@private
 17    """
 18    def __init__(
 19        self,
 20        invariant_channels,
 21        accumulate_channels,
 22        accumulator
 23    ):
 24        super().__init__()
 25        self.invariant_channels = invariant_channels
 26        self.accumulate_channels = accumulate_channels
 27        assert accumulator in ("mean", "min", "max")
 28        self.accumulator = getattr(torch, accumulator)
 29
 30    def _accumulate(self, x, c0, c1):
 31        res = self.accumulator(x[:, c0:c1], dim=1, keepdim=True)
 32        if not torch.is_tensor(res):
 33            res = res.values
 34        assert torch.is_tensor(res)
 35        return res
 36
 37    def forward(self, x):
 38        if self.invariant_channels is None:
 39            c0, c1 = self.accumulate_channels
 40            return self._accumulate(x, c0, c1)
 41        else:
 42            i0, i1 = self.invariant_channels
 43            c0, c1 = self.accumulate_channels
 44            return torch.cat([x[:, i0:i1], self._accumulate(x, c0, c1)], dim=1)
 45
 46
 47def affinities_to_boundaries(aff_channels, accumulator="max"):
 48    """@private
 49    """
 50    return AccumulateChannels(None, aff_channels, accumulator)
 51
 52
 53def affinities_with_foreground_to_boundaries(aff_channels, fg_channel=(0, 1), accumulator="max"):
 54    """@private
 55    """
 56    return AccumulateChannels(fg_channel, aff_channels, accumulator)
 57
 58
 59def affinities_to_boundaries2d():
 60    """@private
 61    """
 62    return affinities_to_boundaries((0, 2))
 63
 64
 65def affinities_with_foreground_to_boundaries2d():
 66    """@private
 67    """
 68    return affinities_with_foreground_to_boundaries((1, 3))
 69
 70
 71def affinities_to_boundaries3d():
 72    """@private
 73    """
 74    return affinities_to_boundaries((0, 3))
 75
 76
 77def affinities_with_foreground_to_boundaries3d():
 78    """@private
 79    """
 80    return affinities_with_foreground_to_boundaries((1, 4))
 81
 82
 83def affinities_to_boundaries_anisotropic():
 84    """@private
 85    """
 86    return AccumulateChannels(None, (1, 3), "max")
 87
 88
 89POSTPROCESSING = {
 90    "affinities_to_boundaries_anisotropic": affinities_to_boundaries_anisotropic,
 91    "affinities_to_boundaries2d": affinities_to_boundaries2d,
 92    "affinities_with_foreground_to_boundaries2d": affinities_with_foreground_to_boundaries2d,
 93    "affinities_to_boundaries3d": affinities_to_boundaries3d,
 94    "affinities_with_foreground_to_boundaries3d": affinities_with_foreground_to_boundaries3d,
 95}
 96"""@private
 97"""
 98
 99
100#
101# Base Implementations
102#
103
104class UNetBase(nn.Module):
105    """Base class for implementing a U-Net.
106
107    Args:
108        encoder: The encoder of the U-Net.
109        base: The base layer of the U-Net.
110        decoder: The decoder of the U-Net.
111        out_conv: The output convolution applied after the last decoder layer.
112        final_activation: The activation applied after the output convolution or last decoder layer.
113        postprocessing: A postprocessing function to apply after the U-Net output.
114        check_shape: Whether to check the input shape to the U-Net forward call.
115    """
116    def __init__(
117        self,
118        encoder: nn.Module,
119        base: nn.Module,
120        decoder: nn.Module,
121        out_conv: Optional[nn.Module] = None,
122        final_activation: Optional[Union[nn.Module, str]] = None,
123        postprocessing: Optional[Union[nn.Module, str]] = None,
124        check_shape: bool = True,
125    ):
126        super().__init__()
127        if len(encoder) != len(decoder):
128            raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})")
129
130        self.encoder = encoder
131        self.base = base
132        self.decoder = decoder
133
134        if out_conv is None:
135            self.return_decoder_outputs = False
136            self._out_channels = self.decoder.out_channels
137        elif isinstance(out_conv, nn.ModuleList):
138            if len(out_conv) != len(self.decoder):
139                raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}")
140            self.return_decoder_outputs = True
141            self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv]
142        else:
143            self.return_decoder_outputs = False
144            self._out_channels = out_conv.out_channels
145        self.out_conv = out_conv
146        self.check_shape = check_shape
147        self.final_activation = self._get_activation(final_activation)
148        self.postprocessing = self._get_postprocessing(postprocessing)
149
150    @property
151    def in_channels(self):
152        return self.encoder.in_channels
153
154    @property
155    def out_channels(self):
156        return self._out_channels
157
158    @property
159    def depth(self):
160        return len(self.encoder)
161
162    def _get_activation(self, activation):
163        return_activation = None
164        if activation is None:
165            return None
166        if isinstance(activation, nn.Module):
167            return activation
168        if isinstance(activation, str):
169            return_activation = getattr(nn, activation, None)
170        if return_activation is None:
171            raise ValueError(f"Invalid activation: {activation}")
172        return return_activation()
173
174    def _get_postprocessing(self, postprocessing):
175        if postprocessing is None:
176            return None
177        elif isinstance(postprocessing, nn.Module):
178            return postprocessing
179        elif postprocessing in POSTPROCESSING:
180            return POSTPROCESSING[postprocessing]()
181        else:
182            raise ValueError(f"Invalid postprocessing: {postprocessing}")
183
184    # load encoder / decoder / base states for pretraining
185    def load_encoder_state(self, state):
186        self.encoder.load_state_dict(state)
187
188    def load_decoder_state(self, state):
189        self.decoder.load_state_dict(state)
190
191    def load_base_state(self, state):
192        self.base.load_state_dict(state)
193
194    def _apply_default(self, x):
195        self.encoder.return_outputs = True
196        self.decoder.return_outputs = False
197
198        x, encoder_out = self.encoder(x)
199        x = self.base(x)
200        x = self.decoder(x, encoder_inputs=encoder_out[::-1])
201
202        if self.out_conv is not None:
203            x = self.out_conv(x)
204        if self.final_activation is not None:
205            x = self.final_activation(x)
206        if self.postprocessing is not None:
207            x = self.postprocessing(x)
208
209        return x
210
211    def _apply_with_side_outputs(self, x):
212        self.encoder.return_outputs = True
213        self.decoder.return_outputs = True
214
215        x, encoder_out = self.encoder(x)
216        x = self.base(x)
217        x = self.decoder(x, encoder_inputs=encoder_out[::-1])
218
219        x = [x if conv is None else conv(xx) for xx, conv in zip(x, self.out_conv)]
220        if self.final_activation is not None:
221            x = [self.final_activation(xx) for xx in x]
222
223        if self.postprocessing is not None:
224            x = [self.postprocessing(xx) for xx in x]
225
226        # we reverse the list to have the full shape output as first element
227        return x[::-1]
228
229    def _check_shape(self, x):
230        spatial_shape = tuple(x.shape)[2:]
231        depth = len(self.encoder)
232        factor = [2**depth] * len(spatial_shape)
233        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
234            msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}"
235            raise ValueError(msg)
236
237    def forward(self, x: torch.Tensor) -> torch.tensor:
238        """Apply U-Net to input data.
239
240        Args:
241            x: The input data.
242
243        Returns:
244            The output of the U-Net.
245        """
246        # Cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference.
247        # x = x.float()
248        if getattr(self, "check_shape", True):
249            self._check_shape(x)
250        if self.return_decoder_outputs:
251            return self._apply_with_side_outputs(x)
252        else:
253            return self._apply_default(x)
254
255
256def _update_conv_kwargs(kwargs, scale_factor):
257    # if the scale factor is a scalar or all entries are the same we don"t need to update the kwargs
258    if isinstance(scale_factor, int) or scale_factor.count(scale_factor[0]) == len(scale_factor):
259        return kwargs
260    else:  # otherwise set anisotropic kernel
261        kernel_size = kwargs.get("kernel_size", 3)
262        padding = kwargs.get("padding", 1)
263
264        # bail out if kernel size or padding aren"t scalars, because it"s
265        # unclear what to do in this case
266        if not (isinstance(kernel_size, int) and isinstance(padding, int)):
267            return kwargs
268
269        kernel_size = tuple(1 if factor == 1 else kernel_size for factor in scale_factor)
270        padding = tuple(0 if factor == 1 else padding for factor in scale_factor)
271        kwargs.update({"kernel_size": kernel_size, "padding": padding})
272        return kwargs
273
274
275class Encoder(nn.Module):
276    """@private
277    """
278    def __init__(
279        self,
280        features,
281        scale_factors,
282        conv_block_impl,
283        pooler_impl,
284        anisotropic_kernel=False,
285        **conv_block_kwargs
286    ):
287        super().__init__()
288        if len(features) != len(scale_factors) + 1:
289            raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}")
290
291        conv_kwargs = [conv_block_kwargs] * len(scale_factors)
292        if anisotropic_kernel:
293            conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor)
294                           for kwargs, scale_factor in zip(conv_kwargs, scale_factors)]
295
296        self.blocks = nn.ModuleList(
297            [conv_block_impl(inc, outc, **kwargs)
298             for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)]
299        )
300        self.poolers = nn.ModuleList(
301            [pooler_impl(factor) for factor in scale_factors]
302        )
303        self.return_outputs = True
304
305        self.in_channels = features[0]
306        self.out_channels = features[-1]
307
308    def __len__(self):
309        return len(self.blocks)
310
311    def forward(self, x):
312        encoder_out = []
313        for block, pooler in zip(self.blocks, self.poolers):
314            x = block(x)
315            encoder_out.append(x)
316            x = pooler(x)
317
318        if self.return_outputs:
319            return x, encoder_out
320        else:
321            return x
322
323
324class Decoder(nn.Module):
325    """@private
326    """
327    def __init__(
328        self,
329        features,
330        scale_factors,
331        conv_block_impl,
332        sampler_impl,
333        anisotropic_kernel=False,
334        **conv_block_kwargs
335    ):
336        super().__init__()
337        if len(features) != len(scale_factors) + 1:
338            raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}")
339
340        conv_kwargs = [conv_block_kwargs] * len(scale_factors)
341        if anisotropic_kernel:
342            conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor)
343                           for kwargs, scale_factor in zip(conv_kwargs, scale_factors)]
344
345        self.blocks = nn.ModuleList(
346            [conv_block_impl(inc, outc, **kwargs)
347             for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)]
348        )
349        self.samplers = nn.ModuleList(
350            [sampler_impl(factor, inc, outc) for factor, inc, outc
351             in zip(scale_factors, features[:-1], features[1:])]
352        )
353        self.return_outputs = False
354
355        self.in_channels = features[0]
356        self.out_channels = features[-1]
357
358    def __len__(self):
359        return len(self.blocks)
360
361    # FIXME this prevents traces from being valid for other input sizes, need to find
362    # a solution to traceable cropping
363    def _crop(self, x, shape):
364        shape_diff = [(xsh - sh) // 2 for xsh, sh in zip(x.shape, shape)]
365        crop = tuple([slice(sd, xsh - sd) for sd, xsh in zip(shape_diff, x.shape)])
366        return x[crop]
367        # # Implementation with torch.narrow, does not fix the tracing warnings!
368        # for dim, (sh, sd) in enumerate(zip(shape, shape_diff)):
369        #     x = torch.narrow(x, dim, sd, sh)
370        # return x
371
372    def _concat(self, x1, x2):
373        return torch.cat([x1, self._crop(x2, x1.shape)], dim=1)
374
375    def forward(self, x, encoder_inputs):
376        if len(encoder_inputs) != len(self.blocks):
377            raise ValueError(f"Invalid number of encoder_inputs: expect {len(self.blocks)}, got {len(encoder_inputs)}")
378
379        decoder_out = []
380        for block, sampler, from_encoder in zip(self.blocks, self.samplers, encoder_inputs):
381            x = sampler(x)
382            x = block(self._concat(x, from_encoder))
383            decoder_out.append(x)
384
385        if self.return_outputs:
386            return decoder_out + [x]
387        else:
388            return x
389
390
391def get_norm_layer(norm, dim, channels, n_groups=32):
392    """@private
393    """
394    if norm is None:
395        return None
396    if norm == "InstanceNorm":
397        return nn.InstanceNorm2d(channels) if dim == 2 else nn.InstanceNorm3d(channels)
398    elif norm == "InstanceNormTrackStats":
399        kwargs = {"affine": True, "track_running_stats": True, "momentum": 0.01}
400        return nn.InstanceNorm2d(channels, **kwargs) if dim == 2 else nn.InstanceNorm3d(channels, **kwargs)
401    elif norm == "GroupNorm":
402        return nn.GroupNorm(min(n_groups, channels), channels)
403    elif norm == "BatchNorm":
404        return nn.BatchNorm2d(channels) if dim == 2 else nn.BatchNorm3d(channels)
405    else:
406        raise ValueError(f"Invalid norm: expect one of 'InstanceNorm', 'BatchNorm' or 'GroupNorm', got {norm}")
407
408
409class ConvBlock(nn.Module):
410    """@private
411    """
412    def __init__(self, in_channels, out_channels, dim, kernel_size=3, padding=1, norm="InstanceNorm"):
413        super().__init__()
414        self.in_channels = in_channels
415        self.out_channels = out_channels
416
417        conv = nn.Conv2d if dim == 2 else nn.Conv3d
418
419        if norm is None:
420            self.block = nn.Sequential(
421                conv(in_channels, out_channels,
422                     kernel_size=kernel_size, padding=padding),
423                nn.ReLU(inplace=True),
424                conv(out_channels, out_channels,
425                     kernel_size=kernel_size, padding=padding),
426                nn.ReLU(inplace=True)
427            )
428        else:
429            self.block = nn.Sequential(
430                get_norm_layer(norm, dim, in_channels),
431                conv(in_channels, out_channels,
432                     kernel_size=kernel_size, padding=padding),
433                nn.ReLU(inplace=True),
434                get_norm_layer(norm, dim, out_channels),
435                conv(out_channels, out_channels,
436                     kernel_size=kernel_size, padding=padding),
437                nn.ReLU(inplace=True)
438            )
439
440    def forward(self, x):
441        return self.block(x)
442
443
444class Upsampler(nn.Module):
445    """@private
446    """
447    def __init__(self, scale_factor, in_channels, out_channels, dim, mode):
448        super().__init__()
449        self.mode = mode
450        self.scale_factor = scale_factor
451
452        conv = nn.Conv2d if dim == 2 else nn.Conv3d
453        self.conv = conv(in_channels, out_channels, 1)
454
455    def forward(self, x):
456        x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
457        x = self.conv(x)
458        return x
459
460
461#
462# 2d unet implementations
463#
464
465class ConvBlock2d(ConvBlock):
466    """@private
467    """
468    def __init__(self, in_channels, out_channels, **kwargs):
469        super().__init__(in_channels, out_channels, dim=2, **kwargs)
470
471
472class Upsampler2d(Upsampler):
473    """@private
474    """
475    def __init__(self, scale_factor,
476                 in_channels, out_channels,
477                 mode="bilinear"):
478        super().__init__(scale_factor, in_channels, out_channels, dim=2, mode=mode)
479
480
481class UNet2d(UNetBase):
482    """A 2D U-Net network for segmentation and other image-to-image tasks.
483
484    The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level.
485    The number of levels is determined by the depth argument. By default the U-Net uses two convolutional layers
486    per level, max-pooling for downsampling and linear interpolation for upsampling.
487    These implementations can be changed by providing arguments for `conv_block_impl`, `pooler_impl`
488    and `sampler_impl` respectively.
489
490    Args:
491        in_channels: The number of input image channels.
492        out_channels: The number of output image channels.
493        depth: The number of encoder / decoder levels of the U-Net.
494        initial_features: The initial number of features, corresponding to the features of the first conv block.
495        gain: The gain factor for increasing the features after each level.
496        final_activation: The activation applied after the output convolution or last decoder layer.
497        return_side_outputs: Whether to return the outputs after each decoder level.
498        conv_block_impl: The implementation of the convolutional block.
499        pooler_impl: The implementation of the pooling layer.
500        postprocessing: A postprocessing function to apply after the U-Net output.
501        check_shape: Whether to check the input shape to the U-Net forward call.
502        conv_block_kwargs: The keyword arguments for the convolutional block.
503    """
504    def __init__(
505        self,
506        in_channels: int,
507        out_channels: int,
508        depth: int = 4,
509        initial_features: int = 32,
510        gain: int = 2,
511        final_activation=None,
512        return_side_outputs: bool = False,
513        conv_block_impl: nn.Module = ConvBlock2d,
514        pooler_impl: nn.Module = nn.MaxPool2d,
515        sampler_impl: nn.Module = Upsampler2d,
516        postprocessing: Optional[Union[nn.Module, str]] = None,
517        check_shape: bool = True,
518        **conv_block_kwargs,
519    ):
520        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
521        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
522        scale_factors = depth * [2]
523
524        if return_side_outputs:
525            if isinstance(out_channels, int) or out_channels is None:
526                out_channels = [out_channels] * depth
527            if len(out_channels) != depth:
528                raise ValueError()
529            out_conv = nn.ModuleList(
530                [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
531            )
532        else:
533            out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1)
534
535        super().__init__(
536            encoder=Encoder(
537                features=features_encoder,
538                scale_factors=scale_factors,
539                conv_block_impl=conv_block_impl,
540                pooler_impl=pooler_impl,
541                **conv_block_kwargs
542            ),
543            decoder=Decoder(
544                features=features_decoder,
545                scale_factors=scale_factors[::-1],
546                conv_block_impl=conv_block_impl,
547                sampler_impl=sampler_impl,
548                **conv_block_kwargs
549            ),
550            base=conv_block_impl(
551                features_encoder[-1], features_encoder[-1] * gain,
552                **conv_block_kwargs
553            ),
554            out_conv=out_conv,
555            final_activation=final_activation,
556            postprocessing=postprocessing,
557            check_shape=check_shape,
558        )
559        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
560                            "initial_features": initial_features, "gain": gain,
561                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
562                            "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl,
563                            "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs}
564
565
566#
567# 3d unet implementations
568#
569
570class ConvBlock3d(ConvBlock):
571    """@private
572    """
573    def __init__(self, in_channels, out_channels, **kwargs):
574        super().__init__(in_channels, out_channels, dim=3, **kwargs)
575
576
577class Upsampler3d(Upsampler):
578    """@private
579    """
580    def __init__(self, scale_factor, in_channels, out_channels, mode="trilinear"):
581        super().__init__(scale_factor, in_channels, out_channels, dim=3, mode=mode)
582
583
584class AnisotropicUNet(UNetBase):
585    """A 3D U-Net network for segmentation and other image-to-image tasks.
586
587    The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level.
588    The number of levels is determined by the length of the scale_factors argument.
589    The scale factors determine the pooling factors for each level. By specifying [1, 2, 2] the pooling
590    is done in an anisotropic fashion, i.e. only across the xy-plane,
591    by specifying [2, 2, 2] it is done in an isotropic fashion.
592
593    By default the U-Net uses two convolutional layers per level.
594    This can be changed by providing an argument for `conv_block_impl`.
595
596    Args:
597        in_channels: The number of input image channels.
598        out_channels: The number of output image channels.
599        scale_factors: The factors for max pooling for the levels of the U-Net.
600        initial_features: The initial number of features, corresponding to the features of the first conv block.
601        gain: The gain factor for increasing the features after each level.
602        final_activation: The activation applied after the output convolution or last decoder layer.
603        return_side_outputs: Whether to return the outputs after each decoder level.
604        conv_block_impl: The implementation of the convolutional block.
605        anisotropic_kernel: Whether to use an anisotropic kernel in addition to anisotropic scaling factor.
606        postprocessing: A postprocessing function to apply after the U-Net output.
607        check_shape: Whether to check the input shape to the U-Net forward call.
608        conv_block_kwargs: The keyword arguments for the convolutional block.
609    """
610    def __init__(
611        self,
612        in_channels: int,
613        out_channels: int,
614        scale_factors: List[List[int]],
615        initial_features: int = 32,
616        gain: int = 2,
617        final_activation: Optional[Union[str, nn.Module]] = None,
618        return_side_outputs: bool = False,
619        conv_block_impl: nn.Module = ConvBlock3d,
620        anisotropic_kernel: bool = False,
621        postprocessing: Optional[Union[str, nn.Module]] = None,
622        check_shape: bool = True,
623        **conv_block_kwargs,
624    ):
625        depth = len(scale_factors)
626        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
627        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
628
629        if return_side_outputs:
630            if isinstance(out_channels, int) or out_channels is None:
631                out_channels = [out_channels] * depth
632            if len(out_channels) != depth:
633                raise ValueError()
634            out_conv = nn.ModuleList(
635                [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
636            )
637        else:
638            out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1)
639
640        super().__init__(
641            encoder=Encoder(
642                features=features_encoder,
643                scale_factors=scale_factors,
644                conv_block_impl=conv_block_impl,
645                pooler_impl=nn.MaxPool3d,
646                anisotropic_kernel=anisotropic_kernel,
647                **conv_block_kwargs
648            ),
649            decoder=Decoder(
650                features=features_decoder,
651                scale_factors=scale_factors[::-1],
652                conv_block_impl=conv_block_impl,
653                sampler_impl=Upsampler3d,
654                anisotropic_kernel=anisotropic_kernel,
655                **conv_block_kwargs
656            ),
657            base=conv_block_impl(
658                features_encoder[-1], features_encoder[-1] * gain, **conv_block_kwargs
659            ),
660            out_conv=out_conv,
661            final_activation=final_activation,
662            postprocessing=postprocessing,
663            check_shape=check_shape,
664        )
665        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors,
666                            "initial_features": initial_features, "gain": gain,
667                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
668                            "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel,
669                            "postprocessing": postprocessing, **conv_block_kwargs}
670
671    def _check_shape(self, x):
672        spatial_shape = tuple(x.shape)[2:]
673        scale_factors = self.init_kwargs.get("scale_factors", [[2, 2, 2]]*len(self.encoder))
674        factor = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
675        if len(spatial_shape) != len(factor):
676            msg = f"Invalid shape for U-Net: dimensions don't agree {len(spatial_shape)} != {len(factor)}"
677            raise ValueError(msg)
678        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
679            msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}"
680            raise ValueError(msg)
681
682
683class UNet3d(AnisotropicUNet):
684    """A 3D U-Net network for segmentation and other image-to-image tasks.
685
686    This class uses the same implementation as `AnisotropicUNet`, with isotropic scaling in each level.
687
688    Args:
689        in_channels: The number of input image channels.
690        out_channels: The number of output image channels.
691        depth: The number of encoder / decoder levels of the U-Net.
692        initial_features: The initial number of features, corresponding to the features of the first conv block.
693        gain: The gain factor for increasing the features after each level.
694        final_activation: The activation applied after the output convolution or last decoder layer.
695        return_side_outputs: Whether to return the outputs after each decoder level.
696        conv_block_impl: The implementation of the convolutional block.
697        postprocessing: A postprocessing function to apply after the U-Net output.
698        check_shape: Whether to check the input shape to the U-Net forward call.
699        conv_block_kwargs: The keyword arguments for the convolutional block.
700    """
701    def __init__(
702        self,
703        in_channels: int,
704        out_channels: int,
705        depth: int = 4,
706        initial_features: int = 32,
707        gain: int = 2,
708        final_activation: Optional[Union[str, nn.Module]] = None,
709        return_side_outputs: bool = False,
710        conv_block_impl: nn.Module = ConvBlock3d,
711        postprocessing: Optional[Union[str, nn.Module]] = None,
712        check_shape: bool = True,
713        **conv_block_kwargs,
714    ):
715        scale_factors = depth * [2]
716        super().__init__(in_channels, out_channels, scale_factors,
717                         initial_features=initial_features, gain=gain,
718                         final_activation=final_activation,
719                         return_side_outputs=return_side_outputs,
720                         anisotropic_kernel=False,
721                         postprocessing=postprocessing,
722                         conv_block_impl=conv_block_impl,
723                         check_shape=check_shape,
724                         **conv_block_kwargs)
725        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
726                            "initial_features": initial_features, "gain": gain,
727                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
728                            "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}
class UNetBase(torch.nn.modules.module.Module):
105class UNetBase(nn.Module):
106    """Base class for implementing a U-Net.
107
108    Args:
109        encoder: The encoder of the U-Net.
110        base: The base layer of the U-Net.
111        decoder: The decoder of the U-Net.
112        out_conv: The output convolution applied after the last decoder layer.
113        final_activation: The activation applied after the output convolution or last decoder layer.
114        postprocessing: A postprocessing function to apply after the U-Net output.
115        check_shape: Whether to check the input shape to the U-Net forward call.
116    """
117    def __init__(
118        self,
119        encoder: nn.Module,
120        base: nn.Module,
121        decoder: nn.Module,
122        out_conv: Optional[nn.Module] = None,
123        final_activation: Optional[Union[nn.Module, str]] = None,
124        postprocessing: Optional[Union[nn.Module, str]] = None,
125        check_shape: bool = True,
126    ):
127        super().__init__()
128        if len(encoder) != len(decoder):
129            raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})")
130
131        self.encoder = encoder
132        self.base = base
133        self.decoder = decoder
134
135        if out_conv is None:
136            self.return_decoder_outputs = False
137            self._out_channels = self.decoder.out_channels
138        elif isinstance(out_conv, nn.ModuleList):
139            if len(out_conv) != len(self.decoder):
140                raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}")
141            self.return_decoder_outputs = True
142            self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv]
143        else:
144            self.return_decoder_outputs = False
145            self._out_channels = out_conv.out_channels
146        self.out_conv = out_conv
147        self.check_shape = check_shape
148        self.final_activation = self._get_activation(final_activation)
149        self.postprocessing = self._get_postprocessing(postprocessing)
150
151    @property
152    def in_channels(self):
153        return self.encoder.in_channels
154
155    @property
156    def out_channels(self):
157        return self._out_channels
158
159    @property
160    def depth(self):
161        return len(self.encoder)
162
163    def _get_activation(self, activation):
164        return_activation = None
165        if activation is None:
166            return None
167        if isinstance(activation, nn.Module):
168            return activation
169        if isinstance(activation, str):
170            return_activation = getattr(nn, activation, None)
171        if return_activation is None:
172            raise ValueError(f"Invalid activation: {activation}")
173        return return_activation()
174
175    def _get_postprocessing(self, postprocessing):
176        if postprocessing is None:
177            return None
178        elif isinstance(postprocessing, nn.Module):
179            return postprocessing
180        elif postprocessing in POSTPROCESSING:
181            return POSTPROCESSING[postprocessing]()
182        else:
183            raise ValueError(f"Invalid postprocessing: {postprocessing}")
184
185    # load encoder / decoder / base states for pretraining
186    def load_encoder_state(self, state):
187        self.encoder.load_state_dict(state)
188
189    def load_decoder_state(self, state):
190        self.decoder.load_state_dict(state)
191
192    def load_base_state(self, state):
193        self.base.load_state_dict(state)
194
195    def _apply_default(self, x):
196        self.encoder.return_outputs = True
197        self.decoder.return_outputs = False
198
199        x, encoder_out = self.encoder(x)
200        x = self.base(x)
201        x = self.decoder(x, encoder_inputs=encoder_out[::-1])
202
203        if self.out_conv is not None:
204            x = self.out_conv(x)
205        if self.final_activation is not None:
206            x = self.final_activation(x)
207        if self.postprocessing is not None:
208            x = self.postprocessing(x)
209
210        return x
211
212    def _apply_with_side_outputs(self, x):
213        self.encoder.return_outputs = True
214        self.decoder.return_outputs = True
215
216        x, encoder_out = self.encoder(x)
217        x = self.base(x)
218        x = self.decoder(x, encoder_inputs=encoder_out[::-1])
219
220        x = [x if conv is None else conv(xx) for xx, conv in zip(x, self.out_conv)]
221        if self.final_activation is not None:
222            x = [self.final_activation(xx) for xx in x]
223
224        if self.postprocessing is not None:
225            x = [self.postprocessing(xx) for xx in x]
226
227        # we reverse the list to have the full shape output as first element
228        return x[::-1]
229
230    def _check_shape(self, x):
231        spatial_shape = tuple(x.shape)[2:]
232        depth = len(self.encoder)
233        factor = [2**depth] * len(spatial_shape)
234        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
235            msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}"
236            raise ValueError(msg)
237
238    def forward(self, x: torch.Tensor) -> torch.tensor:
239        """Apply U-Net to input data.
240
241        Args:
242            x: The input data.
243
244        Returns:
245            The output of the U-Net.
246        """
247        # Cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference.
248        # x = x.float()
249        if getattr(self, "check_shape", True):
250            self._check_shape(x)
251        if self.return_decoder_outputs:
252            return self._apply_with_side_outputs(x)
253        else:
254            return self._apply_default(x)

Base class for implementing a U-Net.

Arguments:
  • encoder: The encoder of the U-Net.
  • base: The base layer of the U-Net.
  • decoder: The decoder of the U-Net.
  • out_conv: The output convolution applied after the last decoder layer.
  • final_activation: The activation applied after the output convolution or last decoder layer.
  • postprocessing: A postprocessing function to apply after the U-Net output.
  • check_shape: Whether to check the input shape to the U-Net forward call.
UNetBase( encoder: torch.nn.modules.module.Module, base: torch.nn.modules.module.Module, decoder: torch.nn.modules.module.Module, out_conv: Optional[torch.nn.modules.module.Module] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, postprocessing: Union[torch.nn.modules.module.Module, str, NoneType] = None, check_shape: bool = True)
117    def __init__(
118        self,
119        encoder: nn.Module,
120        base: nn.Module,
121        decoder: nn.Module,
122        out_conv: Optional[nn.Module] = None,
123        final_activation: Optional[Union[nn.Module, str]] = None,
124        postprocessing: Optional[Union[nn.Module, str]] = None,
125        check_shape: bool = True,
126    ):
127        super().__init__()
128        if len(encoder) != len(decoder):
129            raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})")
130
131        self.encoder = encoder
132        self.base = base
133        self.decoder = decoder
134
135        if out_conv is None:
136            self.return_decoder_outputs = False
137            self._out_channels = self.decoder.out_channels
138        elif isinstance(out_conv, nn.ModuleList):
139            if len(out_conv) != len(self.decoder):
140                raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}")
141            self.return_decoder_outputs = True
142            self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv]
143        else:
144            self.return_decoder_outputs = False
145            self._out_channels = out_conv.out_channels
146        self.out_conv = out_conv
147        self.check_shape = check_shape
148        self.final_activation = self._get_activation(final_activation)
149        self.postprocessing = self._get_postprocessing(postprocessing)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

encoder
base
decoder
out_conv
check_shape
final_activation
postprocessing
in_channels
151    @property
152    def in_channels(self):
153        return self.encoder.in_channels
out_channels
155    @property
156    def out_channels(self):
157        return self._out_channels
depth
159    @property
160    def depth(self):
161        return len(self.encoder)
def load_encoder_state(self, state):
186    def load_encoder_state(self, state):
187        self.encoder.load_state_dict(state)
def load_decoder_state(self, state):
189    def load_decoder_state(self, state):
190        self.decoder.load_state_dict(state)
def load_base_state(self, state):
192    def load_base_state(self, state):
193        self.base.load_state_dict(state)
def forward( self, x: torch.Tensor) -> <built-in method tensor of type object at 0x7fbc00905ba0>:
238    def forward(self, x: torch.Tensor) -> torch.tensor:
239        """Apply U-Net to input data.
240
241        Args:
242            x: The input data.
243
244        Returns:
245            The output of the U-Net.
246        """
247        # Cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference.
248        # x = x.float()
249        if getattr(self, "check_shape", True):
250            self._check_shape(x)
251        if self.return_decoder_outputs:
252            return self._apply_with_side_outputs(x)
253        else:
254            return self._apply_default(x)

Apply U-Net to input data.

Arguments:
  • x: The input data.
Returns:

The output of the U-Net.

class UNet2d(UNetBase):
482class UNet2d(UNetBase):
483    """A 2D U-Net network for segmentation and other image-to-image tasks.
484
485    The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level.
486    The number of levels is determined by the depth argument. By default the U-Net uses two convolutional layers
487    per level, max-pooling for downsampling and linear interpolation for upsampling.
488    These implementations can be changed by providing arguments for `conv_block_impl`, `pooler_impl`
489    and `sampler_impl` respectively.
490
491    Args:
492        in_channels: The number of input image channels.
493        out_channels: The number of output image channels.
494        depth: The number of encoder / decoder levels of the U-Net.
495        initial_features: The initial number of features, corresponding to the features of the first conv block.
496        gain: The gain factor for increasing the features after each level.
497        final_activation: The activation applied after the output convolution or last decoder layer.
498        return_side_outputs: Whether to return the outputs after each decoder level.
499        conv_block_impl: The implementation of the convolutional block.
500        pooler_impl: The implementation of the pooling layer.
501        postprocessing: A postprocessing function to apply after the U-Net output.
502        check_shape: Whether to check the input shape to the U-Net forward call.
503        conv_block_kwargs: The keyword arguments for the convolutional block.
504    """
505    def __init__(
506        self,
507        in_channels: int,
508        out_channels: int,
509        depth: int = 4,
510        initial_features: int = 32,
511        gain: int = 2,
512        final_activation=None,
513        return_side_outputs: bool = False,
514        conv_block_impl: nn.Module = ConvBlock2d,
515        pooler_impl: nn.Module = nn.MaxPool2d,
516        sampler_impl: nn.Module = Upsampler2d,
517        postprocessing: Optional[Union[nn.Module, str]] = None,
518        check_shape: bool = True,
519        **conv_block_kwargs,
520    ):
521        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
522        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
523        scale_factors = depth * [2]
524
525        if return_side_outputs:
526            if isinstance(out_channels, int) or out_channels is None:
527                out_channels = [out_channels] * depth
528            if len(out_channels) != depth:
529                raise ValueError()
530            out_conv = nn.ModuleList(
531                [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
532            )
533        else:
534            out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1)
535
536        super().__init__(
537            encoder=Encoder(
538                features=features_encoder,
539                scale_factors=scale_factors,
540                conv_block_impl=conv_block_impl,
541                pooler_impl=pooler_impl,
542                **conv_block_kwargs
543            ),
544            decoder=Decoder(
545                features=features_decoder,
546                scale_factors=scale_factors[::-1],
547                conv_block_impl=conv_block_impl,
548                sampler_impl=sampler_impl,
549                **conv_block_kwargs
550            ),
551            base=conv_block_impl(
552                features_encoder[-1], features_encoder[-1] * gain,
553                **conv_block_kwargs
554            ),
555            out_conv=out_conv,
556            final_activation=final_activation,
557            postprocessing=postprocessing,
558            check_shape=check_shape,
559        )
560        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
561                            "initial_features": initial_features, "gain": gain,
562                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
563                            "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl,
564                            "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs}

A 2D U-Net network for segmentation and other image-to-image tasks.

The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level. The number of levels is determined by the depth argument. By default the U-Net uses two convolutional layers per level, max-pooling for downsampling and linear interpolation for upsampling. These implementations can be changed by providing arguments for conv_block_impl, pooler_impl and sampler_impl respectively.

Arguments:
  • in_channels: The number of input image channels.
  • out_channels: The number of output image channels.
  • depth: The number of encoder / decoder levels of the U-Net.
  • initial_features: The initial number of features, corresponding to the features of the first conv block.
  • gain: The gain factor for increasing the features after each level.
  • final_activation: The activation applied after the output convolution or last decoder layer.
  • return_side_outputs: Whether to return the outputs after each decoder level.
  • conv_block_impl: The implementation of the convolutional block.
  • pooler_impl: The implementation of the pooling layer.
  • postprocessing: A postprocessing function to apply after the U-Net output.
  • check_shape: Whether to check the input shape to the U-Net forward call.
  • conv_block_kwargs: The keyword arguments for the convolutional block.
UNet2d( in_channels: int, out_channels: int, depth: int = 4, initial_features: int = 32, gain: int = 2, final_activation=None, return_side_outputs: bool = False, conv_block_impl: torch.nn.modules.module.Module = <class 'torch_em.model.unet.ConvBlock2d'>, pooler_impl: torch.nn.modules.module.Module = <class 'torch.nn.modules.pooling.MaxPool2d'>, sampler_impl: torch.nn.modules.module.Module = <class 'torch_em.model.unet.Upsampler2d'>, postprocessing: Union[torch.nn.modules.module.Module, str, NoneType] = None, check_shape: bool = True, **conv_block_kwargs)
505    def __init__(
506        self,
507        in_channels: int,
508        out_channels: int,
509        depth: int = 4,
510        initial_features: int = 32,
511        gain: int = 2,
512        final_activation=None,
513        return_side_outputs: bool = False,
514        conv_block_impl: nn.Module = ConvBlock2d,
515        pooler_impl: nn.Module = nn.MaxPool2d,
516        sampler_impl: nn.Module = Upsampler2d,
517        postprocessing: Optional[Union[nn.Module, str]] = None,
518        check_shape: bool = True,
519        **conv_block_kwargs,
520    ):
521        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
522        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
523        scale_factors = depth * [2]
524
525        if return_side_outputs:
526            if isinstance(out_channels, int) or out_channels is None:
527                out_channels = [out_channels] * depth
528            if len(out_channels) != depth:
529                raise ValueError()
530            out_conv = nn.ModuleList(
531                [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
532            )
533        else:
534            out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1)
535
536        super().__init__(
537            encoder=Encoder(
538                features=features_encoder,
539                scale_factors=scale_factors,
540                conv_block_impl=conv_block_impl,
541                pooler_impl=pooler_impl,
542                **conv_block_kwargs
543            ),
544            decoder=Decoder(
545                features=features_decoder,
546                scale_factors=scale_factors[::-1],
547                conv_block_impl=conv_block_impl,
548                sampler_impl=sampler_impl,
549                **conv_block_kwargs
550            ),
551            base=conv_block_impl(
552                features_encoder[-1], features_encoder[-1] * gain,
553                **conv_block_kwargs
554            ),
555            out_conv=out_conv,
556            final_activation=final_activation,
557            postprocessing=postprocessing,
558            check_shape=check_shape,
559        )
560        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
561                            "initial_features": initial_features, "gain": gain,
562                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
563                            "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl,
564                            "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

init_kwargs
class AnisotropicUNet(UNetBase):
585class AnisotropicUNet(UNetBase):
586    """A 3D U-Net network for segmentation and other image-to-image tasks.
587
588    The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level.
589    The number of levels is determined by the length of the scale_factors argument.
590    The scale factors determine the pooling factors for each level. By specifying [1, 2, 2] the pooling
591    is done in an anisotropic fashion, i.e. only across the xy-plane,
592    by specifying [2, 2, 2] it is done in an isotropic fashion.
593
594    By default the U-Net uses two convolutional layers per level.
595    This can be changed by providing an argument for `conv_block_impl`.
596
597    Args:
598        in_channels: The number of input image channels.
599        out_channels: The number of output image channels.
600        scale_factors: The factors for max pooling for the levels of the U-Net.
601        initial_features: The initial number of features, corresponding to the features of the first conv block.
602        gain: The gain factor for increasing the features after each level.
603        final_activation: The activation applied after the output convolution or last decoder layer.
604        return_side_outputs: Whether to return the outputs after each decoder level.
605        conv_block_impl: The implementation of the convolutional block.
606        anisotropic_kernel: Whether to use an anisotropic kernel in addition to anisotropic scaling factor.
607        postprocessing: A postprocessing function to apply after the U-Net output.
608        check_shape: Whether to check the input shape to the U-Net forward call.
609        conv_block_kwargs: The keyword arguments for the convolutional block.
610    """
611    def __init__(
612        self,
613        in_channels: int,
614        out_channels: int,
615        scale_factors: List[List[int]],
616        initial_features: int = 32,
617        gain: int = 2,
618        final_activation: Optional[Union[str, nn.Module]] = None,
619        return_side_outputs: bool = False,
620        conv_block_impl: nn.Module = ConvBlock3d,
621        anisotropic_kernel: bool = False,
622        postprocessing: Optional[Union[str, nn.Module]] = None,
623        check_shape: bool = True,
624        **conv_block_kwargs,
625    ):
626        depth = len(scale_factors)
627        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
628        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
629
630        if return_side_outputs:
631            if isinstance(out_channels, int) or out_channels is None:
632                out_channels = [out_channels] * depth
633            if len(out_channels) != depth:
634                raise ValueError()
635            out_conv = nn.ModuleList(
636                [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
637            )
638        else:
639            out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1)
640
641        super().__init__(
642            encoder=Encoder(
643                features=features_encoder,
644                scale_factors=scale_factors,
645                conv_block_impl=conv_block_impl,
646                pooler_impl=nn.MaxPool3d,
647                anisotropic_kernel=anisotropic_kernel,
648                **conv_block_kwargs
649            ),
650            decoder=Decoder(
651                features=features_decoder,
652                scale_factors=scale_factors[::-1],
653                conv_block_impl=conv_block_impl,
654                sampler_impl=Upsampler3d,
655                anisotropic_kernel=anisotropic_kernel,
656                **conv_block_kwargs
657            ),
658            base=conv_block_impl(
659                features_encoder[-1], features_encoder[-1] * gain, **conv_block_kwargs
660            ),
661            out_conv=out_conv,
662            final_activation=final_activation,
663            postprocessing=postprocessing,
664            check_shape=check_shape,
665        )
666        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors,
667                            "initial_features": initial_features, "gain": gain,
668                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
669                            "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel,
670                            "postprocessing": postprocessing, **conv_block_kwargs}
671
672    def _check_shape(self, x):
673        spatial_shape = tuple(x.shape)[2:]
674        scale_factors = self.init_kwargs.get("scale_factors", [[2, 2, 2]]*len(self.encoder))
675        factor = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
676        if len(spatial_shape) != len(factor):
677            msg = f"Invalid shape for U-Net: dimensions don't agree {len(spatial_shape)} != {len(factor)}"
678            raise ValueError(msg)
679        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
680            msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}"
681            raise ValueError(msg)

A 3D U-Net network for segmentation and other image-to-image tasks.

The number of features for each level of the U-Net are computed as follows: initial_features * gain ** level. The number of levels is determined by the length of the scale_factors argument. The scale factors determine the pooling factors for each level. By specifying [1, 2, 2] the pooling is done in an anisotropic fashion, i.e. only across the xy-plane, by specifying [2, 2, 2] it is done in an isotropic fashion.

By default the U-Net uses two convolutional layers per level. This can be changed by providing an argument for conv_block_impl.

Arguments:
  • in_channels: The number of input image channels.
  • out_channels: The number of output image channels.
  • scale_factors: The factors for max pooling for the levels of the U-Net.
  • initial_features: The initial number of features, corresponding to the features of the first conv block.
  • gain: The gain factor for increasing the features after each level.
  • final_activation: The activation applied after the output convolution or last decoder layer.
  • return_side_outputs: Whether to return the outputs after each decoder level.
  • conv_block_impl: The implementation of the convolutional block.
  • anisotropic_kernel: Whether to use an anisotropic kernel in addition to anisotropic scaling factor.
  • postprocessing: A postprocessing function to apply after the U-Net output.
  • check_shape: Whether to check the input shape to the U-Net forward call.
  • conv_block_kwargs: The keyword arguments for the convolutional block.
AnisotropicUNet( in_channels: int, out_channels: int, scale_factors: List[List[int]], initial_features: int = 32, gain: int = 2, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, return_side_outputs: bool = False, conv_block_impl: torch.nn.modules.module.Module = <class 'torch_em.model.unet.ConvBlock3d'>, anisotropic_kernel: bool = False, postprocessing: Union[torch.nn.modules.module.Module, str, NoneType] = None, check_shape: bool = True, **conv_block_kwargs)
611    def __init__(
612        self,
613        in_channels: int,
614        out_channels: int,
615        scale_factors: List[List[int]],
616        initial_features: int = 32,
617        gain: int = 2,
618        final_activation: Optional[Union[str, nn.Module]] = None,
619        return_side_outputs: bool = False,
620        conv_block_impl: nn.Module = ConvBlock3d,
621        anisotropic_kernel: bool = False,
622        postprocessing: Optional[Union[str, nn.Module]] = None,
623        check_shape: bool = True,
624        **conv_block_kwargs,
625    ):
626        depth = len(scale_factors)
627        features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)]
628        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
629
630        if return_side_outputs:
631            if isinstance(out_channels, int) or out_channels is None:
632                out_channels = [out_channels] * depth
633            if len(out_channels) != depth:
634                raise ValueError()
635            out_conv = nn.ModuleList(
636                [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)]
637            )
638        else:
639            out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1)
640
641        super().__init__(
642            encoder=Encoder(
643                features=features_encoder,
644                scale_factors=scale_factors,
645                conv_block_impl=conv_block_impl,
646                pooler_impl=nn.MaxPool3d,
647                anisotropic_kernel=anisotropic_kernel,
648                **conv_block_kwargs
649            ),
650            decoder=Decoder(
651                features=features_decoder,
652                scale_factors=scale_factors[::-1],
653                conv_block_impl=conv_block_impl,
654                sampler_impl=Upsampler3d,
655                anisotropic_kernel=anisotropic_kernel,
656                **conv_block_kwargs
657            ),
658            base=conv_block_impl(
659                features_encoder[-1], features_encoder[-1] * gain, **conv_block_kwargs
660            ),
661            out_conv=out_conv,
662            final_activation=final_activation,
663            postprocessing=postprocessing,
664            check_shape=check_shape,
665        )
666        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors,
667                            "initial_features": initial_features, "gain": gain,
668                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
669                            "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel,
670                            "postprocessing": postprocessing, **conv_block_kwargs}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

init_kwargs
class UNet3d(AnisotropicUNet):
684class UNet3d(AnisotropicUNet):
685    """A 3D U-Net network for segmentation and other image-to-image tasks.
686
687    This class uses the same implementation as `AnisotropicUNet`, with isotropic scaling in each level.
688
689    Args:
690        in_channels: The number of input image channels.
691        out_channels: The number of output image channels.
692        depth: The number of encoder / decoder levels of the U-Net.
693        initial_features: The initial number of features, corresponding to the features of the first conv block.
694        gain: The gain factor for increasing the features after each level.
695        final_activation: The activation applied after the output convolution or last decoder layer.
696        return_side_outputs: Whether to return the outputs after each decoder level.
697        conv_block_impl: The implementation of the convolutional block.
698        postprocessing: A postprocessing function to apply after the U-Net output.
699        check_shape: Whether to check the input shape to the U-Net forward call.
700        conv_block_kwargs: The keyword arguments for the convolutional block.
701    """
702    def __init__(
703        self,
704        in_channels: int,
705        out_channels: int,
706        depth: int = 4,
707        initial_features: int = 32,
708        gain: int = 2,
709        final_activation: Optional[Union[str, nn.Module]] = None,
710        return_side_outputs: bool = False,
711        conv_block_impl: nn.Module = ConvBlock3d,
712        postprocessing: Optional[Union[str, nn.Module]] = None,
713        check_shape: bool = True,
714        **conv_block_kwargs,
715    ):
716        scale_factors = depth * [2]
717        super().__init__(in_channels, out_channels, scale_factors,
718                         initial_features=initial_features, gain=gain,
719                         final_activation=final_activation,
720                         return_side_outputs=return_side_outputs,
721                         anisotropic_kernel=False,
722                         postprocessing=postprocessing,
723                         conv_block_impl=conv_block_impl,
724                         check_shape=check_shape,
725                         **conv_block_kwargs)
726        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
727                            "initial_features": initial_features, "gain": gain,
728                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
729                            "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}

A 3D U-Net network for segmentation and other image-to-image tasks.

This class uses the same implementation as AnisotropicUNet, with isotropic scaling in each level.

Arguments:
  • in_channels: The number of input image channels.
  • out_channels: The number of output image channels.
  • depth: The number of encoder / decoder levels of the U-Net.
  • initial_features: The initial number of features, corresponding to the features of the first conv block.
  • gain: The gain factor for increasing the features after each level.
  • final_activation: The activation applied after the output convolution or last decoder layer.
  • return_side_outputs: Whether to return the outputs after each decoder level.
  • conv_block_impl: The implementation of the convolutional block.
  • postprocessing: A postprocessing function to apply after the U-Net output.
  • check_shape: Whether to check the input shape to the U-Net forward call.
  • conv_block_kwargs: The keyword arguments for the convolutional block.
UNet3d( in_channels: int, out_channels: int, depth: int = 4, initial_features: int = 32, gain: int = 2, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, return_side_outputs: bool = False, conv_block_impl: torch.nn.modules.module.Module = <class 'torch_em.model.unet.ConvBlock3d'>, postprocessing: Union[torch.nn.modules.module.Module, str, NoneType] = None, check_shape: bool = True, **conv_block_kwargs)
702    def __init__(
703        self,
704        in_channels: int,
705        out_channels: int,
706        depth: int = 4,
707        initial_features: int = 32,
708        gain: int = 2,
709        final_activation: Optional[Union[str, nn.Module]] = None,
710        return_side_outputs: bool = False,
711        conv_block_impl: nn.Module = ConvBlock3d,
712        postprocessing: Optional[Union[str, nn.Module]] = None,
713        check_shape: bool = True,
714        **conv_block_kwargs,
715    ):
716        scale_factors = depth * [2]
717        super().__init__(in_channels, out_channels, scale_factors,
718                         initial_features=initial_features, gain=gain,
719                         final_activation=final_activation,
720                         return_side_outputs=return_side_outputs,
721                         anisotropic_kernel=False,
722                         postprocessing=postprocessing,
723                         conv_block_impl=conv_block_impl,
724                         check_shape=check_shape,
725                         **conv_block_kwargs)
726        self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth,
727                            "initial_features": initial_features, "gain": gain,
728                            "final_activation": final_activation, "return_side_outputs": return_side_outputs,
729                            "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

init_kwargs