torch_em.model.unetr

  1from functools import partial
  2from collections import OrderedDict
  3from typing import Optional, Tuple, Union, Literal
  4
  5import torch
  6import torch.nn as nn
  7import torch.nn.functional as F
  8
  9from .vit import get_vision_transformer
 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs
 11
 12try:
 13    from micro_sam.util import get_sam_model
 14except ImportError:
 15    get_sam_model = None
 16
 17try:
 18    from micro_sam2.util import get_sam2_model
 19except ImportError:
 20    get_sam2_model = None
 21
 22try:
 23    from micro_sam3.util import get_sam3_model
 24except ImportError:
 25    get_sam3_model = None
 26
 27
 28#
 29# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / SAM2 / SAM3 / DINOv2 / DINOv3 / MAE / ScaleMAE) + UNet Decoder from `torch_em`]  # noqa
 30#
 31
 32
 33class UNETRBase(nn.Module):
 34    """Base class for implementing a UNETR.
 35
 36    Args:
 37        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 38        backbone: The name of the vision transformer implementation.
 39            One of "sam", "sam2", "sam3, "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
 40        encoder: The vision transformer. Can either be a name, such as "vit_b"
 41            (see all combinations for this below) or a torch module.
 42        decoder: The convolutional decoder.
 43        out_channels: The number of output channels of the UNETR.
 44        use_sam_stats: Whether to normalize the input data with the statistics of the
 45            pretrained SAM / SAM2 / SAM3 model.
 46        use_dino_stats: Whether to normalize the input data with the statistics of the
 47            pretrained DINOv2 / DINOv3 model.
 48        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 49        resize_input: Whether to resize the input images to match `img_size`.
 50            By default, it resizes the inputs to match the `img_size`.
 51        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 52            Can either be a filepath or an already loaded checkpoint.
 53        final_activation: The activation to apply to the UNETR output.
 54        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 55        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 56        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 57            By default, it uses resampling for upsampling.
 58
 59        NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
 60
 61        SAM_family_models:
 62            - 'sam' x 'vit_b'
 63            - 'sam' x 'vit_l'
 64            - 'sam' x 'vit_h'
 65            - 'sam2' x 'hvit_t'
 66            - 'sam2' x 'hvit_s'
 67            - 'sam2' x 'hvit_b'
 68            - 'sam2' x 'hvit_l'
 69            - 'sam3' x 'vit_pe'
 70
 71        DINO_family_models:
 72            - 'dinov2' x 'vit_s'
 73            - 'dinov2' x 'vit_b'
 74            - 'dinov2' x 'vit_l'
 75            - 'dinov2' x 'vit_g'
 76            - 'dinov2' x 'vit_s_reg4'
 77            - 'dinov2' x 'vit_b_reg4'
 78            - 'dinov2' x 'vit_l_reg4'
 79            - 'dinov2' x 'vit_g_reg4'
 80            - 'dinov3' x 'vit_s'
 81            - 'dinov3' x 'vit_s+'
 82            - 'dinov3' x 'vit_b'
 83            - 'dinov3' x 'vit_l'
 84            - 'dinov3' x 'vit_l+'
 85            - 'dinov3' x 'vit_h+'
 86            - 'dinov3' x 'vit_7b'
 87
 88        MAE_family_models:
 89            - 'mae' x 'vit_b'
 90            - 'mae' x 'vit_l'
 91            - 'mae' x 'vit_h'
 92            - 'scalemae' x 'vit_b'
 93            - 'scalemae' x 'vit_l'
 94            - 'scalemae' x 'vit_h'
 95    """
 96    def __init__(
 97        self,
 98        img_size: int = 1024,
 99        backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
100        encoder: Optional[Union[nn.Module, str]] = "vit_b",
101        decoder: Optional[nn.Module] = None,
102        out_channels: int = 1,
103        use_sam_stats: bool = False,
104        use_mae_stats: bool = False,
105        use_dino_stats: bool = False,
106        resize_input: bool = True,
107        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
108        final_activation: Optional[Union[str, nn.Module]] = None,
109        use_skip_connection: bool = True,
110        embed_dim: Optional[int] = None,
111        use_conv_transpose: bool = False,
112        **kwargs
113    ) -> None:
114        super().__init__()
115
116        self.img_size = img_size
117        self.use_sam_stats = use_sam_stats
118        self.use_mae_stats = use_mae_stats
119        self.use_dino_stats = use_dino_stats
120        self.use_skip_connection = use_skip_connection
121        self.resize_input = resize_input
122        self.use_conv_transpose = use_conv_transpose
123        self.backbone = backbone
124
125        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
126            print(f"Using {encoder} from {backbone.upper()}")
127            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
128
129            if encoder_checkpoint is not None:
130                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
131
132            if embed_dim is None:
133                embed_dim = self.encoder.embed_dim
134
135        else:  # `nn.Module` ViT backbone
136            self.encoder = encoder
137
138            have_neck = False
139            for name, _ in self.encoder.named_parameters():
140                if name.startswith("neck"):
141                    have_neck = True
142
143            if embed_dim is None:
144                if have_neck:
145                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
146                else:
147                    embed_dim = self.encoder.patch_embed.proj.out_channels
148
149        self.embed_dim = embed_dim
150        self.final_activation = self._get_activation(final_activation)
151
152    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
153        """Function to load pretrained weights to the image encoder.
154        """
155        if isinstance(checkpoint, str):
156            if backbone == "sam" and isinstance(encoder, str):
157                # If we have a SAM encoder, then we first try to load the full SAM Model
158                # (using micro_sam) and otherwise fall back on directly loading the encoder state
159                # from the checkpoint
160                try:
161                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
162                    encoder_state = model.image_encoder.state_dict()
163                except Exception:
164                    # Try loading the encoder state directly from a checkpoint.
165                    encoder_state = torch.load(checkpoint, weights_only=False)
166
167            elif backbone == "sam2" and isinstance(encoder, str):
168                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
169                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
170                # from the checkpoint
171                try:
172                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
173                    encoder_state = model.image_encoder.state_dict()
174                except Exception:
175                    # Try loading the encoder state directly from a checkpoint.
176                    encoder_state = torch.load(checkpoint, weights_only=False)
177
178            elif backbone == "sam3" and isinstance(encoder, str):
179                # If we have a SAM3 encoder, then we first try to load the full SAM3 Model.
180                # (using micro_sam3) and otherwise fall back on directly loading the encoder state
181                # from the checkpoint
182                try:
183                    model = get_sam3_model(checkpoint_path=checkpoint)
184                    encoder_state = model.backbone.vision_backbone.state_dict()
185                    # Let's align loading the encoder weights with expected parameter names
186                    encoder_state = {
187                        k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items()
188                    }
189                    # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks.
190                    encoder_state = {
191                        k: v for k, v in encoder_state.items()
192                        if not (k.startswith("convs.") or k.startswith("sam2_convs."))
193                    }
194                except Exception:
195                    # Try loading the encoder state directly from a checkpoint.
196                    encoder_state = torch.load(checkpoint, weights_only=False)
197
198            elif backbone == "mae":
199                # vit initialization hints from:
200                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
201                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
202                encoder_state = OrderedDict({
203                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
204                })
205                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
206                current_encoder_state = self.encoder.state_dict()
207                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
208                    del self.encoder.head
209
210            elif backbone == "scalemae":
211                # Load the encoder state directly from a checkpoint.
212                encoder_state = torch.load(checkpoint)["model"]
213                encoder_state = OrderedDict({
214                    k: v for k, v in encoder_state.items()
215                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
216                })
217
218                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
219                current_encoder_state = self.encoder.state_dict()
220                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
221                    del self.encoder.head
222
223                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
224                    del self.encoder.pos_embed
225
226            elif backbone in ["dinov2", "dinov3"]:  # Load the encoder state directly from a checkpoint.
227                encoder_state = torch.load(checkpoint)
228
229            else:
230                raise ValueError(
231                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
232                )
233
234        else:
235            encoder_state = checkpoint
236
237        self.encoder.load_state_dict(encoder_state)
238
239    def _get_activation(self, activation):
240        return_activation = None
241        if activation is None:
242            return None
243        if isinstance(activation, nn.Module):
244            return activation
245        if isinstance(activation, str):
246            return_activation = getattr(nn, activation, None)
247        if return_activation is None:
248            raise ValueError(f"Invalid activation: {activation}")
249
250        return return_activation()
251
252    @staticmethod
253    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
254        """Compute the output size given input size and target long side length.
255
256        Args:
257            oldh: The input image height.
258            oldw: The input image width.
259            long_side_length: The longest side length for resizing.
260
261        Returns:
262            The new image height.
263            The new image width.
264        """
265        scale = long_side_length * 1.0 / max(oldh, oldw)
266        newh, neww = oldh * scale, oldw * scale
267        neww = int(neww + 0.5)
268        newh = int(newh + 0.5)
269        return (newh, neww)
270
271    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
272        """Resize the image so that the longest side has the correct length.
273
274        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
275
276        Args:
277            image: The input image.
278
279        Returns:
280            The resized image.
281        """
282        if image.ndim == 4:  # i.e. 2d image
283            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
284            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
285        elif image.ndim == 5:  # i.e. 3d volume
286            B, C, Z, H, W = image.shape
287            target_size = self.get_preprocess_shape(H, W, self.img_size)
288            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
289        else:
290            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
291
292    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
293        """@private
294        """
295        # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1).
296        view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
297        pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
298        pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
299        return pixel_mean, pixel_std
300
301    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
302        """@private
303        """
304        device = x.device
305        is_3d = (x.ndim == 5)
306        device, dtype = x.device, x.dtype
307
308        if self.use_sam_stats:
309            mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
310        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
311            raise NotImplementedError
312        elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"):
313            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
314        elif self.use_sam_stats and self.backbone == "sam3":
315            mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
316        else:
317            mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
318
319        pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
320
321        if self.resize_input:
322            x = self.resize_longest_side(x)
323        input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
324
325        x = (x - pixel_mean) / pixel_std
326        h, w = x.shape[-2:]
327        padh = self.encoder.img_size - h
328        padw = self.encoder.img_size - w
329
330        if is_3d:
331            x = F.pad(x, (0, padw, 0, padh, 0, 0))
332        else:
333            x = F.pad(x, (0, padw, 0, padh))
334
335        return x, input_shape
336
337    def postprocess_masks(
338        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
339    ) -> torch.Tensor:
340        """@private
341        """
342        if masks.ndim == 4:  # i.e. 2d labels
343            masks = F.interpolate(
344                masks,
345                (self.encoder.img_size, self.encoder.img_size),
346                mode="bilinear",
347                align_corners=False,
348            )
349            masks = masks[..., : input_size[0], : input_size[1]]
350            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
351
352        elif masks.ndim == 5:  # i.e. 3d volumetric labels
353            masks = F.interpolate(
354                masks,
355                (input_size[0], self.img_size, self.img_size),
356                mode="trilinear",
357                align_corners=False,
358            )
359            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
360            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
361
362        else:
363            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
364
365        return masks
366
367
368class UNETR(UNETRBase):
369    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
370    """
371    def __init__(
372        self,
373        img_size: int = 1024,
374        backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
375        encoder: Optional[Union[nn.Module, str]] = "vit_b",
376        decoder: Optional[nn.Module] = None,
377        out_channels: int = 1,
378        use_sam_stats: bool = False,
379        use_mae_stats: bool = False,
380        use_dino_stats: bool = False,
381        resize_input: bool = True,
382        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
383        final_activation: Optional[Union[str, nn.Module]] = None,
384        use_skip_connection: bool = True,
385        embed_dim: Optional[int] = None,
386        use_conv_transpose: bool = False,
387        **kwargs
388    ) -> None:
389
390        super().__init__(
391            img_size=img_size,
392            backbone=backbone,
393            encoder=encoder,
394            decoder=decoder,
395            out_channels=out_channels,
396            use_sam_stats=use_sam_stats,
397            use_mae_stats=use_mae_stats,
398            use_dino_stats=use_dino_stats,
399            resize_input=resize_input,
400            encoder_checkpoint=encoder_checkpoint,
401            final_activation=final_activation,
402            use_skip_connection=use_skip_connection,
403            embed_dim=embed_dim,
404            use_conv_transpose=use_conv_transpose,
405            **kwargs,
406        )
407
408        encoder = self.encoder
409
410        if backbone == "sam2" and hasattr(encoder, "trunk"):
411            in_chans = encoder.trunk.patch_embed.proj.in_channels
412        elif hasattr(encoder, "in_chans"):
413            in_chans = encoder.in_chans
414        else:  # `nn.Module` ViT backbone.
415            try:
416                in_chans = encoder.patch_embed.proj.in_channels
417            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
418                in_chans = encoder.patch_embed.seq[0].c.in_channels
419
420        # parameters for the decoder network
421        depth = 3
422        initial_features = 64
423        gain = 2
424        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
425        scale_factors = depth * [2]
426        self.out_channels = out_channels
427
428        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
429        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
430
431        self.decoder = decoder or Decoder(
432            features=features_decoder,
433            scale_factors=scale_factors[::-1],
434            conv_block_impl=ConvBlock2d,
435            sampler_impl=_upsampler,
436        )
437
438        if use_skip_connection:
439            self.deconv1 = Deconv2DBlock(
440                in_channels=self.embed_dim,
441                out_channels=features_decoder[0],
442                use_conv_transpose=use_conv_transpose,
443            )
444            self.deconv2 = nn.Sequential(
445                Deconv2DBlock(
446                    in_channels=self.embed_dim,
447                    out_channels=features_decoder[0],
448                    use_conv_transpose=use_conv_transpose,
449                ),
450                Deconv2DBlock(
451                    in_channels=features_decoder[0],
452                    out_channels=features_decoder[1],
453                    use_conv_transpose=use_conv_transpose,
454                )
455            )
456            self.deconv3 = nn.Sequential(
457                Deconv2DBlock(
458                    in_channels=self.embed_dim,
459                    out_channels=features_decoder[0],
460                    use_conv_transpose=use_conv_transpose,
461                ),
462                Deconv2DBlock(
463                    in_channels=features_decoder[0],
464                    out_channels=features_decoder[1],
465                    use_conv_transpose=use_conv_transpose,
466                ),
467                Deconv2DBlock(
468                    in_channels=features_decoder[1],
469                    out_channels=features_decoder[2],
470                    use_conv_transpose=use_conv_transpose,
471                )
472            )
473            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
474        else:
475            self.deconv1 = Deconv2DBlock(
476                in_channels=self.embed_dim,
477                out_channels=features_decoder[0],
478                use_conv_transpose=use_conv_transpose,
479            )
480            self.deconv2 = Deconv2DBlock(
481                in_channels=features_decoder[0],
482                out_channels=features_decoder[1],
483                use_conv_transpose=use_conv_transpose,
484            )
485            self.deconv3 = Deconv2DBlock(
486                in_channels=features_decoder[1],
487                out_channels=features_decoder[2],
488                use_conv_transpose=use_conv_transpose,
489            )
490            self.deconv4 = Deconv2DBlock(
491                in_channels=features_decoder[2],
492                out_channels=features_decoder[3],
493                use_conv_transpose=use_conv_transpose,
494            )
495
496        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
497        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
498        self.deconv_out = _upsampler(
499            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
500        )
501        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
502
503    def forward(self, x: torch.Tensor) -> torch.Tensor:
504        """Apply the UNETR to the input data.
505
506        Args:
507            x: The input tensor.
508
509        Returns:
510            The UNETR output.
511        """
512        original_shape = x.shape[-2:]
513
514        # Reshape the inputs to the shape expected by the encoder
515        # and normalize the inputs if normalization is part of the model.
516        x, input_shape = self.preprocess(x)
517
518        encoder_outputs = self.encoder(x)
519
520        if isinstance(encoder_outputs[-1], list):
521            # `encoder_outputs` can be arranged in only two forms:
522            #   - either we only return the image embeddings
523            #   - or, we return the image embeddings and the "list" of global attention layers
524            z12, from_encoder = encoder_outputs
525        else:
526            z12 = encoder_outputs
527
528        if self.use_skip_connection:
529            from_encoder = from_encoder[::-1]
530            z9 = self.deconv1(from_encoder[0])
531            z6 = self.deconv2(from_encoder[1])
532            z3 = self.deconv3(from_encoder[2])
533            z0 = self.deconv4(x)
534
535        else:
536            z9 = self.deconv1(z12)
537            z6 = self.deconv2(z9)
538            z3 = self.deconv3(z6)
539            z0 = self.deconv4(z3)
540
541        updated_from_encoder = [z9, z6, z3]
542
543        x = self.base(z12)
544        x = self.decoder(x, encoder_inputs=updated_from_encoder)
545        x = self.deconv_out(x)
546
547        x = torch.cat([x, z0], dim=1)
548        x = self.decoder_head(x)
549
550        x = self.out_conv(x)
551        if self.final_activation is not None:
552            x = self.final_activation(x)
553
554        x = self.postprocess_masks(x, input_shape, original_shape)
555        return x
556
557
558class UNETR2D(UNETR):
559    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
560    """
561    pass
562
563
564class UNETR3D(UNETRBase):
565    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
566    """
567    def __init__(
568        self,
569        img_size: int = 1024,
570        backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
571        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
572        decoder: Optional[nn.Module] = None,
573        out_channels: int = 1,
574        use_sam_stats: bool = False,
575        use_mae_stats: bool = False,
576        use_dino_stats: bool = False,
577        resize_input: bool = True,
578        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
579        final_activation: Optional[Union[str, nn.Module]] = None,
580        use_skip_connection: bool = False,
581        embed_dim: Optional[int] = None,
582        use_conv_transpose: bool = False,
583        use_strip_pooling: bool = True,
584        **kwargs
585    ):
586        if use_skip_connection:
587            raise NotImplementedError("The framework cannot handle skip connections atm.")
588        if use_conv_transpose:
589            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
590
591        # Sort the `embed_dim` out
592        embed_dim = 256 if embed_dim is None else embed_dim
593
594        super().__init__(
595            img_size=img_size,
596            backbone=backbone,
597            encoder=encoder,
598            decoder=decoder,
599            out_channels=out_channels,
600            use_sam_stats=use_sam_stats,
601            use_mae_stats=use_mae_stats,
602            use_dino_stats=use_dino_stats,
603            resize_input=resize_input,
604            encoder_checkpoint=encoder_checkpoint,
605            final_activation=final_activation,
606            use_skip_connection=use_skip_connection,
607            embed_dim=embed_dim,
608            use_conv_transpose=use_conv_transpose,
609            **kwargs,
610        )
611
612        # The 3d convolutional decoder.
613        # First, get the important parameters for the decoder.
614        depth = 3
615        initial_features = 64
616        gain = 2
617        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
618        scale_factors = [1, 2, 2]
619        self.out_channels = out_channels
620
621        # The mapping blocks.
622        self.deconv1 = Deconv3DBlock(
623            in_channels=embed_dim,
624            out_channels=features_decoder[0],
625            scale_factor=scale_factors,
626            use_strip_pooling=use_strip_pooling,
627        )
628        self.deconv2 = Deconv3DBlock(
629            in_channels=features_decoder[0],
630            out_channels=features_decoder[1],
631            scale_factor=scale_factors,
632            use_strip_pooling=use_strip_pooling,
633        )
634        self.deconv3 = Deconv3DBlock(
635            in_channels=features_decoder[1],
636            out_channels=features_decoder[2],
637            scale_factor=scale_factors,
638            use_strip_pooling=use_strip_pooling,
639        )
640        self.deconv4 = Deconv3DBlock(
641            in_channels=features_decoder[2],
642            out_channels=features_decoder[3],
643            scale_factor=scale_factors,
644            use_strip_pooling=use_strip_pooling,
645        )
646
647        # The core decoder block.
648        self.decoder = decoder or Decoder(
649            features=features_decoder,
650            scale_factors=[scale_factors] * depth,
651            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
652            sampler_impl=Upsampler3d,
653        )
654
655        # And the final upsampler to match the expected dimensions.
656        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
657            in_channels=features_decoder[-1],
658            out_channels=features_decoder[-1],
659            scale_factor=scale_factors,
660            use_strip_pooling=use_strip_pooling,
661        )
662
663        # Additional conjunction blocks.
664        self.base = ConvBlock3dWithStrip(
665            in_channels=embed_dim,
666            out_channels=features_decoder[0],
667            use_strip_pooling=use_strip_pooling,
668        )
669
670        # And the output layers.
671        self.decoder_head = ConvBlock3dWithStrip(
672            in_channels=2 * features_decoder[-1],
673            out_channels=features_decoder[-1],
674            use_strip_pooling=use_strip_pooling,
675        )
676        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
677
678    def forward(self, x: torch.Tensor):
679        """Forward pass of the UNETR-3D model.
680
681        Args:
682            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
683
684        Returns:
685            The UNETR output.
686        """
687        B, C, Z, H, W = x.shape
688        original_shape = (Z, H, W)
689
690        # Preprocessing step
691        x, input_shape = self.preprocess(x)
692
693        # Run the image encoder.
694        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
695
696        # Prepare the counterparts for the decoder.
697        # NOTE: The section below is sequential, there's no skip connections atm.
698        z9 = self.deconv1(curr_features)
699        z6 = self.deconv2(z9)
700        z3 = self.deconv3(z6)
701        z0 = self.deconv4(z3)
702
703        updated_from_encoder = [z9, z6, z3]
704
705        # Align the features through the base block.
706        x = self.base(curr_features)
707        # Run the decoder
708        x = self.decoder(x, encoder_inputs=updated_from_encoder)
709        x = self.deconv_out(x)  # NOTE before `end_up`
710
711        # And the final output head.
712        x = torch.cat([x, z0], dim=1)
713        x = self.decoder_head(x)
714        x = self.out_conv(x)
715        if self.final_activation is not None:
716            x = self.final_activation(x)
717
718        # Postprocess the output back to original size.
719        x = self.postprocess_masks(x, input_shape, original_shape)
720        return x
721
722#
723#  ADDITIONAL FUNCTIONALITIES
724#
725
726
727def _strip_pooling_layers(enabled, channels) -> nn.Module:
728    return DepthStripPooling(channels) if enabled else nn.Identity()
729
730
731class DepthStripPooling(nn.Module):
732    """@private
733    """
734    def __init__(self, channels: int, reduction: int = 4):
735        """Block for strip pooling along the depth dimension (only).
736
737        eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling
738        to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to
739        modulate the original features (using a gated residual).
740
741        For 2D (Z == 1 or Z == 3): returns input unchanged (no-op).
742
743        Args:
744            channels: The output channels.
745            reduction: The reduction of the hidden layers.
746        """
747        super().__init__()
748        hidden = max(1, channels // reduction)
749        self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1)
750        self.bn1 = nn.BatchNorm3d(hidden)
751        self.relu = nn.ReLU(inplace=True)
752        self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1)
753
754    def forward(self, x: torch.Tensor) -> torch.Tensor:
755        if x.dim() != 5:
756            raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.")
757
758        B, C, Z, H, W = x.shape
759        if Z == 1 or Z == 3:  # i.e. 2d-as-1-slice or RGB_2d-as-1-slice.
760            return x  # We simply do nothing there.
761
762        # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W)
763        feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W))
764        feat = self.conv1(feat)
765        feat = self.bn1(feat)
766        feat = self.relu(feat)
767        feat = self.conv2(feat)
768        gate = torch.sigmoid(feat).expand(B, C, Z, H, W)  # Broadcast the collapsed depth context back to all slices
769
770        # Gated residual fusion
771        return x * gate + x
772
773
774class Deconv3DBlock(nn.Module):
775    """@private
776    """
777    def __init__(
778        self,
779        scale_factor,
780        in_channels,
781        out_channels,
782        kernel_size=3,
783        anisotropic_kernel=True,
784        use_strip_pooling=True,
785    ):
786        super().__init__()
787        conv_block_kwargs = {
788            "in_channels": out_channels,
789            "out_channels": out_channels,
790            "kernel_size": kernel_size,
791            "padding": ((kernel_size - 1) // 2),
792        }
793        if anisotropic_kernel:
794            conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor)
795
796        self.block = nn.Sequential(
797            Upsampler3d(scale_factor, in_channels, out_channels),
798            nn.Conv3d(**conv_block_kwargs),
799            nn.BatchNorm3d(out_channels),
800            nn.ReLU(True),
801            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
802        )
803
804    def forward(self, x):
805        return self.block(x)
806
807
808class ConvBlock3dWithStrip(nn.Module):
809    """@private
810    """
811    def __init__(
812        self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs
813    ):
814        super().__init__()
815        self.block = nn.Sequential(
816            ConvBlock3d(in_channels, out_channels, **kwargs),
817            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
818        )
819
820    def forward(self, x):
821        return self.block(x)
822
823
824class SingleDeconv2DBlock(nn.Module):
825    """@private
826    """
827    def __init__(self, scale_factor, in_channels, out_channels):
828        super().__init__()
829        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
830
831    def forward(self, x):
832        return self.block(x)
833
834
835class SingleConv2DBlock(nn.Module):
836    """@private
837    """
838    def __init__(self, in_channels, out_channels, kernel_size):
839        super().__init__()
840        self.block = nn.Conv2d(
841            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
842        )
843
844    def forward(self, x):
845        return self.block(x)
846
847
848class Conv2DBlock(nn.Module):
849    """@private
850    """
851    def __init__(self, in_channels, out_channels, kernel_size=3):
852        super().__init__()
853        self.block = nn.Sequential(
854            SingleConv2DBlock(in_channels, out_channels, kernel_size),
855            nn.BatchNorm2d(out_channels),
856            nn.ReLU(True)
857        )
858
859    def forward(self, x):
860        return self.block(x)
861
862
863class Deconv2DBlock(nn.Module):
864    """@private
865    """
866    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
867        super().__init__()
868        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
869        self.block = nn.Sequential(
870            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
871            SingleConv2DBlock(out_channels, out_channels, kernel_size),
872            nn.BatchNorm2d(out_channels),
873            nn.ReLU(True)
874        )
875
876    def forward(self, x):
877        return self.block(x)
class UNETRBase(torch.nn.modules.module.Module):
 34class UNETRBase(nn.Module):
 35    """Base class for implementing a UNETR.
 36
 37    Args:
 38        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 39        backbone: The name of the vision transformer implementation.
 40            One of "sam", "sam2", "sam3, "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
 41        encoder: The vision transformer. Can either be a name, such as "vit_b"
 42            (see all combinations for this below) or a torch module.
 43        decoder: The convolutional decoder.
 44        out_channels: The number of output channels of the UNETR.
 45        use_sam_stats: Whether to normalize the input data with the statistics of the
 46            pretrained SAM / SAM2 / SAM3 model.
 47        use_dino_stats: Whether to normalize the input data with the statistics of the
 48            pretrained DINOv2 / DINOv3 model.
 49        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 50        resize_input: Whether to resize the input images to match `img_size`.
 51            By default, it resizes the inputs to match the `img_size`.
 52        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 53            Can either be a filepath or an already loaded checkpoint.
 54        final_activation: The activation to apply to the UNETR output.
 55        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 56        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 57        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 58            By default, it uses resampling for upsampling.
 59
 60        NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
 61
 62        SAM_family_models:
 63            - 'sam' x 'vit_b'
 64            - 'sam' x 'vit_l'
 65            - 'sam' x 'vit_h'
 66            - 'sam2' x 'hvit_t'
 67            - 'sam2' x 'hvit_s'
 68            - 'sam2' x 'hvit_b'
 69            - 'sam2' x 'hvit_l'
 70            - 'sam3' x 'vit_pe'
 71
 72        DINO_family_models:
 73            - 'dinov2' x 'vit_s'
 74            - 'dinov2' x 'vit_b'
 75            - 'dinov2' x 'vit_l'
 76            - 'dinov2' x 'vit_g'
 77            - 'dinov2' x 'vit_s_reg4'
 78            - 'dinov2' x 'vit_b_reg4'
 79            - 'dinov2' x 'vit_l_reg4'
 80            - 'dinov2' x 'vit_g_reg4'
 81            - 'dinov3' x 'vit_s'
 82            - 'dinov3' x 'vit_s+'
 83            - 'dinov3' x 'vit_b'
 84            - 'dinov3' x 'vit_l'
 85            - 'dinov3' x 'vit_l+'
 86            - 'dinov3' x 'vit_h+'
 87            - 'dinov3' x 'vit_7b'
 88
 89        MAE_family_models:
 90            - 'mae' x 'vit_b'
 91            - 'mae' x 'vit_l'
 92            - 'mae' x 'vit_h'
 93            - 'scalemae' x 'vit_b'
 94            - 'scalemae' x 'vit_l'
 95            - 'scalemae' x 'vit_h'
 96    """
 97    def __init__(
 98        self,
 99        img_size: int = 1024,
100        backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
101        encoder: Optional[Union[nn.Module, str]] = "vit_b",
102        decoder: Optional[nn.Module] = None,
103        out_channels: int = 1,
104        use_sam_stats: bool = False,
105        use_mae_stats: bool = False,
106        use_dino_stats: bool = False,
107        resize_input: bool = True,
108        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
109        final_activation: Optional[Union[str, nn.Module]] = None,
110        use_skip_connection: bool = True,
111        embed_dim: Optional[int] = None,
112        use_conv_transpose: bool = False,
113        **kwargs
114    ) -> None:
115        super().__init__()
116
117        self.img_size = img_size
118        self.use_sam_stats = use_sam_stats
119        self.use_mae_stats = use_mae_stats
120        self.use_dino_stats = use_dino_stats
121        self.use_skip_connection = use_skip_connection
122        self.resize_input = resize_input
123        self.use_conv_transpose = use_conv_transpose
124        self.backbone = backbone
125
126        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
127            print(f"Using {encoder} from {backbone.upper()}")
128            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
129
130            if encoder_checkpoint is not None:
131                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
132
133            if embed_dim is None:
134                embed_dim = self.encoder.embed_dim
135
136        else:  # `nn.Module` ViT backbone
137            self.encoder = encoder
138
139            have_neck = False
140            for name, _ in self.encoder.named_parameters():
141                if name.startswith("neck"):
142                    have_neck = True
143
144            if embed_dim is None:
145                if have_neck:
146                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
147                else:
148                    embed_dim = self.encoder.patch_embed.proj.out_channels
149
150        self.embed_dim = embed_dim
151        self.final_activation = self._get_activation(final_activation)
152
153    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
154        """Function to load pretrained weights to the image encoder.
155        """
156        if isinstance(checkpoint, str):
157            if backbone == "sam" and isinstance(encoder, str):
158                # If we have a SAM encoder, then we first try to load the full SAM Model
159                # (using micro_sam) and otherwise fall back on directly loading the encoder state
160                # from the checkpoint
161                try:
162                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
163                    encoder_state = model.image_encoder.state_dict()
164                except Exception:
165                    # Try loading the encoder state directly from a checkpoint.
166                    encoder_state = torch.load(checkpoint, weights_only=False)
167
168            elif backbone == "sam2" and isinstance(encoder, str):
169                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
170                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
171                # from the checkpoint
172                try:
173                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
174                    encoder_state = model.image_encoder.state_dict()
175                except Exception:
176                    # Try loading the encoder state directly from a checkpoint.
177                    encoder_state = torch.load(checkpoint, weights_only=False)
178
179            elif backbone == "sam3" and isinstance(encoder, str):
180                # If we have a SAM3 encoder, then we first try to load the full SAM3 Model.
181                # (using micro_sam3) and otherwise fall back on directly loading the encoder state
182                # from the checkpoint
183                try:
184                    model = get_sam3_model(checkpoint_path=checkpoint)
185                    encoder_state = model.backbone.vision_backbone.state_dict()
186                    # Let's align loading the encoder weights with expected parameter names
187                    encoder_state = {
188                        k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items()
189                    }
190                    # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks.
191                    encoder_state = {
192                        k: v for k, v in encoder_state.items()
193                        if not (k.startswith("convs.") or k.startswith("sam2_convs."))
194                    }
195                except Exception:
196                    # Try loading the encoder state directly from a checkpoint.
197                    encoder_state = torch.load(checkpoint, weights_only=False)
198
199            elif backbone == "mae":
200                # vit initialization hints from:
201                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
202                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
203                encoder_state = OrderedDict({
204                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
205                })
206                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
207                current_encoder_state = self.encoder.state_dict()
208                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
209                    del self.encoder.head
210
211            elif backbone == "scalemae":
212                # Load the encoder state directly from a checkpoint.
213                encoder_state = torch.load(checkpoint)["model"]
214                encoder_state = OrderedDict({
215                    k: v for k, v in encoder_state.items()
216                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
217                })
218
219                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
220                current_encoder_state = self.encoder.state_dict()
221                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
222                    del self.encoder.head
223
224                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
225                    del self.encoder.pos_embed
226
227            elif backbone in ["dinov2", "dinov3"]:  # Load the encoder state directly from a checkpoint.
228                encoder_state = torch.load(checkpoint)
229
230            else:
231                raise ValueError(
232                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
233                )
234
235        else:
236            encoder_state = checkpoint
237
238        self.encoder.load_state_dict(encoder_state)
239
240    def _get_activation(self, activation):
241        return_activation = None
242        if activation is None:
243            return None
244        if isinstance(activation, nn.Module):
245            return activation
246        if isinstance(activation, str):
247            return_activation = getattr(nn, activation, None)
248        if return_activation is None:
249            raise ValueError(f"Invalid activation: {activation}")
250
251        return return_activation()
252
253    @staticmethod
254    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
255        """Compute the output size given input size and target long side length.
256
257        Args:
258            oldh: The input image height.
259            oldw: The input image width.
260            long_side_length: The longest side length for resizing.
261
262        Returns:
263            The new image height.
264            The new image width.
265        """
266        scale = long_side_length * 1.0 / max(oldh, oldw)
267        newh, neww = oldh * scale, oldw * scale
268        neww = int(neww + 0.5)
269        newh = int(newh + 0.5)
270        return (newh, neww)
271
272    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
273        """Resize the image so that the longest side has the correct length.
274
275        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
276
277        Args:
278            image: The input image.
279
280        Returns:
281            The resized image.
282        """
283        if image.ndim == 4:  # i.e. 2d image
284            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
285            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
286        elif image.ndim == 5:  # i.e. 3d volume
287            B, C, Z, H, W = image.shape
288            target_size = self.get_preprocess_shape(H, W, self.img_size)
289            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
290        else:
291            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
292
293    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
294        """@private
295        """
296        # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1).
297        view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
298        pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
299        pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
300        return pixel_mean, pixel_std
301
302    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
303        """@private
304        """
305        device = x.device
306        is_3d = (x.ndim == 5)
307        device, dtype = x.device, x.dtype
308
309        if self.use_sam_stats:
310            mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
311        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
312            raise NotImplementedError
313        elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"):
314            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
315        elif self.use_sam_stats and self.backbone == "sam3":
316            mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
317        else:
318            mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
319
320        pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
321
322        if self.resize_input:
323            x = self.resize_longest_side(x)
324        input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
325
326        x = (x - pixel_mean) / pixel_std
327        h, w = x.shape[-2:]
328        padh = self.encoder.img_size - h
329        padw = self.encoder.img_size - w
330
331        if is_3d:
332            x = F.pad(x, (0, padw, 0, padh, 0, 0))
333        else:
334            x = F.pad(x, (0, padw, 0, padh))
335
336        return x, input_shape
337
338    def postprocess_masks(
339        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
340    ) -> torch.Tensor:
341        """@private
342        """
343        if masks.ndim == 4:  # i.e. 2d labels
344            masks = F.interpolate(
345                masks,
346                (self.encoder.img_size, self.encoder.img_size),
347                mode="bilinear",
348                align_corners=False,
349            )
350            masks = masks[..., : input_size[0], : input_size[1]]
351            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
352
353        elif masks.ndim == 5:  # i.e. 3d volumetric labels
354            masks = F.interpolate(
355                masks,
356                (input_size[0], self.img_size, self.img_size),
357                mode="trilinear",
358                align_corners=False,
359            )
360            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
361            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
362
363        else:
364            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
365
366        return masks

Base class for implementing a UNETR.

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • backbone: The name of the vision transformer implementation. One of "sam", "sam2", "sam3, "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
  • encoder: The vision transformer. Can either be a name, such as "vit_b" (see all combinations for this below) or a torch module.
  • decoder: The convolutional decoder.
  • out_channels: The number of output channels of the UNETR.
  • use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM / SAM2 / SAM3 model.
  • use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv2 / DINOv3 model.
  • use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  • resize_input: Whether to resize the input images to match img_size. By default, it resizes the inputs to match the img_size.
  • encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
  • final_activation: The activation to apply to the UNETR output.
  • use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
  • embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
  • use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
  • NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
  • SAM_family_models: - 'sam' x 'vit_b'
    • 'sam' x 'vit_l'
    • 'sam' x 'vit_h'
    • 'sam2' x 'hvit_t'
    • 'sam2' x 'hvit_s'
    • 'sam2' x 'hvit_b'
    • 'sam2' x 'hvit_l'
    • 'sam3' x 'vit_pe'
  • DINO_family_models: - 'dinov2' x 'vit_s'
    • 'dinov2' x 'vit_b'
    • 'dinov2' x 'vit_l'
    • 'dinov2' x 'vit_g'
    • 'dinov2' x 'vit_s_reg4'
    • 'dinov2' x 'vit_b_reg4'
    • 'dinov2' x 'vit_l_reg4'
    • 'dinov2' x 'vit_g_reg4'
    • 'dinov3' x 'vit_s'
    • 'dinov3' x 'vit_s+'
    • 'dinov3' x 'vit_b'
    • 'dinov3' x 'vit_l'
    • 'dinov3' x 'vit_l+'
    • 'dinov3' x 'vit_h+'
    • 'dinov3' x 'vit_7b'
  • MAE_family_models: - 'mae' x 'vit_b'
    • 'mae' x 'vit_l'
    • 'mae' x 'vit_h'
    • 'scalemae' x 'vit_b'
    • 'scalemae' x 'vit_l'
    • 'scalemae' x 'vit_h'
UNETRBase( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
 97    def __init__(
 98        self,
 99        img_size: int = 1024,
100        backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
101        encoder: Optional[Union[nn.Module, str]] = "vit_b",
102        decoder: Optional[nn.Module] = None,
103        out_channels: int = 1,
104        use_sam_stats: bool = False,
105        use_mae_stats: bool = False,
106        use_dino_stats: bool = False,
107        resize_input: bool = True,
108        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
109        final_activation: Optional[Union[str, nn.Module]] = None,
110        use_skip_connection: bool = True,
111        embed_dim: Optional[int] = None,
112        use_conv_transpose: bool = False,
113        **kwargs
114    ) -> None:
115        super().__init__()
116
117        self.img_size = img_size
118        self.use_sam_stats = use_sam_stats
119        self.use_mae_stats = use_mae_stats
120        self.use_dino_stats = use_dino_stats
121        self.use_skip_connection = use_skip_connection
122        self.resize_input = resize_input
123        self.use_conv_transpose = use_conv_transpose
124        self.backbone = backbone
125
126        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
127            print(f"Using {encoder} from {backbone.upper()}")
128            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
129
130            if encoder_checkpoint is not None:
131                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
132
133            if embed_dim is None:
134                embed_dim = self.encoder.embed_dim
135
136        else:  # `nn.Module` ViT backbone
137            self.encoder = encoder
138
139            have_neck = False
140            for name, _ in self.encoder.named_parameters():
141                if name.startswith("neck"):
142                    have_neck = True
143
144            if embed_dim is None:
145                if have_neck:
146                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
147                else:
148                    embed_dim = self.encoder.patch_embed.proj.out_channels
149
150        self.embed_dim = embed_dim
151        self.final_activation = self._get_activation(final_activation)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

img_size
use_sam_stats
use_mae_stats
use_dino_stats
use_skip_connection
resize_input
use_conv_transpose
backbone
embed_dim
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
253    @staticmethod
254    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
255        """Compute the output size given input size and target long side length.
256
257        Args:
258            oldh: The input image height.
259            oldw: The input image width.
260            long_side_length: The longest side length for resizing.
261
262        Returns:
263            The new image height.
264            The new image width.
265        """
266        scale = long_side_length * 1.0 / max(oldh, oldw)
267        newh, neww = oldh * scale, oldw * scale
268        neww = int(neww + 0.5)
269        newh = int(newh + 0.5)
270        return (newh, neww)

Compute the output size given input size and target long side length.

Arguments:
  • oldh: The input image height.
  • oldw: The input image width.
  • long_side_length: The longest side length for resizing.
Returns:

The new image height. The new image width.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
272    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
273        """Resize the image so that the longest side has the correct length.
274
275        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
276
277        Args:
278            image: The input image.
279
280        Returns:
281            The resized image.
282        """
283        if image.ndim == 4:  # i.e. 2d image
284            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
285            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
286        elif image.ndim == 5:  # i.e. 3d volume
287            B, C, Z, H, W = image.shape
288            target_size = self.get_preprocess_shape(H, W, self.img_size)
289            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
290        else:
291            raise ValueError("Expected 4d or 5d inputs, got", image.shape)

Resize the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.

Arguments:
  • image: The input image.
Returns:

The resized image.

class UNETR(UNETRBase):
369class UNETR(UNETRBase):
370    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
371    """
372    def __init__(
373        self,
374        img_size: int = 1024,
375        backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
376        encoder: Optional[Union[nn.Module, str]] = "vit_b",
377        decoder: Optional[nn.Module] = None,
378        out_channels: int = 1,
379        use_sam_stats: bool = False,
380        use_mae_stats: bool = False,
381        use_dino_stats: bool = False,
382        resize_input: bool = True,
383        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
384        final_activation: Optional[Union[str, nn.Module]] = None,
385        use_skip_connection: bool = True,
386        embed_dim: Optional[int] = None,
387        use_conv_transpose: bool = False,
388        **kwargs
389    ) -> None:
390
391        super().__init__(
392            img_size=img_size,
393            backbone=backbone,
394            encoder=encoder,
395            decoder=decoder,
396            out_channels=out_channels,
397            use_sam_stats=use_sam_stats,
398            use_mae_stats=use_mae_stats,
399            use_dino_stats=use_dino_stats,
400            resize_input=resize_input,
401            encoder_checkpoint=encoder_checkpoint,
402            final_activation=final_activation,
403            use_skip_connection=use_skip_connection,
404            embed_dim=embed_dim,
405            use_conv_transpose=use_conv_transpose,
406            **kwargs,
407        )
408
409        encoder = self.encoder
410
411        if backbone == "sam2" and hasattr(encoder, "trunk"):
412            in_chans = encoder.trunk.patch_embed.proj.in_channels
413        elif hasattr(encoder, "in_chans"):
414            in_chans = encoder.in_chans
415        else:  # `nn.Module` ViT backbone.
416            try:
417                in_chans = encoder.patch_embed.proj.in_channels
418            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
419                in_chans = encoder.patch_embed.seq[0].c.in_channels
420
421        # parameters for the decoder network
422        depth = 3
423        initial_features = 64
424        gain = 2
425        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
426        scale_factors = depth * [2]
427        self.out_channels = out_channels
428
429        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
430        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
431
432        self.decoder = decoder or Decoder(
433            features=features_decoder,
434            scale_factors=scale_factors[::-1],
435            conv_block_impl=ConvBlock2d,
436            sampler_impl=_upsampler,
437        )
438
439        if use_skip_connection:
440            self.deconv1 = Deconv2DBlock(
441                in_channels=self.embed_dim,
442                out_channels=features_decoder[0],
443                use_conv_transpose=use_conv_transpose,
444            )
445            self.deconv2 = nn.Sequential(
446                Deconv2DBlock(
447                    in_channels=self.embed_dim,
448                    out_channels=features_decoder[0],
449                    use_conv_transpose=use_conv_transpose,
450                ),
451                Deconv2DBlock(
452                    in_channels=features_decoder[0],
453                    out_channels=features_decoder[1],
454                    use_conv_transpose=use_conv_transpose,
455                )
456            )
457            self.deconv3 = nn.Sequential(
458                Deconv2DBlock(
459                    in_channels=self.embed_dim,
460                    out_channels=features_decoder[0],
461                    use_conv_transpose=use_conv_transpose,
462                ),
463                Deconv2DBlock(
464                    in_channels=features_decoder[0],
465                    out_channels=features_decoder[1],
466                    use_conv_transpose=use_conv_transpose,
467                ),
468                Deconv2DBlock(
469                    in_channels=features_decoder[1],
470                    out_channels=features_decoder[2],
471                    use_conv_transpose=use_conv_transpose,
472                )
473            )
474            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
475        else:
476            self.deconv1 = Deconv2DBlock(
477                in_channels=self.embed_dim,
478                out_channels=features_decoder[0],
479                use_conv_transpose=use_conv_transpose,
480            )
481            self.deconv2 = Deconv2DBlock(
482                in_channels=features_decoder[0],
483                out_channels=features_decoder[1],
484                use_conv_transpose=use_conv_transpose,
485            )
486            self.deconv3 = Deconv2DBlock(
487                in_channels=features_decoder[1],
488                out_channels=features_decoder[2],
489                use_conv_transpose=use_conv_transpose,
490            )
491            self.deconv4 = Deconv2DBlock(
492                in_channels=features_decoder[2],
493                out_channels=features_decoder[3],
494                use_conv_transpose=use_conv_transpose,
495            )
496
497        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
498        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
499        self.deconv_out = _upsampler(
500            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
501        )
502        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
503
504    def forward(self, x: torch.Tensor) -> torch.Tensor:
505        """Apply the UNETR to the input data.
506
507        Args:
508            x: The input tensor.
509
510        Returns:
511            The UNETR output.
512        """
513        original_shape = x.shape[-2:]
514
515        # Reshape the inputs to the shape expected by the encoder
516        # and normalize the inputs if normalization is part of the model.
517        x, input_shape = self.preprocess(x)
518
519        encoder_outputs = self.encoder(x)
520
521        if isinstance(encoder_outputs[-1], list):
522            # `encoder_outputs` can be arranged in only two forms:
523            #   - either we only return the image embeddings
524            #   - or, we return the image embeddings and the "list" of global attention layers
525            z12, from_encoder = encoder_outputs
526        else:
527            z12 = encoder_outputs
528
529        if self.use_skip_connection:
530            from_encoder = from_encoder[::-1]
531            z9 = self.deconv1(from_encoder[0])
532            z6 = self.deconv2(from_encoder[1])
533            z3 = self.deconv3(from_encoder[2])
534            z0 = self.deconv4(x)
535
536        else:
537            z9 = self.deconv1(z12)
538            z6 = self.deconv2(z9)
539            z3 = self.deconv3(z6)
540            z0 = self.deconv4(z3)
541
542        updated_from_encoder = [z9, z6, z3]
543
544        x = self.base(z12)
545        x = self.decoder(x, encoder_inputs=updated_from_encoder)
546        x = self.deconv_out(x)
547
548        x = torch.cat([x, z0], dim=1)
549        x = self.decoder_head(x)
550
551        x = self.out_conv(x)
552        if self.final_activation is not None:
553            x = self.final_activation(x)
554
555        x = self.postprocess_masks(x, input_shape, original_shape)
556        return x

A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
372    def __init__(
373        self,
374        img_size: int = 1024,
375        backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
376        encoder: Optional[Union[nn.Module, str]] = "vit_b",
377        decoder: Optional[nn.Module] = None,
378        out_channels: int = 1,
379        use_sam_stats: bool = False,
380        use_mae_stats: bool = False,
381        use_dino_stats: bool = False,
382        resize_input: bool = True,
383        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
384        final_activation: Optional[Union[str, nn.Module]] = None,
385        use_skip_connection: bool = True,
386        embed_dim: Optional[int] = None,
387        use_conv_transpose: bool = False,
388        **kwargs
389    ) -> None:
390
391        super().__init__(
392            img_size=img_size,
393            backbone=backbone,
394            encoder=encoder,
395            decoder=decoder,
396            out_channels=out_channels,
397            use_sam_stats=use_sam_stats,
398            use_mae_stats=use_mae_stats,
399            use_dino_stats=use_dino_stats,
400            resize_input=resize_input,
401            encoder_checkpoint=encoder_checkpoint,
402            final_activation=final_activation,
403            use_skip_connection=use_skip_connection,
404            embed_dim=embed_dim,
405            use_conv_transpose=use_conv_transpose,
406            **kwargs,
407        )
408
409        encoder = self.encoder
410
411        if backbone == "sam2" and hasattr(encoder, "trunk"):
412            in_chans = encoder.trunk.patch_embed.proj.in_channels
413        elif hasattr(encoder, "in_chans"):
414            in_chans = encoder.in_chans
415        else:  # `nn.Module` ViT backbone.
416            try:
417                in_chans = encoder.patch_embed.proj.in_channels
418            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
419                in_chans = encoder.patch_embed.seq[0].c.in_channels
420
421        # parameters for the decoder network
422        depth = 3
423        initial_features = 64
424        gain = 2
425        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
426        scale_factors = depth * [2]
427        self.out_channels = out_channels
428
429        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
430        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
431
432        self.decoder = decoder or Decoder(
433            features=features_decoder,
434            scale_factors=scale_factors[::-1],
435            conv_block_impl=ConvBlock2d,
436            sampler_impl=_upsampler,
437        )
438
439        if use_skip_connection:
440            self.deconv1 = Deconv2DBlock(
441                in_channels=self.embed_dim,
442                out_channels=features_decoder[0],
443                use_conv_transpose=use_conv_transpose,
444            )
445            self.deconv2 = nn.Sequential(
446                Deconv2DBlock(
447                    in_channels=self.embed_dim,
448                    out_channels=features_decoder[0],
449                    use_conv_transpose=use_conv_transpose,
450                ),
451                Deconv2DBlock(
452                    in_channels=features_decoder[0],
453                    out_channels=features_decoder[1],
454                    use_conv_transpose=use_conv_transpose,
455                )
456            )
457            self.deconv3 = nn.Sequential(
458                Deconv2DBlock(
459                    in_channels=self.embed_dim,
460                    out_channels=features_decoder[0],
461                    use_conv_transpose=use_conv_transpose,
462                ),
463                Deconv2DBlock(
464                    in_channels=features_decoder[0],
465                    out_channels=features_decoder[1],
466                    use_conv_transpose=use_conv_transpose,
467                ),
468                Deconv2DBlock(
469                    in_channels=features_decoder[1],
470                    out_channels=features_decoder[2],
471                    use_conv_transpose=use_conv_transpose,
472                )
473            )
474            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
475        else:
476            self.deconv1 = Deconv2DBlock(
477                in_channels=self.embed_dim,
478                out_channels=features_decoder[0],
479                use_conv_transpose=use_conv_transpose,
480            )
481            self.deconv2 = Deconv2DBlock(
482                in_channels=features_decoder[0],
483                out_channels=features_decoder[1],
484                use_conv_transpose=use_conv_transpose,
485            )
486            self.deconv3 = Deconv2DBlock(
487                in_channels=features_decoder[1],
488                out_channels=features_decoder[2],
489                use_conv_transpose=use_conv_transpose,
490            )
491            self.deconv4 = Deconv2DBlock(
492                in_channels=features_decoder[2],
493                out_channels=features_decoder[3],
494                use_conv_transpose=use_conv_transpose,
495            )
496
497        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
498        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
499        self.deconv_out = _upsampler(
500            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
501        )
502        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
decoder
base
out_conv
deconv_out
decoder_head
def forward(self, x: torch.Tensor) -> torch.Tensor:
504    def forward(self, x: torch.Tensor) -> torch.Tensor:
505        """Apply the UNETR to the input data.
506
507        Args:
508            x: The input tensor.
509
510        Returns:
511            The UNETR output.
512        """
513        original_shape = x.shape[-2:]
514
515        # Reshape the inputs to the shape expected by the encoder
516        # and normalize the inputs if normalization is part of the model.
517        x, input_shape = self.preprocess(x)
518
519        encoder_outputs = self.encoder(x)
520
521        if isinstance(encoder_outputs[-1], list):
522            # `encoder_outputs` can be arranged in only two forms:
523            #   - either we only return the image embeddings
524            #   - or, we return the image embeddings and the "list" of global attention layers
525            z12, from_encoder = encoder_outputs
526        else:
527            z12 = encoder_outputs
528
529        if self.use_skip_connection:
530            from_encoder = from_encoder[::-1]
531            z9 = self.deconv1(from_encoder[0])
532            z6 = self.deconv2(from_encoder[1])
533            z3 = self.deconv3(from_encoder[2])
534            z0 = self.deconv4(x)
535
536        else:
537            z9 = self.deconv1(z12)
538            z6 = self.deconv2(z9)
539            z3 = self.deconv3(z6)
540            z0 = self.deconv4(z3)
541
542        updated_from_encoder = [z9, z6, z3]
543
544        x = self.base(z12)
545        x = self.decoder(x, encoder_inputs=updated_from_encoder)
546        x = self.deconv_out(x)
547
548        x = torch.cat([x, z0], dim=1)
549        x = self.decoder_head(x)
550
551        x = self.out_conv(x)
552        if self.final_activation is not None:
553            x = self.final_activation(x)
554
555        x = self.postprocess_masks(x, input_shape, original_shape)
556        return x

Apply the UNETR to the input data.

Arguments:
  • x: The input tensor.
Returns:

The UNETR output.

class UNETR2D(UNETR):
559class UNETR2D(UNETR):
560    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
561    """
562    pass

A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

class UNETR3D(UNETRBase):
565class UNETR3D(UNETRBase):
566    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
567    """
568    def __init__(
569        self,
570        img_size: int = 1024,
571        backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
572        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
573        decoder: Optional[nn.Module] = None,
574        out_channels: int = 1,
575        use_sam_stats: bool = False,
576        use_mae_stats: bool = False,
577        use_dino_stats: bool = False,
578        resize_input: bool = True,
579        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
580        final_activation: Optional[Union[str, nn.Module]] = None,
581        use_skip_connection: bool = False,
582        embed_dim: Optional[int] = None,
583        use_conv_transpose: bool = False,
584        use_strip_pooling: bool = True,
585        **kwargs
586    ):
587        if use_skip_connection:
588            raise NotImplementedError("The framework cannot handle skip connections atm.")
589        if use_conv_transpose:
590            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
591
592        # Sort the `embed_dim` out
593        embed_dim = 256 if embed_dim is None else embed_dim
594
595        super().__init__(
596            img_size=img_size,
597            backbone=backbone,
598            encoder=encoder,
599            decoder=decoder,
600            out_channels=out_channels,
601            use_sam_stats=use_sam_stats,
602            use_mae_stats=use_mae_stats,
603            use_dino_stats=use_dino_stats,
604            resize_input=resize_input,
605            encoder_checkpoint=encoder_checkpoint,
606            final_activation=final_activation,
607            use_skip_connection=use_skip_connection,
608            embed_dim=embed_dim,
609            use_conv_transpose=use_conv_transpose,
610            **kwargs,
611        )
612
613        # The 3d convolutional decoder.
614        # First, get the important parameters for the decoder.
615        depth = 3
616        initial_features = 64
617        gain = 2
618        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
619        scale_factors = [1, 2, 2]
620        self.out_channels = out_channels
621
622        # The mapping blocks.
623        self.deconv1 = Deconv3DBlock(
624            in_channels=embed_dim,
625            out_channels=features_decoder[0],
626            scale_factor=scale_factors,
627            use_strip_pooling=use_strip_pooling,
628        )
629        self.deconv2 = Deconv3DBlock(
630            in_channels=features_decoder[0],
631            out_channels=features_decoder[1],
632            scale_factor=scale_factors,
633            use_strip_pooling=use_strip_pooling,
634        )
635        self.deconv3 = Deconv3DBlock(
636            in_channels=features_decoder[1],
637            out_channels=features_decoder[2],
638            scale_factor=scale_factors,
639            use_strip_pooling=use_strip_pooling,
640        )
641        self.deconv4 = Deconv3DBlock(
642            in_channels=features_decoder[2],
643            out_channels=features_decoder[3],
644            scale_factor=scale_factors,
645            use_strip_pooling=use_strip_pooling,
646        )
647
648        # The core decoder block.
649        self.decoder = decoder or Decoder(
650            features=features_decoder,
651            scale_factors=[scale_factors] * depth,
652            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
653            sampler_impl=Upsampler3d,
654        )
655
656        # And the final upsampler to match the expected dimensions.
657        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
658            in_channels=features_decoder[-1],
659            out_channels=features_decoder[-1],
660            scale_factor=scale_factors,
661            use_strip_pooling=use_strip_pooling,
662        )
663
664        # Additional conjunction blocks.
665        self.base = ConvBlock3dWithStrip(
666            in_channels=embed_dim,
667            out_channels=features_decoder[0],
668            use_strip_pooling=use_strip_pooling,
669        )
670
671        # And the output layers.
672        self.decoder_head = ConvBlock3dWithStrip(
673            in_channels=2 * features_decoder[-1],
674            out_channels=features_decoder[-1],
675            use_strip_pooling=use_strip_pooling,
676        )
677        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
678
679    def forward(self, x: torch.Tensor):
680        """Forward pass of the UNETR-3D model.
681
682        Args:
683            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
684
685        Returns:
686            The UNETR output.
687        """
688        B, C, Z, H, W = x.shape
689        original_shape = (Z, H, W)
690
691        # Preprocessing step
692        x, input_shape = self.preprocess(x)
693
694        # Run the image encoder.
695        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
696
697        # Prepare the counterparts for the decoder.
698        # NOTE: The section below is sequential, there's no skip connections atm.
699        z9 = self.deconv1(curr_features)
700        z6 = self.deconv2(z9)
701        z3 = self.deconv3(z6)
702        z0 = self.deconv4(z3)
703
704        updated_from_encoder = [z9, z6, z3]
705
706        # Align the features through the base block.
707        x = self.base(curr_features)
708        # Run the decoder
709        x = self.decoder(x, encoder_inputs=updated_from_encoder)
710        x = self.deconv_out(x)  # NOTE before `end_up`
711
712        # And the final output head.
713        x = torch.cat([x, z0], dim=1)
714        x = self.decoder_head(x)
715        x = self.out_conv(x)
716        if self.final_activation is not None:
717            x = self.final_activation(x)
718
719        # Postprocess the output back to original size.
720        x = self.postprocess_masks(x, input_shape, original_shape)
721        return x

A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR3D( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'hvit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = False, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, use_strip_pooling: bool = True, **kwargs)
568    def __init__(
569        self,
570        img_size: int = 1024,
571        backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
572        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
573        decoder: Optional[nn.Module] = None,
574        out_channels: int = 1,
575        use_sam_stats: bool = False,
576        use_mae_stats: bool = False,
577        use_dino_stats: bool = False,
578        resize_input: bool = True,
579        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
580        final_activation: Optional[Union[str, nn.Module]] = None,
581        use_skip_connection: bool = False,
582        embed_dim: Optional[int] = None,
583        use_conv_transpose: bool = False,
584        use_strip_pooling: bool = True,
585        **kwargs
586    ):
587        if use_skip_connection:
588            raise NotImplementedError("The framework cannot handle skip connections atm.")
589        if use_conv_transpose:
590            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
591
592        # Sort the `embed_dim` out
593        embed_dim = 256 if embed_dim is None else embed_dim
594
595        super().__init__(
596            img_size=img_size,
597            backbone=backbone,
598            encoder=encoder,
599            decoder=decoder,
600            out_channels=out_channels,
601            use_sam_stats=use_sam_stats,
602            use_mae_stats=use_mae_stats,
603            use_dino_stats=use_dino_stats,
604            resize_input=resize_input,
605            encoder_checkpoint=encoder_checkpoint,
606            final_activation=final_activation,
607            use_skip_connection=use_skip_connection,
608            embed_dim=embed_dim,
609            use_conv_transpose=use_conv_transpose,
610            **kwargs,
611        )
612
613        # The 3d convolutional decoder.
614        # First, get the important parameters for the decoder.
615        depth = 3
616        initial_features = 64
617        gain = 2
618        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
619        scale_factors = [1, 2, 2]
620        self.out_channels = out_channels
621
622        # The mapping blocks.
623        self.deconv1 = Deconv3DBlock(
624            in_channels=embed_dim,
625            out_channels=features_decoder[0],
626            scale_factor=scale_factors,
627            use_strip_pooling=use_strip_pooling,
628        )
629        self.deconv2 = Deconv3DBlock(
630            in_channels=features_decoder[0],
631            out_channels=features_decoder[1],
632            scale_factor=scale_factors,
633            use_strip_pooling=use_strip_pooling,
634        )
635        self.deconv3 = Deconv3DBlock(
636            in_channels=features_decoder[1],
637            out_channels=features_decoder[2],
638            scale_factor=scale_factors,
639            use_strip_pooling=use_strip_pooling,
640        )
641        self.deconv4 = Deconv3DBlock(
642            in_channels=features_decoder[2],
643            out_channels=features_decoder[3],
644            scale_factor=scale_factors,
645            use_strip_pooling=use_strip_pooling,
646        )
647
648        # The core decoder block.
649        self.decoder = decoder or Decoder(
650            features=features_decoder,
651            scale_factors=[scale_factors] * depth,
652            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
653            sampler_impl=Upsampler3d,
654        )
655
656        # And the final upsampler to match the expected dimensions.
657        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
658            in_channels=features_decoder[-1],
659            out_channels=features_decoder[-1],
660            scale_factor=scale_factors,
661            use_strip_pooling=use_strip_pooling,
662        )
663
664        # Additional conjunction blocks.
665        self.base = ConvBlock3dWithStrip(
666            in_channels=embed_dim,
667            out_channels=features_decoder[0],
668            use_strip_pooling=use_strip_pooling,
669        )
670
671        # And the output layers.
672        self.decoder_head = ConvBlock3dWithStrip(
673            in_channels=2 * features_decoder[-1],
674            out_channels=features_decoder[-1],
675            use_strip_pooling=use_strip_pooling,
676        )
677        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
deconv1
deconv2
deconv3
deconv4
decoder
deconv_out
base
decoder_head
out_conv
def forward(self, x: torch.Tensor):
679    def forward(self, x: torch.Tensor):
680        """Forward pass of the UNETR-3D model.
681
682        Args:
683            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
684
685        Returns:
686            The UNETR output.
687        """
688        B, C, Z, H, W = x.shape
689        original_shape = (Z, H, W)
690
691        # Preprocessing step
692        x, input_shape = self.preprocess(x)
693
694        # Run the image encoder.
695        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
696
697        # Prepare the counterparts for the decoder.
698        # NOTE: The section below is sequential, there's no skip connections atm.
699        z9 = self.deconv1(curr_features)
700        z6 = self.deconv2(z9)
701        z3 = self.deconv3(z6)
702        z0 = self.deconv4(z3)
703
704        updated_from_encoder = [z9, z6, z3]
705
706        # Align the features through the base block.
707        x = self.base(curr_features)
708        # Run the decoder
709        x = self.decoder(x, encoder_inputs=updated_from_encoder)
710        x = self.deconv_out(x)  # NOTE before `end_up`
711
712        # And the final output head.
713        x = torch.cat([x, z0], dim=1)
714        x = self.decoder_head(x)
715        x = self.out_conv(x)
716        if self.final_activation is not None:
717            x = self.final_activation(x)
718
719        # Postprocess the output back to original size.
720        x = self.postprocess_masks(x, input_shape, original_shape)
721        return x

Forward pass of the UNETR-3D model.

Arguments:
  • x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:

The UNETR output.