torch_em.model.unetr

  1from functools import partial
  2from collections import OrderedDict
  3from typing import Optional, Tuple, Union, Literal
  4
  5import torch
  6import torch.nn as nn
  7import torch.nn.functional as F
  8
  9from .vit import get_vision_transformer
 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs
 11
 12try:
 13    from micro_sam.util import get_sam_model
 14except ImportError:
 15    get_sam_model = None
 16
 17try:
 18    from micro_sam2.util import get_sam2_model
 19except ImportError:
 20    get_sam2_model = None
 21
 22
 23#
 24# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / MAE / ScaleMAE) + UNet Decoder from `torch_em`]
 25#
 26
 27
 28class UNETRBase(nn.Module):
 29    """Base class for implementing a UNETR.
 30
 31    Args:
 32        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 33        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
 34        encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
 35        decoder: The convolutional decoder.
 36        out_channels: The number of output channels of the UNETR.
 37        use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
 38        use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model.
 39        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 40        resize_input: Whether to resize the input images to match `img_size`.
 41            By default, it resizes the inputs to match the `img_size`.
 42        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 43            Can either be a filepath or an already loaded checkpoint.
 44        final_activation: The activation to apply to the UNETR output.
 45        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 46        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 47        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 48            By default, it uses resampling for upsampling.
 49    """
 50    def __init__(
 51        self,
 52        img_size: int = 1024,
 53        backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2",
 54        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 55        decoder: Optional[nn.Module] = None,
 56        out_channels: int = 1,
 57        use_sam_stats: bool = False,
 58        use_mae_stats: bool = False,
 59        use_dino_stats: bool = False,
 60        resize_input: bool = True,
 61        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 62        final_activation: Optional[Union[str, nn.Module]] = None,
 63        use_skip_connection: bool = True,
 64        embed_dim: Optional[int] = None,
 65        use_conv_transpose: bool = False,
 66        **kwargs
 67    ) -> None:
 68        super().__init__()
 69
 70        self.img_size = img_size
 71        self.use_sam_stats = use_sam_stats
 72        self.use_mae_stats = use_mae_stats
 73        self.use_dino_stats = use_dino_stats
 74        self.use_skip_connection = use_skip_connection
 75        self.resize_input = resize_input
 76        self.use_conv_transpose = use_conv_transpose
 77        self.backbone = backbone
 78
 79        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b"
 80            print(f"Using {encoder} from {backbone.upper()}")
 81            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
 82
 83            if encoder_checkpoint is not None:
 84                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
 85
 86            if embed_dim is None:
 87                embed_dim = self.encoder.embed_dim
 88
 89        else:  # `nn.Module` ViT backbone
 90            self.encoder = encoder
 91
 92            have_neck = False
 93            for name, _ in self.encoder.named_parameters():
 94                if name.startswith("neck"):
 95                    have_neck = True
 96
 97            if embed_dim is None:
 98                if have_neck:
 99                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
100                else:
101                    embed_dim = self.encoder.patch_embed.proj.out_channels
102
103        self.embed_dim = embed_dim
104        self.final_activation = self._get_activation(final_activation)
105
106    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
107        """Function to load pretrained weights to the image encoder.
108        """
109        if isinstance(checkpoint, str):
110            if backbone == "sam" and isinstance(encoder, str):
111                # If we have a SAM encoder, then we first try to load the full SAM Model
112                # (using micro_sam) and otherwise fall back on directly loading the encoder state
113                # from the checkpoint
114                try:
115                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
116                    encoder_state = model.image_encoder.state_dict()
117                except Exception:
118                    # Try loading the encoder state directly from a checkpoint.
119                    encoder_state = torch.load(checkpoint, weights_only=False)
120
121            elif backbone == "sam2" and isinstance(encoder, str):
122                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
123                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
124                # from the checkpoint
125                try:
126                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
127                    encoder_state = model.image_encoder.state_dict()
128                except Exception:
129                    # Try loading the encoder state directly from a checkpoint.
130                    encoder_state = torch.load(checkpoint, weights_only=False)
131
132            elif backbone == "mae":
133                # vit initialization hints from:
134                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
135                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
136                encoder_state = OrderedDict({
137                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
138                })
139                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
140                current_encoder_state = self.encoder.state_dict()
141                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
142                    del self.encoder.head
143
144            elif backbone == "scalemae":
145                # Load the encoder state directly from a checkpoint.
146                encoder_state = torch.load(checkpoint)["model"]
147                encoder_state = OrderedDict({
148                    k: v for k, v in encoder_state.items()
149                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
150                })
151
152                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
153                current_encoder_state = self.encoder.state_dict()
154                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
155                    del self.encoder.head
156
157                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
158                    del self.encoder.pos_embed
159
160            elif backbone == "dinov3":  # Load the encoder state directly from a checkpoint.
161                encoder_state = torch.load(checkpoint)
162
163            else:
164                raise ValueError(
165                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
166                )
167
168        else:
169            encoder_state = checkpoint
170
171        self.encoder.load_state_dict(encoder_state)
172
173    def _get_activation(self, activation):
174        return_activation = None
175        if activation is None:
176            return None
177        if isinstance(activation, nn.Module):
178            return activation
179        if isinstance(activation, str):
180            return_activation = getattr(nn, activation, None)
181        if return_activation is None:
182            raise ValueError(f"Invalid activation: {activation}")
183
184        return return_activation()
185
186    @staticmethod
187    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
188        """Compute the output size given input size and target long side length.
189
190        Args:
191            oldh: The input image height.
192            oldw: The input image width.
193            long_side_length: The longest side length for resizing.
194
195        Returns:
196            The new image height.
197            The new image width.
198        """
199        scale = long_side_length * 1.0 / max(oldh, oldw)
200        newh, neww = oldh * scale, oldw * scale
201        neww = int(neww + 0.5)
202        newh = int(newh + 0.5)
203        return (newh, neww)
204
205    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
206        """Resize the image so that the longest side has the correct length.
207
208        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
209
210        Args:
211            image: The input image.
212
213        Returns:
214            The resized image.
215        """
216        if image.ndim == 4:  # i.e. 2d image
217            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
218            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
219        elif image.ndim == 5:  # i.e. 3d volume
220            B, C, Z, H, W = image.shape
221            target_size = self.get_preprocess_shape(H, W, self.img_size)
222            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
223        else:
224            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
225
226    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
227        """@private
228        """
229        # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1).
230        view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
231        pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
232        pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
233        return pixel_mean, pixel_std
234
235    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
236        """@private
237        """
238        device = x.device
239        is_3d = (x.ndim == 5)
240        device, dtype = x.device, x.dtype
241
242        if self.use_sam_stats:
243            mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
244        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
245            raise NotImplementedError
246        elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"):
247            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
248        else:
249            mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
250
251        pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
252
253        if self.resize_input:
254            x = self.resize_longest_side(x)
255        input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
256
257        x = (x - pixel_mean) / pixel_std
258        h, w = x.shape[-2:]
259        padh = self.encoder.img_size - h
260        padw = self.encoder.img_size - w
261
262        if is_3d:
263            x = F.pad(x, (0, padw, 0, padh, 0, 0))
264        else:
265            x = F.pad(x, (0, padw, 0, padh))
266
267        return x, input_shape
268
269    def postprocess_masks(
270        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
271    ) -> torch.Tensor:
272        """@private
273        """
274        if masks.ndim == 4:  # i.e. 2d labels
275            masks = F.interpolate(
276                masks,
277                (self.encoder.img_size, self.encoder.img_size),
278                mode="bilinear",
279                align_corners=False,
280            )
281            masks = masks[..., : input_size[0], : input_size[1]]
282            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
283
284        elif masks.ndim == 5:  # i.e. 3d volumetric labels
285            masks = F.interpolate(
286                masks,
287                (input_size[0], self.img_size, self.img_size),
288                mode="trilinear",
289                align_corners=False,
290            )
291            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
292            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
293
294        else:
295            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
296
297        return masks
298
299
300class UNETR(UNETRBase):
301    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
302    """
303    def __init__(
304        self,
305        img_size: int = 1024,
306        backbone: str = "sam",
307        encoder: Optional[Union[nn.Module, str]] = "vit_b",
308        decoder: Optional[nn.Module] = None,
309        out_channels: int = 1,
310        use_sam_stats: bool = False,
311        use_mae_stats: bool = False,
312        use_dino_stats: bool = False,
313        resize_input: bool = True,
314        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
315        final_activation: Optional[Union[str, nn.Module]] = None,
316        use_skip_connection: bool = True,
317        embed_dim: Optional[int] = None,
318        use_conv_transpose: bool = False,
319        **kwargs
320    ) -> None:
321
322        super().__init__(
323            img_size=img_size,
324            backbone=backbone,
325            encoder=encoder,
326            decoder=decoder,
327            out_channels=out_channels,
328            use_sam_stats=use_sam_stats,
329            use_mae_stats=use_mae_stats,
330            use_dino_stats=use_dino_stats,
331            resize_input=resize_input,
332            encoder_checkpoint=encoder_checkpoint,
333            final_activation=final_activation,
334            use_skip_connection=use_skip_connection,
335            embed_dim=embed_dim,
336            use_conv_transpose=use_conv_transpose,
337            **kwargs,
338        )
339
340        encoder = self.encoder
341
342        if backbone == "sam2" and hasattr(encoder, "trunk"):
343            in_chans = encoder.trunk.patch_embed.proj.in_channels
344        elif hasattr(encoder, "in_chans"):
345            in_chans = encoder.in_chans
346        else:  # `nn.Module` ViT backbone.
347            try:
348                in_chans = encoder.patch_embed.proj.in_channels
349            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
350                in_chans = encoder.patch_embed.seq[0].c.in_channels
351
352        # parameters for the decoder network
353        depth = 3
354        initial_features = 64
355        gain = 2
356        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
357        scale_factors = depth * [2]
358        self.out_channels = out_channels
359
360        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
361        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
362
363        self.decoder = decoder or Decoder(
364            features=features_decoder,
365            scale_factors=scale_factors[::-1],
366            conv_block_impl=ConvBlock2d,
367            sampler_impl=_upsampler,
368        )
369
370        if use_skip_connection:
371            self.deconv1 = Deconv2DBlock(
372                in_channels=self.embed_dim,
373                out_channels=features_decoder[0],
374                use_conv_transpose=use_conv_transpose,
375            )
376            self.deconv2 = nn.Sequential(
377                Deconv2DBlock(
378                    in_channels=self.embed_dim,
379                    out_channels=features_decoder[0],
380                    use_conv_transpose=use_conv_transpose,
381                ),
382                Deconv2DBlock(
383                    in_channels=features_decoder[0],
384                    out_channels=features_decoder[1],
385                    use_conv_transpose=use_conv_transpose,
386                )
387            )
388            self.deconv3 = nn.Sequential(
389                Deconv2DBlock(
390                    in_channels=self.embed_dim,
391                    out_channels=features_decoder[0],
392                    use_conv_transpose=use_conv_transpose,
393                ),
394                Deconv2DBlock(
395                    in_channels=features_decoder[0],
396                    out_channels=features_decoder[1],
397                    use_conv_transpose=use_conv_transpose,
398                ),
399                Deconv2DBlock(
400                    in_channels=features_decoder[1],
401                    out_channels=features_decoder[2],
402                    use_conv_transpose=use_conv_transpose,
403                )
404            )
405            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
406        else:
407            self.deconv1 = Deconv2DBlock(
408                in_channels=self.embed_dim,
409                out_channels=features_decoder[0],
410                use_conv_transpose=use_conv_transpose,
411            )
412            self.deconv2 = Deconv2DBlock(
413                in_channels=features_decoder[0],
414                out_channels=features_decoder[1],
415                use_conv_transpose=use_conv_transpose,
416            )
417            self.deconv3 = Deconv2DBlock(
418                in_channels=features_decoder[1],
419                out_channels=features_decoder[2],
420                use_conv_transpose=use_conv_transpose,
421            )
422            self.deconv4 = Deconv2DBlock(
423                in_channels=features_decoder[2],
424                out_channels=features_decoder[3],
425                use_conv_transpose=use_conv_transpose,
426            )
427
428        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
429        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
430        self.deconv_out = _upsampler(
431            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
432        )
433        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
434
435    def forward(self, x: torch.Tensor) -> torch.Tensor:
436        """Apply the UNETR to the input data.
437
438        Args:
439            x: The input tensor.
440
441        Returns:
442            The UNETR output.
443        """
444        original_shape = x.shape[-2:]
445
446        # Reshape the inputs to the shape expected by the encoder
447        # and normalize the inputs if normalization is part of the model.
448        x, input_shape = self.preprocess(x)
449
450        encoder_outputs = self.encoder(x)
451
452        if isinstance(encoder_outputs[-1], list):
453            # `encoder_outputs` can be arranged in only two forms:
454            #   - either we only return the image embeddings
455            #   - or, we return the image embeddings and the "list" of global attention layers
456            z12, from_encoder = encoder_outputs
457        else:
458            z12 = encoder_outputs
459
460        if self.use_skip_connection:
461            from_encoder = from_encoder[::-1]
462            z9 = self.deconv1(from_encoder[0])
463            z6 = self.deconv2(from_encoder[1])
464            z3 = self.deconv3(from_encoder[2])
465            z0 = self.deconv4(x)
466
467        else:
468            z9 = self.deconv1(z12)
469            z6 = self.deconv2(z9)
470            z3 = self.deconv3(z6)
471            z0 = self.deconv4(z3)
472
473        updated_from_encoder = [z9, z6, z3]
474
475        x = self.base(z12)
476        x = self.decoder(x, encoder_inputs=updated_from_encoder)
477        x = self.deconv_out(x)
478
479        x = torch.cat([x, z0], dim=1)
480        x = self.decoder_head(x)
481
482        x = self.out_conv(x)
483        if self.final_activation is not None:
484            x = self.final_activation(x)
485
486        x = self.postprocess_masks(x, input_shape, original_shape)
487        return x
488
489
490class UNETR2D(UNETR):
491    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
492    """
493    pass
494
495
496class UNETR3D(UNETRBase):
497    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
498    """
499    def __init__(
500        self,
501        img_size: int = 1024,
502        backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2",
503        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
504        decoder: Optional[nn.Module] = None,
505        out_channels: int = 1,
506        use_sam_stats: bool = False,
507        use_mae_stats: bool = False,
508        use_dino_stats: bool = False,
509        resize_input: bool = True,
510        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
511        final_activation: Optional[Union[str, nn.Module]] = None,
512        use_skip_connection: bool = False,
513        embed_dim: Optional[int] = None,
514        use_conv_transpose: bool = False,
515        use_strip_pooling: bool = True,
516        **kwargs
517    ):
518        if use_skip_connection:
519            raise NotImplementedError("The framework cannot handle skip connections atm.")
520        if use_conv_transpose:
521            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
522
523        super().__init__(
524            img_size=img_size,
525            backbone=backbone,
526            encoder=encoder,
527            decoder=decoder,
528            out_channels=out_channels,
529            use_sam_stats=use_sam_stats,
530            use_mae_stats=use_mae_stats,
531            use_dino_stats=use_dino_stats,
532            resize_input=resize_input,
533            encoder_checkpoint=encoder_checkpoint,
534            final_activation=final_activation,
535            use_skip_connection=use_skip_connection,
536            embed_dim=embed_dim,
537            use_conv_transpose=use_conv_transpose,
538            **kwargs,
539        )
540
541        # Load the pretrained image encoder weights
542        self.image_encoder = self.encoder
543
544        # Step 2: the 3d convolutional decoder.
545        # First, get the important parameters for the decoder.
546        embed_dim = 256
547        depth = 3
548        initial_features = 64
549        gain = 2
550        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
551        scale_factors = [1, 2, 2]
552        self.out_channels = out_channels
553
554        # The mapping blocks.
555        self.deconv1 = Deconv3DBlock(
556            in_channels=embed_dim,
557            out_channels=features_decoder[0],
558            scale_factor=scale_factors,
559            use_strip_pooling=use_strip_pooling,
560        )
561        self.deconv2 = Deconv3DBlock(
562            in_channels=features_decoder[0],
563            out_channels=features_decoder[1],
564            scale_factor=scale_factors,
565            use_strip_pooling=use_strip_pooling,
566        )
567        self.deconv3 = Deconv3DBlock(
568            in_channels=features_decoder[1],
569            out_channels=features_decoder[2],
570            scale_factor=scale_factors,
571            use_strip_pooling=use_strip_pooling,
572        )
573        self.deconv4 = Deconv3DBlock(
574            in_channels=features_decoder[2],
575            out_channels=features_decoder[3],
576            scale_factor=scale_factors,
577            use_strip_pooling=use_strip_pooling,
578        )
579
580        # The core decoder block.
581        self.decoder = decoder or Decoder(
582            features=features_decoder,
583            scale_factors=[scale_factors] * depth,
584            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
585            sampler_impl=Upsampler3d,
586        )
587
588        # And the final upsampler to match the expected dimensions.
589        self.end_up = Deconv3DBlock(
590            in_channels=features_decoder[-1],
591            out_channels=features_decoder[-1],
592            scale_factor=scale_factors,
593            use_strip_pooling=use_strip_pooling,
594        )
595
596        # Additional conjunction blocks.
597        self.base = ConvBlock3dWithStrip(
598            in_channels=embed_dim,
599            out_channels=features_decoder[0],
600            use_strip_pooling=use_strip_pooling,
601        )
602
603        # And the output layers.
604        self.decoder_head = ConvBlock3dWithStrip(
605            in_channels=2 * features_decoder[-1],
606            out_channels=features_decoder[-1],
607            use_strip_pooling=use_strip_pooling,
608        )
609        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
610
611    def forward(self, x: torch.Tensor):
612        """Forward pass of the UNETR-3D model.
613
614        Args:
615            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
616
617        Returns:
618            The UNETR output.
619        """
620        B, C, Z, H, W = x.shape
621        original_shape = (Z, H, W)
622
623        # Preprocessing step
624        x, input_shape = self.preprocess(x)
625
626        # Run the image encoder.
627        curr_features = torch.stack([self.image_encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
628
629        # Prepare the counterparts for the decoder.
630        # NOTE: The section below is sequential, there's no skip connections atm.
631        z9 = self.deconv1(curr_features)
632        z6 = self.deconv2(z9)
633        z3 = self.deconv3(z6)
634        z0 = self.deconv4(z3)
635
636        # Align the features through the base block.
637        x = self.base(curr_features)
638
639        # Run the decoder.
640        updated_from_encoder = [z9, z6, z3]
641        x = self.decoder(x, encoder_inputs=updated_from_encoder)
642        x = self.end_up(x)
643
644        # And the final output head.
645        x = torch.cat([x, z0], dim=1)
646        x = self.decoder_head(x)
647        x = self.out_conv(x)
648        if self.final_activation is not None:
649            x = self.final_activation(x)
650
651        # Postprocess the output back to original size.
652        x = self.postprocess_masks(x, input_shape, original_shape)
653        return x
654
655#
656#  ADDITIONAL FUNCTIONALITIES
657#
658
659
660def _strip_pooling_layers(enabled, channels) -> nn.Module:
661    return DepthStripPooling(channels) if enabled else nn.Identity()
662
663
664class DepthStripPooling(nn.Module):
665    """@private
666    """
667    def __init__(self, channels: int, reduction: int = 4):
668        """Block for strip pooling along the depth dimension (only).
669
670        eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling
671        to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to
672        modulate the original features (using a gated residual).
673
674        For 2D (Z == 1 or Z == 3): returns input unchanged (no-op).
675
676        Args:
677            channels: The output channels.
678            reduction: The reduction of the hidden layers.
679        """
680        super().__init__()
681        hidden = max(1, channels // reduction)
682        self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1)
683        self.bn1 = nn.BatchNorm3d(hidden)
684        self.relu = nn.ReLU(inplace=True)
685        self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1)
686
687    def forward(self, x: torch.Tensor) -> torch.Tensor:
688        if x.dim() != 5:
689            raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.")
690
691        B, C, Z, H, W = x.shape
692        if Z == 1 or Z == 3:  # i.e. 2d-as-1-slice or RGB_2d-as-1-slice.
693            return x  # We simply do nothing there.
694
695        # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W)
696        feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W))
697        feat = self.conv1(feat)
698        feat = self.bn1(feat)
699        feat = self.relu(feat)
700        feat = self.conv2(feat)
701        gate = torch.sigmoid(feat).expand(B, C, Z, H, W)  # Broadcast the collapsed depth context back to all slices
702
703        # Gated residual fusion
704        return x * gate + x
705
706
707class Deconv3DBlock(nn.Module):
708    """@private
709    """
710    def __init__(
711        self,
712        scale_factor,
713        in_channels,
714        out_channels,
715        kernel_size=3,
716        anisotropic_kernel=True,
717        use_strip_pooling=True,
718    ):
719        super().__init__()
720        conv_block_kwargs = {
721            "in_channels": out_channels,
722            "out_channels": out_channels,
723            "kernel_size": kernel_size,
724            "padding": ((kernel_size - 1) // 2),
725        }
726        if anisotropic_kernel:
727            conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor)
728
729        self.block = nn.Sequential(
730            Upsampler3d(scale_factor, in_channels, out_channels),
731            nn.Conv3d(**conv_block_kwargs),
732            nn.BatchNorm3d(out_channels),
733            nn.ReLU(True),
734            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
735        )
736
737    def forward(self, x):
738        return self.block(x)
739
740
741class ConvBlock3dWithStrip(nn.Module):
742    """@private
743    """
744    def __init__(
745        self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs
746    ):
747        super().__init__()
748        self.block = nn.Sequential(
749            ConvBlock3d(in_channels, out_channels, **kwargs),
750            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
751        )
752
753    def forward(self, x):
754        return self.block(x)
755
756
757class SingleDeconv2DBlock(nn.Module):
758    """@private
759    """
760    def __init__(self, scale_factor, in_channels, out_channels):
761        super().__init__()
762        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
763
764    def forward(self, x):
765        return self.block(x)
766
767
768class SingleConv2DBlock(nn.Module):
769    """@private
770    """
771    def __init__(self, in_channels, out_channels, kernel_size):
772        super().__init__()
773        self.block = nn.Conv2d(
774            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
775        )
776
777    def forward(self, x):
778        return self.block(x)
779
780
781class Conv2DBlock(nn.Module):
782    """@private
783    """
784    def __init__(self, in_channels, out_channels, kernel_size=3):
785        super().__init__()
786        self.block = nn.Sequential(
787            SingleConv2DBlock(in_channels, out_channels, kernel_size),
788            nn.BatchNorm2d(out_channels),
789            nn.ReLU(True)
790        )
791
792    def forward(self, x):
793        return self.block(x)
794
795
796class Deconv2DBlock(nn.Module):
797    """@private
798    """
799    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
800        super().__init__()
801        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
802        self.block = nn.Sequential(
803            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
804            SingleConv2DBlock(out_channels, out_channels, kernel_size),
805            nn.BatchNorm2d(out_channels),
806            nn.ReLU(True)
807        )
808
809    def forward(self, x):
810        return self.block(x)
class UNETRBase(torch.nn.modules.module.Module):
 29class UNETRBase(nn.Module):
 30    """Base class for implementing a UNETR.
 31
 32    Args:
 33        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 34        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
 35        encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
 36        decoder: The convolutional decoder.
 37        out_channels: The number of output channels of the UNETR.
 38        use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
 39        use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model.
 40        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 41        resize_input: Whether to resize the input images to match `img_size`.
 42            By default, it resizes the inputs to match the `img_size`.
 43        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 44            Can either be a filepath or an already loaded checkpoint.
 45        final_activation: The activation to apply to the UNETR output.
 46        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 47        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 48        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 49            By default, it uses resampling for upsampling.
 50    """
 51    def __init__(
 52        self,
 53        img_size: int = 1024,
 54        backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2",
 55        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 56        decoder: Optional[nn.Module] = None,
 57        out_channels: int = 1,
 58        use_sam_stats: bool = False,
 59        use_mae_stats: bool = False,
 60        use_dino_stats: bool = False,
 61        resize_input: bool = True,
 62        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 63        final_activation: Optional[Union[str, nn.Module]] = None,
 64        use_skip_connection: bool = True,
 65        embed_dim: Optional[int] = None,
 66        use_conv_transpose: bool = False,
 67        **kwargs
 68    ) -> None:
 69        super().__init__()
 70
 71        self.img_size = img_size
 72        self.use_sam_stats = use_sam_stats
 73        self.use_mae_stats = use_mae_stats
 74        self.use_dino_stats = use_dino_stats
 75        self.use_skip_connection = use_skip_connection
 76        self.resize_input = resize_input
 77        self.use_conv_transpose = use_conv_transpose
 78        self.backbone = backbone
 79
 80        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b"
 81            print(f"Using {encoder} from {backbone.upper()}")
 82            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
 83
 84            if encoder_checkpoint is not None:
 85                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
 86
 87            if embed_dim is None:
 88                embed_dim = self.encoder.embed_dim
 89
 90        else:  # `nn.Module` ViT backbone
 91            self.encoder = encoder
 92
 93            have_neck = False
 94            for name, _ in self.encoder.named_parameters():
 95                if name.startswith("neck"):
 96                    have_neck = True
 97
 98            if embed_dim is None:
 99                if have_neck:
100                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
101                else:
102                    embed_dim = self.encoder.patch_embed.proj.out_channels
103
104        self.embed_dim = embed_dim
105        self.final_activation = self._get_activation(final_activation)
106
107    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
108        """Function to load pretrained weights to the image encoder.
109        """
110        if isinstance(checkpoint, str):
111            if backbone == "sam" and isinstance(encoder, str):
112                # If we have a SAM encoder, then we first try to load the full SAM Model
113                # (using micro_sam) and otherwise fall back on directly loading the encoder state
114                # from the checkpoint
115                try:
116                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
117                    encoder_state = model.image_encoder.state_dict()
118                except Exception:
119                    # Try loading the encoder state directly from a checkpoint.
120                    encoder_state = torch.load(checkpoint, weights_only=False)
121
122            elif backbone == "sam2" and isinstance(encoder, str):
123                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
124                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
125                # from the checkpoint
126                try:
127                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
128                    encoder_state = model.image_encoder.state_dict()
129                except Exception:
130                    # Try loading the encoder state directly from a checkpoint.
131                    encoder_state = torch.load(checkpoint, weights_only=False)
132
133            elif backbone == "mae":
134                # vit initialization hints from:
135                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
136                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
137                encoder_state = OrderedDict({
138                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
139                })
140                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
141                current_encoder_state = self.encoder.state_dict()
142                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
143                    del self.encoder.head
144
145            elif backbone == "scalemae":
146                # Load the encoder state directly from a checkpoint.
147                encoder_state = torch.load(checkpoint)["model"]
148                encoder_state = OrderedDict({
149                    k: v for k, v in encoder_state.items()
150                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
151                })
152
153                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
154                current_encoder_state = self.encoder.state_dict()
155                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
156                    del self.encoder.head
157
158                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
159                    del self.encoder.pos_embed
160
161            elif backbone == "dinov3":  # Load the encoder state directly from a checkpoint.
162                encoder_state = torch.load(checkpoint)
163
164            else:
165                raise ValueError(
166                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
167                )
168
169        else:
170            encoder_state = checkpoint
171
172        self.encoder.load_state_dict(encoder_state)
173
174    def _get_activation(self, activation):
175        return_activation = None
176        if activation is None:
177            return None
178        if isinstance(activation, nn.Module):
179            return activation
180        if isinstance(activation, str):
181            return_activation = getattr(nn, activation, None)
182        if return_activation is None:
183            raise ValueError(f"Invalid activation: {activation}")
184
185        return return_activation()
186
187    @staticmethod
188    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
189        """Compute the output size given input size and target long side length.
190
191        Args:
192            oldh: The input image height.
193            oldw: The input image width.
194            long_side_length: The longest side length for resizing.
195
196        Returns:
197            The new image height.
198            The new image width.
199        """
200        scale = long_side_length * 1.0 / max(oldh, oldw)
201        newh, neww = oldh * scale, oldw * scale
202        neww = int(neww + 0.5)
203        newh = int(newh + 0.5)
204        return (newh, neww)
205
206    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
207        """Resize the image so that the longest side has the correct length.
208
209        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
210
211        Args:
212            image: The input image.
213
214        Returns:
215            The resized image.
216        """
217        if image.ndim == 4:  # i.e. 2d image
218            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
219            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
220        elif image.ndim == 5:  # i.e. 3d volume
221            B, C, Z, H, W = image.shape
222            target_size = self.get_preprocess_shape(H, W, self.img_size)
223            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
224        else:
225            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
226
227    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
228        """@private
229        """
230        # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1).
231        view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
232        pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
233        pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
234        return pixel_mean, pixel_std
235
236    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
237        """@private
238        """
239        device = x.device
240        is_3d = (x.ndim == 5)
241        device, dtype = x.device, x.dtype
242
243        if self.use_sam_stats:
244            mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
245        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
246            raise NotImplementedError
247        elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"):
248            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
249        else:
250            mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
251
252        pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
253
254        if self.resize_input:
255            x = self.resize_longest_side(x)
256        input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
257
258        x = (x - pixel_mean) / pixel_std
259        h, w = x.shape[-2:]
260        padh = self.encoder.img_size - h
261        padw = self.encoder.img_size - w
262
263        if is_3d:
264            x = F.pad(x, (0, padw, 0, padh, 0, 0))
265        else:
266            x = F.pad(x, (0, padw, 0, padh))
267
268        return x, input_shape
269
270    def postprocess_masks(
271        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
272    ) -> torch.Tensor:
273        """@private
274        """
275        if masks.ndim == 4:  # i.e. 2d labels
276            masks = F.interpolate(
277                masks,
278                (self.encoder.img_size, self.encoder.img_size),
279                mode="bilinear",
280                align_corners=False,
281            )
282            masks = masks[..., : input_size[0], : input_size[1]]
283            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
284
285        elif masks.ndim == 5:  # i.e. 3d volumetric labels
286            masks = F.interpolate(
287                masks,
288                (input_size[0], self.img_size, self.img_size),
289                mode="trilinear",
290                align_corners=False,
291            )
292            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
293            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
294
295        else:
296            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
297
298        return masks

Base class for implementing a UNETR.

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • backbone: The name of the vision transformer implementation. One of "sam" or "mae".
  • encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
  • decoder: The convolutional decoder.
  • out_channels: The number of output channels of the UNETR.
  • use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
  • use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model.
  • use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  • resize_input: Whether to resize the input images to match img_size. By default, it resizes the inputs to match the img_size.
  • encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
  • final_activation: The activation to apply to the UNETR output.
  • use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
  • embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
  • use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
UNETRBase( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'mae', 'scalemae', 'dinov3'] = 'sam2', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
 51    def __init__(
 52        self,
 53        img_size: int = 1024,
 54        backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2",
 55        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 56        decoder: Optional[nn.Module] = None,
 57        out_channels: int = 1,
 58        use_sam_stats: bool = False,
 59        use_mae_stats: bool = False,
 60        use_dino_stats: bool = False,
 61        resize_input: bool = True,
 62        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 63        final_activation: Optional[Union[str, nn.Module]] = None,
 64        use_skip_connection: bool = True,
 65        embed_dim: Optional[int] = None,
 66        use_conv_transpose: bool = False,
 67        **kwargs
 68    ) -> None:
 69        super().__init__()
 70
 71        self.img_size = img_size
 72        self.use_sam_stats = use_sam_stats
 73        self.use_mae_stats = use_mae_stats
 74        self.use_dino_stats = use_dino_stats
 75        self.use_skip_connection = use_skip_connection
 76        self.resize_input = resize_input
 77        self.use_conv_transpose = use_conv_transpose
 78        self.backbone = backbone
 79
 80        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b"
 81            print(f"Using {encoder} from {backbone.upper()}")
 82            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
 83
 84            if encoder_checkpoint is not None:
 85                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
 86
 87            if embed_dim is None:
 88                embed_dim = self.encoder.embed_dim
 89
 90        else:  # `nn.Module` ViT backbone
 91            self.encoder = encoder
 92
 93            have_neck = False
 94            for name, _ in self.encoder.named_parameters():
 95                if name.startswith("neck"):
 96                    have_neck = True
 97
 98            if embed_dim is None:
 99                if have_neck:
100                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
101                else:
102                    embed_dim = self.encoder.patch_embed.proj.out_channels
103
104        self.embed_dim = embed_dim
105        self.final_activation = self._get_activation(final_activation)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

img_size
use_sam_stats
use_mae_stats
use_dino_stats
use_skip_connection
resize_input
use_conv_transpose
backbone
embed_dim
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
187    @staticmethod
188    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
189        """Compute the output size given input size and target long side length.
190
191        Args:
192            oldh: The input image height.
193            oldw: The input image width.
194            long_side_length: The longest side length for resizing.
195
196        Returns:
197            The new image height.
198            The new image width.
199        """
200        scale = long_side_length * 1.0 / max(oldh, oldw)
201        newh, neww = oldh * scale, oldw * scale
202        neww = int(neww + 0.5)
203        newh = int(newh + 0.5)
204        return (newh, neww)

Compute the output size given input size and target long side length.

Arguments:
  • oldh: The input image height.
  • oldw: The input image width.
  • long_side_length: The longest side length for resizing.
Returns:

The new image height. The new image width.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
206    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
207        """Resize the image so that the longest side has the correct length.
208
209        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
210
211        Args:
212            image: The input image.
213
214        Returns:
215            The resized image.
216        """
217        if image.ndim == 4:  # i.e. 2d image
218            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
219            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
220        elif image.ndim == 5:  # i.e. 3d volume
221            B, C, Z, H, W = image.shape
222            target_size = self.get_preprocess_shape(H, W, self.img_size)
223            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
224        else:
225            raise ValueError("Expected 4d or 5d inputs, got", image.shape)

Resize the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.

Arguments:
  • image: The input image.
Returns:

The resized image.

class UNETR(UNETRBase):
301class UNETR(UNETRBase):
302    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
303    """
304    def __init__(
305        self,
306        img_size: int = 1024,
307        backbone: str = "sam",
308        encoder: Optional[Union[nn.Module, str]] = "vit_b",
309        decoder: Optional[nn.Module] = None,
310        out_channels: int = 1,
311        use_sam_stats: bool = False,
312        use_mae_stats: bool = False,
313        use_dino_stats: bool = False,
314        resize_input: bool = True,
315        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
316        final_activation: Optional[Union[str, nn.Module]] = None,
317        use_skip_connection: bool = True,
318        embed_dim: Optional[int] = None,
319        use_conv_transpose: bool = False,
320        **kwargs
321    ) -> None:
322
323        super().__init__(
324            img_size=img_size,
325            backbone=backbone,
326            encoder=encoder,
327            decoder=decoder,
328            out_channels=out_channels,
329            use_sam_stats=use_sam_stats,
330            use_mae_stats=use_mae_stats,
331            use_dino_stats=use_dino_stats,
332            resize_input=resize_input,
333            encoder_checkpoint=encoder_checkpoint,
334            final_activation=final_activation,
335            use_skip_connection=use_skip_connection,
336            embed_dim=embed_dim,
337            use_conv_transpose=use_conv_transpose,
338            **kwargs,
339        )
340
341        encoder = self.encoder
342
343        if backbone == "sam2" and hasattr(encoder, "trunk"):
344            in_chans = encoder.trunk.patch_embed.proj.in_channels
345        elif hasattr(encoder, "in_chans"):
346            in_chans = encoder.in_chans
347        else:  # `nn.Module` ViT backbone.
348            try:
349                in_chans = encoder.patch_embed.proj.in_channels
350            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
351                in_chans = encoder.patch_embed.seq[0].c.in_channels
352
353        # parameters for the decoder network
354        depth = 3
355        initial_features = 64
356        gain = 2
357        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
358        scale_factors = depth * [2]
359        self.out_channels = out_channels
360
361        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
362        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
363
364        self.decoder = decoder or Decoder(
365            features=features_decoder,
366            scale_factors=scale_factors[::-1],
367            conv_block_impl=ConvBlock2d,
368            sampler_impl=_upsampler,
369        )
370
371        if use_skip_connection:
372            self.deconv1 = Deconv2DBlock(
373                in_channels=self.embed_dim,
374                out_channels=features_decoder[0],
375                use_conv_transpose=use_conv_transpose,
376            )
377            self.deconv2 = nn.Sequential(
378                Deconv2DBlock(
379                    in_channels=self.embed_dim,
380                    out_channels=features_decoder[0],
381                    use_conv_transpose=use_conv_transpose,
382                ),
383                Deconv2DBlock(
384                    in_channels=features_decoder[0],
385                    out_channels=features_decoder[1],
386                    use_conv_transpose=use_conv_transpose,
387                )
388            )
389            self.deconv3 = nn.Sequential(
390                Deconv2DBlock(
391                    in_channels=self.embed_dim,
392                    out_channels=features_decoder[0],
393                    use_conv_transpose=use_conv_transpose,
394                ),
395                Deconv2DBlock(
396                    in_channels=features_decoder[0],
397                    out_channels=features_decoder[1],
398                    use_conv_transpose=use_conv_transpose,
399                ),
400                Deconv2DBlock(
401                    in_channels=features_decoder[1],
402                    out_channels=features_decoder[2],
403                    use_conv_transpose=use_conv_transpose,
404                )
405            )
406            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
407        else:
408            self.deconv1 = Deconv2DBlock(
409                in_channels=self.embed_dim,
410                out_channels=features_decoder[0],
411                use_conv_transpose=use_conv_transpose,
412            )
413            self.deconv2 = Deconv2DBlock(
414                in_channels=features_decoder[0],
415                out_channels=features_decoder[1],
416                use_conv_transpose=use_conv_transpose,
417            )
418            self.deconv3 = Deconv2DBlock(
419                in_channels=features_decoder[1],
420                out_channels=features_decoder[2],
421                use_conv_transpose=use_conv_transpose,
422            )
423            self.deconv4 = Deconv2DBlock(
424                in_channels=features_decoder[2],
425                out_channels=features_decoder[3],
426                use_conv_transpose=use_conv_transpose,
427            )
428
429        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
430        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
431        self.deconv_out = _upsampler(
432            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
433        )
434        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
435
436    def forward(self, x: torch.Tensor) -> torch.Tensor:
437        """Apply the UNETR to the input data.
438
439        Args:
440            x: The input tensor.
441
442        Returns:
443            The UNETR output.
444        """
445        original_shape = x.shape[-2:]
446
447        # Reshape the inputs to the shape expected by the encoder
448        # and normalize the inputs if normalization is part of the model.
449        x, input_shape = self.preprocess(x)
450
451        encoder_outputs = self.encoder(x)
452
453        if isinstance(encoder_outputs[-1], list):
454            # `encoder_outputs` can be arranged in only two forms:
455            #   - either we only return the image embeddings
456            #   - or, we return the image embeddings and the "list" of global attention layers
457            z12, from_encoder = encoder_outputs
458        else:
459            z12 = encoder_outputs
460
461        if self.use_skip_connection:
462            from_encoder = from_encoder[::-1]
463            z9 = self.deconv1(from_encoder[0])
464            z6 = self.deconv2(from_encoder[1])
465            z3 = self.deconv3(from_encoder[2])
466            z0 = self.deconv4(x)
467
468        else:
469            z9 = self.deconv1(z12)
470            z6 = self.deconv2(z9)
471            z3 = self.deconv3(z6)
472            z0 = self.deconv4(z3)
473
474        updated_from_encoder = [z9, z6, z3]
475
476        x = self.base(z12)
477        x = self.decoder(x, encoder_inputs=updated_from_encoder)
478        x = self.deconv_out(x)
479
480        x = torch.cat([x, z0], dim=1)
481        x = self.decoder_head(x)
482
483        x = self.out_conv(x)
484        if self.final_activation is not None:
485            x = self.final_activation(x)
486
487        x = self.postprocess_masks(x, input_shape, original_shape)
488        return x

A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR( img_size: int = 1024, backbone: str = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
304    def __init__(
305        self,
306        img_size: int = 1024,
307        backbone: str = "sam",
308        encoder: Optional[Union[nn.Module, str]] = "vit_b",
309        decoder: Optional[nn.Module] = None,
310        out_channels: int = 1,
311        use_sam_stats: bool = False,
312        use_mae_stats: bool = False,
313        use_dino_stats: bool = False,
314        resize_input: bool = True,
315        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
316        final_activation: Optional[Union[str, nn.Module]] = None,
317        use_skip_connection: bool = True,
318        embed_dim: Optional[int] = None,
319        use_conv_transpose: bool = False,
320        **kwargs
321    ) -> None:
322
323        super().__init__(
324            img_size=img_size,
325            backbone=backbone,
326            encoder=encoder,
327            decoder=decoder,
328            out_channels=out_channels,
329            use_sam_stats=use_sam_stats,
330            use_mae_stats=use_mae_stats,
331            use_dino_stats=use_dino_stats,
332            resize_input=resize_input,
333            encoder_checkpoint=encoder_checkpoint,
334            final_activation=final_activation,
335            use_skip_connection=use_skip_connection,
336            embed_dim=embed_dim,
337            use_conv_transpose=use_conv_transpose,
338            **kwargs,
339        )
340
341        encoder = self.encoder
342
343        if backbone == "sam2" and hasattr(encoder, "trunk"):
344            in_chans = encoder.trunk.patch_embed.proj.in_channels
345        elif hasattr(encoder, "in_chans"):
346            in_chans = encoder.in_chans
347        else:  # `nn.Module` ViT backbone.
348            try:
349                in_chans = encoder.patch_embed.proj.in_channels
350            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
351                in_chans = encoder.patch_embed.seq[0].c.in_channels
352
353        # parameters for the decoder network
354        depth = 3
355        initial_features = 64
356        gain = 2
357        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
358        scale_factors = depth * [2]
359        self.out_channels = out_channels
360
361        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
362        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
363
364        self.decoder = decoder or Decoder(
365            features=features_decoder,
366            scale_factors=scale_factors[::-1],
367            conv_block_impl=ConvBlock2d,
368            sampler_impl=_upsampler,
369        )
370
371        if use_skip_connection:
372            self.deconv1 = Deconv2DBlock(
373                in_channels=self.embed_dim,
374                out_channels=features_decoder[0],
375                use_conv_transpose=use_conv_transpose,
376            )
377            self.deconv2 = nn.Sequential(
378                Deconv2DBlock(
379                    in_channels=self.embed_dim,
380                    out_channels=features_decoder[0],
381                    use_conv_transpose=use_conv_transpose,
382                ),
383                Deconv2DBlock(
384                    in_channels=features_decoder[0],
385                    out_channels=features_decoder[1],
386                    use_conv_transpose=use_conv_transpose,
387                )
388            )
389            self.deconv3 = nn.Sequential(
390                Deconv2DBlock(
391                    in_channels=self.embed_dim,
392                    out_channels=features_decoder[0],
393                    use_conv_transpose=use_conv_transpose,
394                ),
395                Deconv2DBlock(
396                    in_channels=features_decoder[0],
397                    out_channels=features_decoder[1],
398                    use_conv_transpose=use_conv_transpose,
399                ),
400                Deconv2DBlock(
401                    in_channels=features_decoder[1],
402                    out_channels=features_decoder[2],
403                    use_conv_transpose=use_conv_transpose,
404                )
405            )
406            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
407        else:
408            self.deconv1 = Deconv2DBlock(
409                in_channels=self.embed_dim,
410                out_channels=features_decoder[0],
411                use_conv_transpose=use_conv_transpose,
412            )
413            self.deconv2 = Deconv2DBlock(
414                in_channels=features_decoder[0],
415                out_channels=features_decoder[1],
416                use_conv_transpose=use_conv_transpose,
417            )
418            self.deconv3 = Deconv2DBlock(
419                in_channels=features_decoder[1],
420                out_channels=features_decoder[2],
421                use_conv_transpose=use_conv_transpose,
422            )
423            self.deconv4 = Deconv2DBlock(
424                in_channels=features_decoder[2],
425                out_channels=features_decoder[3],
426                use_conv_transpose=use_conv_transpose,
427            )
428
429        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
430        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
431        self.deconv_out = _upsampler(
432            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
433        )
434        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
decoder
base
out_conv
deconv_out
decoder_head
def forward(self, x: torch.Tensor) -> torch.Tensor:
436    def forward(self, x: torch.Tensor) -> torch.Tensor:
437        """Apply the UNETR to the input data.
438
439        Args:
440            x: The input tensor.
441
442        Returns:
443            The UNETR output.
444        """
445        original_shape = x.shape[-2:]
446
447        # Reshape the inputs to the shape expected by the encoder
448        # and normalize the inputs if normalization is part of the model.
449        x, input_shape = self.preprocess(x)
450
451        encoder_outputs = self.encoder(x)
452
453        if isinstance(encoder_outputs[-1], list):
454            # `encoder_outputs` can be arranged in only two forms:
455            #   - either we only return the image embeddings
456            #   - or, we return the image embeddings and the "list" of global attention layers
457            z12, from_encoder = encoder_outputs
458        else:
459            z12 = encoder_outputs
460
461        if self.use_skip_connection:
462            from_encoder = from_encoder[::-1]
463            z9 = self.deconv1(from_encoder[0])
464            z6 = self.deconv2(from_encoder[1])
465            z3 = self.deconv3(from_encoder[2])
466            z0 = self.deconv4(x)
467
468        else:
469            z9 = self.deconv1(z12)
470            z6 = self.deconv2(z9)
471            z3 = self.deconv3(z6)
472            z0 = self.deconv4(z3)
473
474        updated_from_encoder = [z9, z6, z3]
475
476        x = self.base(z12)
477        x = self.decoder(x, encoder_inputs=updated_from_encoder)
478        x = self.deconv_out(x)
479
480        x = torch.cat([x, z0], dim=1)
481        x = self.decoder_head(x)
482
483        x = self.out_conv(x)
484        if self.final_activation is not None:
485            x = self.final_activation(x)
486
487        x = self.postprocess_masks(x, input_shape, original_shape)
488        return x

Apply the UNETR to the input data.

Arguments:
  • x: The input tensor.
Returns:

The UNETR output.

class UNETR2D(UNETR):
491class UNETR2D(UNETR):
492    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
493    """
494    pass

A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

class UNETR3D(UNETRBase):
497class UNETR3D(UNETRBase):
498    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
499    """
500    def __init__(
501        self,
502        img_size: int = 1024,
503        backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2",
504        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
505        decoder: Optional[nn.Module] = None,
506        out_channels: int = 1,
507        use_sam_stats: bool = False,
508        use_mae_stats: bool = False,
509        use_dino_stats: bool = False,
510        resize_input: bool = True,
511        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
512        final_activation: Optional[Union[str, nn.Module]] = None,
513        use_skip_connection: bool = False,
514        embed_dim: Optional[int] = None,
515        use_conv_transpose: bool = False,
516        use_strip_pooling: bool = True,
517        **kwargs
518    ):
519        if use_skip_connection:
520            raise NotImplementedError("The framework cannot handle skip connections atm.")
521        if use_conv_transpose:
522            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
523
524        super().__init__(
525            img_size=img_size,
526            backbone=backbone,
527            encoder=encoder,
528            decoder=decoder,
529            out_channels=out_channels,
530            use_sam_stats=use_sam_stats,
531            use_mae_stats=use_mae_stats,
532            use_dino_stats=use_dino_stats,
533            resize_input=resize_input,
534            encoder_checkpoint=encoder_checkpoint,
535            final_activation=final_activation,
536            use_skip_connection=use_skip_connection,
537            embed_dim=embed_dim,
538            use_conv_transpose=use_conv_transpose,
539            **kwargs,
540        )
541
542        # Load the pretrained image encoder weights
543        self.image_encoder = self.encoder
544
545        # Step 2: the 3d convolutional decoder.
546        # First, get the important parameters for the decoder.
547        embed_dim = 256
548        depth = 3
549        initial_features = 64
550        gain = 2
551        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
552        scale_factors = [1, 2, 2]
553        self.out_channels = out_channels
554
555        # The mapping blocks.
556        self.deconv1 = Deconv3DBlock(
557            in_channels=embed_dim,
558            out_channels=features_decoder[0],
559            scale_factor=scale_factors,
560            use_strip_pooling=use_strip_pooling,
561        )
562        self.deconv2 = Deconv3DBlock(
563            in_channels=features_decoder[0],
564            out_channels=features_decoder[1],
565            scale_factor=scale_factors,
566            use_strip_pooling=use_strip_pooling,
567        )
568        self.deconv3 = Deconv3DBlock(
569            in_channels=features_decoder[1],
570            out_channels=features_decoder[2],
571            scale_factor=scale_factors,
572            use_strip_pooling=use_strip_pooling,
573        )
574        self.deconv4 = Deconv3DBlock(
575            in_channels=features_decoder[2],
576            out_channels=features_decoder[3],
577            scale_factor=scale_factors,
578            use_strip_pooling=use_strip_pooling,
579        )
580
581        # The core decoder block.
582        self.decoder = decoder or Decoder(
583            features=features_decoder,
584            scale_factors=[scale_factors] * depth,
585            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
586            sampler_impl=Upsampler3d,
587        )
588
589        # And the final upsampler to match the expected dimensions.
590        self.end_up = Deconv3DBlock(
591            in_channels=features_decoder[-1],
592            out_channels=features_decoder[-1],
593            scale_factor=scale_factors,
594            use_strip_pooling=use_strip_pooling,
595        )
596
597        # Additional conjunction blocks.
598        self.base = ConvBlock3dWithStrip(
599            in_channels=embed_dim,
600            out_channels=features_decoder[0],
601            use_strip_pooling=use_strip_pooling,
602        )
603
604        # And the output layers.
605        self.decoder_head = ConvBlock3dWithStrip(
606            in_channels=2 * features_decoder[-1],
607            out_channels=features_decoder[-1],
608            use_strip_pooling=use_strip_pooling,
609        )
610        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
611
612    def forward(self, x: torch.Tensor):
613        """Forward pass of the UNETR-3D model.
614
615        Args:
616            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
617
618        Returns:
619            The UNETR output.
620        """
621        B, C, Z, H, W = x.shape
622        original_shape = (Z, H, W)
623
624        # Preprocessing step
625        x, input_shape = self.preprocess(x)
626
627        # Run the image encoder.
628        curr_features = torch.stack([self.image_encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
629
630        # Prepare the counterparts for the decoder.
631        # NOTE: The section below is sequential, there's no skip connections atm.
632        z9 = self.deconv1(curr_features)
633        z6 = self.deconv2(z9)
634        z3 = self.deconv3(z6)
635        z0 = self.deconv4(z3)
636
637        # Align the features through the base block.
638        x = self.base(curr_features)
639
640        # Run the decoder.
641        updated_from_encoder = [z9, z6, z3]
642        x = self.decoder(x, encoder_inputs=updated_from_encoder)
643        x = self.end_up(x)
644
645        # And the final output head.
646        x = torch.cat([x, z0], dim=1)
647        x = self.decoder_head(x)
648        x = self.out_conv(x)
649        if self.final_activation is not None:
650            x = self.final_activation(x)
651
652        # Postprocess the output back to original size.
653        x = self.postprocess_masks(x, input_shape, original_shape)
654        return x

A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR3D( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'mae', 'scalemae', 'dinov3'] = 'sam2', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'hvit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = False, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, use_strip_pooling: bool = True, **kwargs)
500    def __init__(
501        self,
502        img_size: int = 1024,
503        backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2",
504        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
505        decoder: Optional[nn.Module] = None,
506        out_channels: int = 1,
507        use_sam_stats: bool = False,
508        use_mae_stats: bool = False,
509        use_dino_stats: bool = False,
510        resize_input: bool = True,
511        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
512        final_activation: Optional[Union[str, nn.Module]] = None,
513        use_skip_connection: bool = False,
514        embed_dim: Optional[int] = None,
515        use_conv_transpose: bool = False,
516        use_strip_pooling: bool = True,
517        **kwargs
518    ):
519        if use_skip_connection:
520            raise NotImplementedError("The framework cannot handle skip connections atm.")
521        if use_conv_transpose:
522            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
523
524        super().__init__(
525            img_size=img_size,
526            backbone=backbone,
527            encoder=encoder,
528            decoder=decoder,
529            out_channels=out_channels,
530            use_sam_stats=use_sam_stats,
531            use_mae_stats=use_mae_stats,
532            use_dino_stats=use_dino_stats,
533            resize_input=resize_input,
534            encoder_checkpoint=encoder_checkpoint,
535            final_activation=final_activation,
536            use_skip_connection=use_skip_connection,
537            embed_dim=embed_dim,
538            use_conv_transpose=use_conv_transpose,
539            **kwargs,
540        )
541
542        # Load the pretrained image encoder weights
543        self.image_encoder = self.encoder
544
545        # Step 2: the 3d convolutional decoder.
546        # First, get the important parameters for the decoder.
547        embed_dim = 256
548        depth = 3
549        initial_features = 64
550        gain = 2
551        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
552        scale_factors = [1, 2, 2]
553        self.out_channels = out_channels
554
555        # The mapping blocks.
556        self.deconv1 = Deconv3DBlock(
557            in_channels=embed_dim,
558            out_channels=features_decoder[0],
559            scale_factor=scale_factors,
560            use_strip_pooling=use_strip_pooling,
561        )
562        self.deconv2 = Deconv3DBlock(
563            in_channels=features_decoder[0],
564            out_channels=features_decoder[1],
565            scale_factor=scale_factors,
566            use_strip_pooling=use_strip_pooling,
567        )
568        self.deconv3 = Deconv3DBlock(
569            in_channels=features_decoder[1],
570            out_channels=features_decoder[2],
571            scale_factor=scale_factors,
572            use_strip_pooling=use_strip_pooling,
573        )
574        self.deconv4 = Deconv3DBlock(
575            in_channels=features_decoder[2],
576            out_channels=features_decoder[3],
577            scale_factor=scale_factors,
578            use_strip_pooling=use_strip_pooling,
579        )
580
581        # The core decoder block.
582        self.decoder = decoder or Decoder(
583            features=features_decoder,
584            scale_factors=[scale_factors] * depth,
585            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
586            sampler_impl=Upsampler3d,
587        )
588
589        # And the final upsampler to match the expected dimensions.
590        self.end_up = Deconv3DBlock(
591            in_channels=features_decoder[-1],
592            out_channels=features_decoder[-1],
593            scale_factor=scale_factors,
594            use_strip_pooling=use_strip_pooling,
595        )
596
597        # Additional conjunction blocks.
598        self.base = ConvBlock3dWithStrip(
599            in_channels=embed_dim,
600            out_channels=features_decoder[0],
601            use_strip_pooling=use_strip_pooling,
602        )
603
604        # And the output layers.
605        self.decoder_head = ConvBlock3dWithStrip(
606            in_channels=2 * features_decoder[-1],
607            out_channels=features_decoder[-1],
608            use_strip_pooling=use_strip_pooling,
609        )
610        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

image_encoder
out_channels
deconv1
deconv2
deconv3
deconv4
decoder
end_up
base
decoder_head
out_conv
def forward(self, x: torch.Tensor):
612    def forward(self, x: torch.Tensor):
613        """Forward pass of the UNETR-3D model.
614
615        Args:
616            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
617
618        Returns:
619            The UNETR output.
620        """
621        B, C, Z, H, W = x.shape
622        original_shape = (Z, H, W)
623
624        # Preprocessing step
625        x, input_shape = self.preprocess(x)
626
627        # Run the image encoder.
628        curr_features = torch.stack([self.image_encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
629
630        # Prepare the counterparts for the decoder.
631        # NOTE: The section below is sequential, there's no skip connections atm.
632        z9 = self.deconv1(curr_features)
633        z6 = self.deconv2(z9)
634        z3 = self.deconv3(z6)
635        z0 = self.deconv4(z3)
636
637        # Align the features through the base block.
638        x = self.base(curr_features)
639
640        # Run the decoder.
641        updated_from_encoder = [z9, z6, z3]
642        x = self.decoder(x, encoder_inputs=updated_from_encoder)
643        x = self.end_up(x)
644
645        # And the final output head.
646        x = torch.cat([x, z0], dim=1)
647        x = self.decoder_head(x)
648        x = self.out_conv(x)
649        if self.final_activation is not None:
650            x = self.final_activation(x)
651
652        # Postprocess the output back to original size.
653        x = self.postprocess_masks(x, input_shape, original_shape)
654        return x

Forward pass of the UNETR-3D model.

Arguments:
  • x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:

The UNETR output.