torch_em.model.unetr

  1from functools import partial
  2from collections import OrderedDict
  3from typing import Optional, Tuple, Union, Literal
  4
  5import torch
  6import torch.nn as nn
  7import torch.nn.functional as F
  8
  9from .vit import get_vision_transformer
 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs
 11
 12try:
 13    from micro_sam.util import get_sam_model
 14except ImportError:
 15    get_sam_model = None
 16
 17try:
 18    from micro_sam.v2.util import get_sam2_model
 19except ImportError:
 20    get_sam2_model = None
 21
 22try:
 23    from micro_sam3.util import get_sam3_model
 24except ImportError:
 25    get_sam3_model = None
 26
 27
 28#
 29# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / CellposeSAM / SAM2 / SAM3 / DINOv2 / DINOv3 / MAE / ScaleMAE) + UNet Decoder from `torch_em`]  # noqa
 30#
 31
 32
 33class UNETRBase(nn.Module):
 34    """Base class for implementing a UNETR.
 35
 36    Args:
 37        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 38        backbone: The name of the vision transformer implementation.
 39            One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"
 40            (see all combinations below)
 41        encoder: The vision transformer. Can either be a name, such as "vit_b"
 42            (see all combinations for this below) or a torch module.
 43        decoder: The convolutional decoder.
 44        out_channels: The number of output channels of the UNETR.
 45        use_sam_stats: Whether to normalize the input data with the statistics of the
 46            pretrained SAM / SAM2 / SAM3 model.
 47        use_dino_stats: Whether to normalize the input data with the statistics of the
 48            pretrained DINOv2 / DINOv3 model.
 49        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 50        resize_input: Whether to resize the input images to match `img_size`.
 51            By default, it resizes the inputs to match the `img_size`.
 52        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 53            Can either be a filepath or an already loaded checkpoint.
 54        final_activation: The activation to apply to the UNETR output.
 55        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 56        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 57        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 58            By default, it uses resampling for upsampling.
 59
 60        NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
 61
 62        SAM_family_models:
 63            - 'sam' x 'vit_b'
 64            - 'sam' x 'vit_l'
 65            - 'sam' x 'vit_h'
 66            - 'sam2' x 'hvit_t'
 67            - 'sam2' x 'hvit_s'
 68            - 'sam2' x 'hvit_b'
 69            - 'sam2' x 'hvit_l'
 70            - 'sam3' x 'vit_pe'
 71            - 'cellpose_sam' x 'vit_l'
 72
 73        DINO_family_models:
 74            - 'dinov2' x 'vit_s'
 75            - 'dinov2' x 'vit_b'
 76            - 'dinov2' x 'vit_l'
 77            - 'dinov2' x 'vit_g'
 78            - 'dinov2' x 'vit_s_reg4'
 79            - 'dinov2' x 'vit_b_reg4'
 80            - 'dinov2' x 'vit_l_reg4'
 81            - 'dinov2' x 'vit_g_reg4'
 82            - 'dinov3' x 'vit_s'
 83            - 'dinov3' x 'vit_s+'
 84            - 'dinov3' x 'vit_b'
 85            - 'dinov3' x 'vit_l'
 86            - 'dinov3' x 'vit_l+'
 87            - 'dinov3' x 'vit_h+'
 88            - 'dinov3' x 'vit_7b'
 89
 90        MAE_family_models:
 91            - 'mae' x 'vit_b'
 92            - 'mae' x 'vit_l'
 93            - 'mae' x 'vit_h'
 94            - 'scalemae' x 'vit_b'
 95            - 'scalemae' x 'vit_l'
 96            - 'scalemae' x 'vit_h'
 97    """
 98    def __init__(
 99        self,
100        img_size: int = 1024,
101        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
102        encoder: Optional[Union[nn.Module, str]] = "vit_b",
103        decoder: Optional[nn.Module] = None,
104        out_channels: int = 1,
105        use_sam_stats: bool = False,
106        use_mae_stats: bool = False,
107        use_dino_stats: bool = False,
108        resize_input: bool = True,
109        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
110        final_activation: Optional[Union[str, nn.Module]] = None,
111        use_skip_connection: bool = True,
112        embed_dim: Optional[int] = None,
113        use_conv_transpose: bool = False,
114        **kwargs
115    ) -> None:
116        super().__init__()
117
118        self.img_size = img_size
119        self.use_sam_stats = use_sam_stats
120        self.use_mae_stats = use_mae_stats
121        self.use_dino_stats = use_dino_stats
122        self.use_skip_connection = use_skip_connection
123        self.resize_input = resize_input
124        self.use_conv_transpose = use_conv_transpose
125        self.backbone = backbone
126
127        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
128            print(f"Using {encoder} from {backbone.upper()}")
129            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
130
131            if encoder_checkpoint is not None:
132                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
133
134            if embed_dim is None:
135                embed_dim = self.encoder.embed_dim
136
137        else:  # `nn.Module` ViT backbone
138            self.encoder = encoder
139
140            have_neck = False
141            for name, _ in self.encoder.named_parameters():
142                if name.startswith("neck"):
143                    have_neck = True
144
145            if embed_dim is None:
146                if have_neck:
147                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
148                else:
149                    embed_dim = self.encoder.patch_embed.proj.out_channels
150
151        self.embed_dim = embed_dim
152        self.final_activation = self._get_activation(final_activation)
153
154    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
155        """Function to load pretrained weights to the image encoder.
156        """
157        if isinstance(checkpoint, str):
158            if backbone == "sam" and isinstance(encoder, str):
159                # If we have a SAM encoder, then we first try to load the full SAM Model
160                # (using micro_sam) and otherwise fall back on directly loading the encoder state
161                # from the checkpoint
162                try:
163                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
164                    encoder_state = model.image_encoder.state_dict()
165                except Exception:
166                    # Try loading the encoder state directly from a checkpoint.
167                    encoder_state = torch.load(checkpoint, weights_only=False)
168
169            elif backbone == "cellpose_sam" and isinstance(encoder, str):
170                # The architecture matches CellposeSAM exactly (same rel_pos sizes),
171                # so weights load directly without any interpolation.
172                encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False)
173                # Handle DataParallel/DistributedDataParallel prefix.
174                if any(k.startswith("module.") for k in encoder_state.keys()):
175                    encoder_state = OrderedDict(
176                        {k[len("module."):]: v for k, v in encoder_state.items()}
177                    )
178                # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix).
179                if any(k.startswith("encoder.") for k in encoder_state.keys()):
180                    encoder_state = OrderedDict(
181                        {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")}
182                    )
183
184            elif backbone == "sam2" and isinstance(encoder, str):
185                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
186                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
187                # from the checkpoint
188                try:
189                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
190                    encoder_state = model.image_encoder.state_dict()
191                except Exception:
192                    # Try loading the encoder state directly from a checkpoint.
193                    encoder_state = torch.load(checkpoint, weights_only=False)
194
195            elif backbone == "sam3" and isinstance(encoder, str):
196                # If we have a SAM3 encoder, then we first try to load the full SAM3 Model.
197                # (using micro_sam3) and otherwise fall back on directly loading the encoder state
198                # from the checkpoint
199                try:
200                    model = get_sam3_model(checkpoint_path=checkpoint)
201                    encoder_state = model.backbone.vision_backbone.state_dict()
202                    # Let's align loading the encoder weights with expected parameter names
203                    encoder_state = {
204                        k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items()
205                    }
206                    # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks.
207                    encoder_state = {
208                        k: v for k, v in encoder_state.items()
209                        if not (k.startswith("convs.") or k.startswith("sam2_convs."))
210                    }
211                except Exception:
212                    # Try loading the encoder state directly from a checkpoint.
213                    encoder_state = torch.load(checkpoint, weights_only=False)
214
215            elif backbone == "mae":
216                # vit initialization hints from:
217                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
218                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
219                encoder_state = OrderedDict({
220                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
221                })
222                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
223                current_encoder_state = self.encoder.state_dict()
224                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
225                    del self.encoder.head
226
227            elif backbone == "scalemae":
228                # Load the encoder state directly from a checkpoint.
229                encoder_state = torch.load(checkpoint)["model"]
230                encoder_state = OrderedDict({
231                    k: v for k, v in encoder_state.items()
232                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
233                })
234
235                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
236                current_encoder_state = self.encoder.state_dict()
237                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
238                    del self.encoder.head
239
240                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
241                    del self.encoder.pos_embed
242
243            elif backbone in ["dinov2", "dinov3"]:  # Load the encoder state directly from a checkpoint.
244                encoder_state = torch.load(checkpoint)
245
246            else:
247                raise ValueError(
248                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
249                )
250
251        else:
252            encoder_state = checkpoint
253
254        self.encoder.load_state_dict(encoder_state)
255
256    def _get_activation(self, activation):
257        return_activation = None
258        if activation is None:
259            return None
260        if isinstance(activation, nn.Module):
261            return activation
262        if isinstance(activation, str):
263            return_activation = getattr(nn, activation, None)
264        if return_activation is None:
265            raise ValueError(f"Invalid activation: {activation}")
266
267        return return_activation()
268
269    @staticmethod
270    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
271        """Compute the output size given input size and target long side length.
272
273        Args:
274            oldh: The input image height.
275            oldw: The input image width.
276            long_side_length: The longest side length for resizing.
277
278        Returns:
279            The new image height.
280            The new image width.
281        """
282        scale = long_side_length * 1.0 / max(oldh, oldw)
283        newh, neww = oldh * scale, oldw * scale
284        neww = int(neww + 0.5)
285        newh = int(newh + 0.5)
286        return (newh, neww)
287
288    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
289        """Resize the image so that the longest side has the correct length.
290
291        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
292
293        Args:
294            image: The input image.
295
296        Returns:
297            The resized image.
298        """
299        if image.ndim == 4:  # i.e. 2d image
300            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
301            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
302        elif image.ndim == 5:  # i.e. 3d volume
303            B, C, Z, H, W = image.shape
304            target_size = self.get_preprocess_shape(H, W, self.img_size)
305            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
306        else:
307            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
308
309    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
310        """@private
311        """
312        # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1).
313        view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
314        pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
315        pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
316        return pixel_mean, pixel_std
317
318    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
319        """@private
320        """
321        device = x.device
322        is_3d = (x.ndim == 5)
323        device, dtype = x.device, x.dtype
324
325        if self.use_sam_stats:
326            mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
327        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
328            raise NotImplementedError
329        elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"):
330            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
331        elif self.use_sam_stats and self.backbone == "sam3":
332            mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
333        else:
334            mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
335
336        pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
337
338        if self.resize_input:
339            x = self.resize_longest_side(x)
340        input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
341
342        x = (x - pixel_mean) / pixel_std
343        h, w = x.shape[-2:]
344        padh = self.encoder.img_size - h
345        padw = self.encoder.img_size - w
346
347        if is_3d:
348            x = F.pad(x, (0, padw, 0, padh, 0, 0))
349        else:
350            x = F.pad(x, (0, padw, 0, padh))
351
352        return x, input_shape
353
354    def postprocess_masks(
355        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
356    ) -> torch.Tensor:
357        """@private
358        """
359        if masks.ndim == 4:  # i.e. 2d labels
360            masks = F.interpolate(
361                masks,
362                (self.encoder.img_size, self.encoder.img_size),
363                mode="bilinear",
364                align_corners=False,
365            )
366            masks = masks[..., : input_size[0], : input_size[1]]
367            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
368
369        elif masks.ndim == 5:  # i.e. 3d volumetric labels
370            masks = F.interpolate(
371                masks,
372                (input_size[0], self.img_size, self.img_size),
373                mode="trilinear",
374                align_corners=False,
375            )
376            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
377            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
378
379        else:
380            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
381
382        return masks
383
384
385class UNETR(UNETRBase):
386    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
387    """
388    def __init__(
389        self,
390        img_size: int = 1024,
391        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
392        encoder: Optional[Union[nn.Module, str]] = "vit_b",
393        decoder: Optional[nn.Module] = None,
394        out_channels: int = 1,
395        use_sam_stats: bool = False,
396        use_mae_stats: bool = False,
397        use_dino_stats: bool = False,
398        resize_input: bool = True,
399        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
400        final_activation: Optional[Union[str, nn.Module]] = None,
401        use_skip_connection: bool = True,
402        embed_dim: Optional[int] = None,
403        use_conv_transpose: bool = False,
404        **kwargs
405    ) -> None:
406
407        super().__init__(
408            img_size=img_size,
409            backbone=backbone,
410            encoder=encoder,
411            decoder=decoder,
412            out_channels=out_channels,
413            use_sam_stats=use_sam_stats,
414            use_mae_stats=use_mae_stats,
415            use_dino_stats=use_dino_stats,
416            resize_input=resize_input,
417            encoder_checkpoint=encoder_checkpoint,
418            final_activation=final_activation,
419            use_skip_connection=use_skip_connection,
420            embed_dim=embed_dim,
421            use_conv_transpose=use_conv_transpose,
422            **kwargs,
423        )
424
425        encoder = self.encoder
426
427        if backbone == "sam2" and hasattr(encoder, "trunk"):
428            in_chans = encoder.trunk.patch_embed.proj.in_channels
429        elif hasattr(encoder, "in_chans"):
430            in_chans = encoder.in_chans
431        else:  # `nn.Module` ViT backbone.
432            try:
433                in_chans = encoder.patch_embed.proj.in_channels
434            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
435                in_chans = encoder.patch_embed.seq[0].c.in_channels
436
437        # parameters for the decoder network
438        depth = 3
439        initial_features = 64
440        gain = 2
441        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
442        scale_factors = depth * [2]
443        self.out_channels = out_channels
444
445        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
446        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
447
448        self.decoder = decoder or Decoder(
449            features=features_decoder,
450            scale_factors=scale_factors[::-1],
451            conv_block_impl=ConvBlock2d,
452            sampler_impl=_upsampler,
453        )
454
455        if use_skip_connection:
456            self.deconv1 = Deconv2DBlock(
457                in_channels=self.embed_dim,
458                out_channels=features_decoder[0],
459                use_conv_transpose=use_conv_transpose,
460            )
461            self.deconv2 = nn.Sequential(
462                Deconv2DBlock(
463                    in_channels=self.embed_dim,
464                    out_channels=features_decoder[0],
465                    use_conv_transpose=use_conv_transpose,
466                ),
467                Deconv2DBlock(
468                    in_channels=features_decoder[0],
469                    out_channels=features_decoder[1],
470                    use_conv_transpose=use_conv_transpose,
471                )
472            )
473            self.deconv3 = nn.Sequential(
474                Deconv2DBlock(
475                    in_channels=self.embed_dim,
476                    out_channels=features_decoder[0],
477                    use_conv_transpose=use_conv_transpose,
478                ),
479                Deconv2DBlock(
480                    in_channels=features_decoder[0],
481                    out_channels=features_decoder[1],
482                    use_conv_transpose=use_conv_transpose,
483                ),
484                Deconv2DBlock(
485                    in_channels=features_decoder[1],
486                    out_channels=features_decoder[2],
487                    use_conv_transpose=use_conv_transpose,
488                )
489            )
490            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
491        else:
492            self.deconv1 = Deconv2DBlock(
493                in_channels=self.embed_dim,
494                out_channels=features_decoder[0],
495                use_conv_transpose=use_conv_transpose,
496            )
497            self.deconv2 = Deconv2DBlock(
498                in_channels=features_decoder[0],
499                out_channels=features_decoder[1],
500                use_conv_transpose=use_conv_transpose,
501            )
502            self.deconv3 = Deconv2DBlock(
503                in_channels=features_decoder[1],
504                out_channels=features_decoder[2],
505                use_conv_transpose=use_conv_transpose,
506            )
507            self.deconv4 = Deconv2DBlock(
508                in_channels=features_decoder[2],
509                out_channels=features_decoder[3],
510                use_conv_transpose=use_conv_transpose,
511            )
512
513        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
514        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
515        self.deconv_out = _upsampler(
516            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
517        )
518        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
519
520    def forward(self, x: torch.Tensor) -> torch.Tensor:
521        """Apply the UNETR to the input data.
522
523        Args:
524            x: The input tensor.
525
526        Returns:
527            The UNETR output.
528        """
529        original_shape = x.shape[-2:]
530
531        # Reshape the inputs to the shape expected by the encoder
532        # and normalize the inputs if normalization is part of the model.
533        x, input_shape = self.preprocess(x)
534
535        encoder_outputs = self.encoder(x)
536
537        if isinstance(encoder_outputs[-1], list):
538            # `encoder_outputs` can be arranged in only two forms:
539            #   - either we only return the image embeddings
540            #   - or, we return the image embeddings and the "list" of global attention layers
541            z12, from_encoder = encoder_outputs
542        else:
543            z12 = encoder_outputs
544
545        if self.use_skip_connection:
546            from_encoder = from_encoder[::-1]
547            z9 = self.deconv1(from_encoder[0])
548            z6 = self.deconv2(from_encoder[1])
549            z3 = self.deconv3(from_encoder[2])
550            z0 = self.deconv4(x)
551
552        else:
553            z9 = self.deconv1(z12)
554            z6 = self.deconv2(z9)
555            z3 = self.deconv3(z6)
556            z0 = self.deconv4(z3)
557
558        updated_from_encoder = [z9, z6, z3]
559
560        x = self.base(z12)
561        x = self.decoder(x, encoder_inputs=updated_from_encoder)
562        x = self.deconv_out(x)
563
564        x = torch.cat([x, z0], dim=1)
565        x = self.decoder_head(x)
566
567        x = self.out_conv(x)
568        if self.final_activation is not None:
569            x = self.final_activation(x)
570
571        x = self.postprocess_masks(x, input_shape, original_shape)
572        return x
573
574
575class UNETR2D(UNETR):
576    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
577    """
578    pass
579
580
581class UNETR3D(UNETRBase):
582    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
583    """
584    def __init__(
585        self,
586        img_size: int = 1024,
587        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
588        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
589        decoder: Optional[nn.Module] = None,
590        out_channels: int = 1,
591        use_sam_stats: bool = False,
592        use_mae_stats: bool = False,
593        use_dino_stats: bool = False,
594        resize_input: bool = True,
595        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
596        final_activation: Optional[Union[str, nn.Module]] = None,
597        use_skip_connection: bool = False,
598        embed_dim: Optional[int] = None,
599        use_conv_transpose: bool = False,
600        use_strip_pooling: bool = True,
601        **kwargs
602    ):
603        if use_skip_connection:
604            raise NotImplementedError("The framework cannot handle skip connections atm.")
605        if use_conv_transpose:
606            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
607
608        # Sort the `embed_dim` out
609        embed_dim = 256 if embed_dim is None else embed_dim
610
611        super().__init__(
612            img_size=img_size,
613            backbone=backbone,
614            encoder=encoder,
615            decoder=decoder,
616            out_channels=out_channels,
617            use_sam_stats=use_sam_stats,
618            use_mae_stats=use_mae_stats,
619            use_dino_stats=use_dino_stats,
620            resize_input=resize_input,
621            encoder_checkpoint=encoder_checkpoint,
622            final_activation=final_activation,
623            use_skip_connection=use_skip_connection,
624            embed_dim=embed_dim,
625            use_conv_transpose=use_conv_transpose,
626            **kwargs,
627        )
628
629        # The 3d convolutional decoder.
630        # First, get the important parameters for the decoder.
631        depth = 3
632        initial_features = 64
633        gain = 2
634        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
635        scale_factors = [1, 2, 2]
636        self.out_channels = out_channels
637
638        # The mapping blocks.
639        self.deconv1 = Deconv3DBlock(
640            in_channels=embed_dim,
641            out_channels=features_decoder[0],
642            scale_factor=scale_factors,
643            use_strip_pooling=use_strip_pooling,
644        )
645        self.deconv2 = Deconv3DBlock(
646            in_channels=features_decoder[0],
647            out_channels=features_decoder[1],
648            scale_factor=scale_factors,
649            use_strip_pooling=use_strip_pooling,
650        )
651        self.deconv3 = Deconv3DBlock(
652            in_channels=features_decoder[1],
653            out_channels=features_decoder[2],
654            scale_factor=scale_factors,
655            use_strip_pooling=use_strip_pooling,
656        )
657        self.deconv4 = Deconv3DBlock(
658            in_channels=features_decoder[2],
659            out_channels=features_decoder[3],
660            scale_factor=scale_factors,
661            use_strip_pooling=use_strip_pooling,
662        )
663
664        # The core decoder block.
665        self.decoder = decoder or Decoder(
666            features=features_decoder,
667            scale_factors=[scale_factors] * depth,
668            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
669            sampler_impl=Upsampler3d,
670        )
671
672        # And the final upsampler to match the expected dimensions.
673        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
674            in_channels=features_decoder[-1],
675            out_channels=features_decoder[-1],
676            scale_factor=scale_factors,
677            use_strip_pooling=use_strip_pooling,
678        )
679
680        # Additional conjunction blocks.
681        self.base = ConvBlock3dWithStrip(
682            in_channels=embed_dim,
683            out_channels=features_decoder[0],
684            use_strip_pooling=use_strip_pooling,
685        )
686
687        # And the output layers.
688        self.decoder_head = ConvBlock3dWithStrip(
689            in_channels=2 * features_decoder[-1],
690            out_channels=features_decoder[-1],
691            use_strip_pooling=use_strip_pooling,
692        )
693        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
694
695    def forward(self, x: torch.Tensor):
696        """Forward pass of the UNETR-3D model.
697
698        Args:
699            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
700
701        Returns:
702            The UNETR output.
703        """
704        B, C, Z, H, W = x.shape
705        original_shape = (Z, H, W)
706
707        # Preprocessing step
708        x, input_shape = self.preprocess(x)
709
710        # Run the image encoder.
711        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
712
713        # Prepare the counterparts for the decoder.
714        # NOTE: The section below is sequential, there's no skip connections atm.
715        z9 = self.deconv1(curr_features)
716        z6 = self.deconv2(z9)
717        z3 = self.deconv3(z6)
718        z0 = self.deconv4(z3)
719
720        updated_from_encoder = [z9, z6, z3]
721
722        # Align the features through the base block.
723        x = self.base(curr_features)
724        # Run the decoder
725        x = self.decoder(x, encoder_inputs=updated_from_encoder)
726        x = self.deconv_out(x)  # NOTE before `end_up`
727
728        # And the final output head.
729        x = torch.cat([x, z0], dim=1)
730        x = self.decoder_head(x)
731        x = self.out_conv(x)
732        if self.final_activation is not None:
733            x = self.final_activation(x)
734
735        # Postprocess the output back to original size.
736        x = self.postprocess_masks(x, input_shape, original_shape)
737        return x
738
739#
740#  ADDITIONAL FUNCTIONALITIES
741#
742
743
744def _strip_pooling_layers(enabled, channels) -> nn.Module:
745    return DepthStripPooling(channels) if enabled else nn.Identity()
746
747
748class DepthStripPooling(nn.Module):
749    """@private
750    """
751    def __init__(self, channels: int, reduction: int = 4):
752        """Block for strip pooling along the depth dimension (only).
753
754        eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling
755        to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to
756        modulate the original features (using a gated residual).
757
758        For 2D (Z == 1 or Z == 3): returns input unchanged (no-op).
759
760        Args:
761            channels: The output channels.
762            reduction: The reduction of the hidden layers.
763        """
764        super().__init__()
765        hidden = max(1, channels // reduction)
766        self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1)
767        self.bn1 = nn.BatchNorm3d(hidden)
768        self.relu = nn.ReLU(inplace=True)
769        self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1)
770
771    def forward(self, x: torch.Tensor) -> torch.Tensor:
772        if x.dim() != 5:
773            raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.")
774
775        B, C, Z, H, W = x.shape
776        if Z == 1 or Z == 3:  # i.e. 2d-as-1-slice or RGB_2d-as-1-slice.
777            return x  # We simply do nothing there.
778
779        # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W)
780        feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W))
781        feat = self.conv1(feat)
782        feat = self.bn1(feat)
783        feat = self.relu(feat)
784        feat = self.conv2(feat)
785        gate = torch.sigmoid(feat).expand(B, C, Z, H, W)  # Broadcast the collapsed depth context back to all slices
786
787        # Gated residual fusion
788        return x * gate + x
789
790
791class Deconv3DBlock(nn.Module):
792    """@private
793    """
794    def __init__(
795        self,
796        scale_factor,
797        in_channels,
798        out_channels,
799        kernel_size=3,
800        anisotropic_kernel=True,
801        use_strip_pooling=True,
802    ):
803        super().__init__()
804        conv_block_kwargs = {
805            "in_channels": out_channels,
806            "out_channels": out_channels,
807            "kernel_size": kernel_size,
808            "padding": ((kernel_size - 1) // 2),
809        }
810        if anisotropic_kernel:
811            conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor)
812
813        self.block = nn.Sequential(
814            Upsampler3d(scale_factor, in_channels, out_channels),
815            nn.Conv3d(**conv_block_kwargs),
816            nn.BatchNorm3d(out_channels),
817            nn.ReLU(True),
818            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
819        )
820
821    def forward(self, x):
822        return self.block(x)
823
824
825class ConvBlock3dWithStrip(nn.Module):
826    """@private
827    """
828    def __init__(
829        self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs
830    ):
831        super().__init__()
832        self.block = nn.Sequential(
833            ConvBlock3d(in_channels, out_channels, **kwargs),
834            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
835        )
836
837    def forward(self, x):
838        return self.block(x)
839
840
841class SingleDeconv2DBlock(nn.Module):
842    """@private
843    """
844    def __init__(self, scale_factor, in_channels, out_channels):
845        super().__init__()
846        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
847
848    def forward(self, x):
849        return self.block(x)
850
851
852class SingleConv2DBlock(nn.Module):
853    """@private
854    """
855    def __init__(self, in_channels, out_channels, kernel_size):
856        super().__init__()
857        self.block = nn.Conv2d(
858            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
859        )
860
861    def forward(self, x):
862        return self.block(x)
863
864
865class Conv2DBlock(nn.Module):
866    """@private
867    """
868    def __init__(self, in_channels, out_channels, kernel_size=3):
869        super().__init__()
870        self.block = nn.Sequential(
871            SingleConv2DBlock(in_channels, out_channels, kernel_size),
872            nn.BatchNorm2d(out_channels),
873            nn.ReLU(True)
874        )
875
876    def forward(self, x):
877        return self.block(x)
878
879
880class Deconv2DBlock(nn.Module):
881    """@private
882    """
883    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
884        super().__init__()
885        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
886        self.block = nn.Sequential(
887            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
888            SingleConv2DBlock(out_channels, out_channels, kernel_size),
889            nn.BatchNorm2d(out_channels),
890            nn.ReLU(True)
891        )
892
893    def forward(self, x):
894        return self.block(x)
class UNETRBase(torch.nn.modules.module.Module):
 34class UNETRBase(nn.Module):
 35    """Base class for implementing a UNETR.
 36
 37    Args:
 38        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 39        backbone: The name of the vision transformer implementation.
 40            One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"
 41            (see all combinations below)
 42        encoder: The vision transformer. Can either be a name, such as "vit_b"
 43            (see all combinations for this below) or a torch module.
 44        decoder: The convolutional decoder.
 45        out_channels: The number of output channels of the UNETR.
 46        use_sam_stats: Whether to normalize the input data with the statistics of the
 47            pretrained SAM / SAM2 / SAM3 model.
 48        use_dino_stats: Whether to normalize the input data with the statistics of the
 49            pretrained DINOv2 / DINOv3 model.
 50        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 51        resize_input: Whether to resize the input images to match `img_size`.
 52            By default, it resizes the inputs to match the `img_size`.
 53        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 54            Can either be a filepath or an already loaded checkpoint.
 55        final_activation: The activation to apply to the UNETR output.
 56        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 57        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 58        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 59            By default, it uses resampling for upsampling.
 60
 61        NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
 62
 63        SAM_family_models:
 64            - 'sam' x 'vit_b'
 65            - 'sam' x 'vit_l'
 66            - 'sam' x 'vit_h'
 67            - 'sam2' x 'hvit_t'
 68            - 'sam2' x 'hvit_s'
 69            - 'sam2' x 'hvit_b'
 70            - 'sam2' x 'hvit_l'
 71            - 'sam3' x 'vit_pe'
 72            - 'cellpose_sam' x 'vit_l'
 73
 74        DINO_family_models:
 75            - 'dinov2' x 'vit_s'
 76            - 'dinov2' x 'vit_b'
 77            - 'dinov2' x 'vit_l'
 78            - 'dinov2' x 'vit_g'
 79            - 'dinov2' x 'vit_s_reg4'
 80            - 'dinov2' x 'vit_b_reg4'
 81            - 'dinov2' x 'vit_l_reg4'
 82            - 'dinov2' x 'vit_g_reg4'
 83            - 'dinov3' x 'vit_s'
 84            - 'dinov3' x 'vit_s+'
 85            - 'dinov3' x 'vit_b'
 86            - 'dinov3' x 'vit_l'
 87            - 'dinov3' x 'vit_l+'
 88            - 'dinov3' x 'vit_h+'
 89            - 'dinov3' x 'vit_7b'
 90
 91        MAE_family_models:
 92            - 'mae' x 'vit_b'
 93            - 'mae' x 'vit_l'
 94            - 'mae' x 'vit_h'
 95            - 'scalemae' x 'vit_b'
 96            - 'scalemae' x 'vit_l'
 97            - 'scalemae' x 'vit_h'
 98    """
 99    def __init__(
100        self,
101        img_size: int = 1024,
102        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
103        encoder: Optional[Union[nn.Module, str]] = "vit_b",
104        decoder: Optional[nn.Module] = None,
105        out_channels: int = 1,
106        use_sam_stats: bool = False,
107        use_mae_stats: bool = False,
108        use_dino_stats: bool = False,
109        resize_input: bool = True,
110        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
111        final_activation: Optional[Union[str, nn.Module]] = None,
112        use_skip_connection: bool = True,
113        embed_dim: Optional[int] = None,
114        use_conv_transpose: bool = False,
115        **kwargs
116    ) -> None:
117        super().__init__()
118
119        self.img_size = img_size
120        self.use_sam_stats = use_sam_stats
121        self.use_mae_stats = use_mae_stats
122        self.use_dino_stats = use_dino_stats
123        self.use_skip_connection = use_skip_connection
124        self.resize_input = resize_input
125        self.use_conv_transpose = use_conv_transpose
126        self.backbone = backbone
127
128        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
129            print(f"Using {encoder} from {backbone.upper()}")
130            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
131
132            if encoder_checkpoint is not None:
133                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
134
135            if embed_dim is None:
136                embed_dim = self.encoder.embed_dim
137
138        else:  # `nn.Module` ViT backbone
139            self.encoder = encoder
140
141            have_neck = False
142            for name, _ in self.encoder.named_parameters():
143                if name.startswith("neck"):
144                    have_neck = True
145
146            if embed_dim is None:
147                if have_neck:
148                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
149                else:
150                    embed_dim = self.encoder.patch_embed.proj.out_channels
151
152        self.embed_dim = embed_dim
153        self.final_activation = self._get_activation(final_activation)
154
155    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
156        """Function to load pretrained weights to the image encoder.
157        """
158        if isinstance(checkpoint, str):
159            if backbone == "sam" and isinstance(encoder, str):
160                # If we have a SAM encoder, then we first try to load the full SAM Model
161                # (using micro_sam) and otherwise fall back on directly loading the encoder state
162                # from the checkpoint
163                try:
164                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
165                    encoder_state = model.image_encoder.state_dict()
166                except Exception:
167                    # Try loading the encoder state directly from a checkpoint.
168                    encoder_state = torch.load(checkpoint, weights_only=False)
169
170            elif backbone == "cellpose_sam" and isinstance(encoder, str):
171                # The architecture matches CellposeSAM exactly (same rel_pos sizes),
172                # so weights load directly without any interpolation.
173                encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False)
174                # Handle DataParallel/DistributedDataParallel prefix.
175                if any(k.startswith("module.") for k in encoder_state.keys()):
176                    encoder_state = OrderedDict(
177                        {k[len("module."):]: v for k, v in encoder_state.items()}
178                    )
179                # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix).
180                if any(k.startswith("encoder.") for k in encoder_state.keys()):
181                    encoder_state = OrderedDict(
182                        {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")}
183                    )
184
185            elif backbone == "sam2" and isinstance(encoder, str):
186                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
187                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
188                # from the checkpoint
189                try:
190                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
191                    encoder_state = model.image_encoder.state_dict()
192                except Exception:
193                    # Try loading the encoder state directly from a checkpoint.
194                    encoder_state = torch.load(checkpoint, weights_only=False)
195
196            elif backbone == "sam3" and isinstance(encoder, str):
197                # If we have a SAM3 encoder, then we first try to load the full SAM3 Model.
198                # (using micro_sam3) and otherwise fall back on directly loading the encoder state
199                # from the checkpoint
200                try:
201                    model = get_sam3_model(checkpoint_path=checkpoint)
202                    encoder_state = model.backbone.vision_backbone.state_dict()
203                    # Let's align loading the encoder weights with expected parameter names
204                    encoder_state = {
205                        k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items()
206                    }
207                    # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks.
208                    encoder_state = {
209                        k: v for k, v in encoder_state.items()
210                        if not (k.startswith("convs.") or k.startswith("sam2_convs."))
211                    }
212                except Exception:
213                    # Try loading the encoder state directly from a checkpoint.
214                    encoder_state = torch.load(checkpoint, weights_only=False)
215
216            elif backbone == "mae":
217                # vit initialization hints from:
218                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
219                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
220                encoder_state = OrderedDict({
221                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
222                })
223                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
224                current_encoder_state = self.encoder.state_dict()
225                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
226                    del self.encoder.head
227
228            elif backbone == "scalemae":
229                # Load the encoder state directly from a checkpoint.
230                encoder_state = torch.load(checkpoint)["model"]
231                encoder_state = OrderedDict({
232                    k: v for k, v in encoder_state.items()
233                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
234                })
235
236                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
237                current_encoder_state = self.encoder.state_dict()
238                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
239                    del self.encoder.head
240
241                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
242                    del self.encoder.pos_embed
243
244            elif backbone in ["dinov2", "dinov3"]:  # Load the encoder state directly from a checkpoint.
245                encoder_state = torch.load(checkpoint)
246
247            else:
248                raise ValueError(
249                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
250                )
251
252        else:
253            encoder_state = checkpoint
254
255        self.encoder.load_state_dict(encoder_state)
256
257    def _get_activation(self, activation):
258        return_activation = None
259        if activation is None:
260            return None
261        if isinstance(activation, nn.Module):
262            return activation
263        if isinstance(activation, str):
264            return_activation = getattr(nn, activation, None)
265        if return_activation is None:
266            raise ValueError(f"Invalid activation: {activation}")
267
268        return return_activation()
269
270    @staticmethod
271    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
272        """Compute the output size given input size and target long side length.
273
274        Args:
275            oldh: The input image height.
276            oldw: The input image width.
277            long_side_length: The longest side length for resizing.
278
279        Returns:
280            The new image height.
281            The new image width.
282        """
283        scale = long_side_length * 1.0 / max(oldh, oldw)
284        newh, neww = oldh * scale, oldw * scale
285        neww = int(neww + 0.5)
286        newh = int(newh + 0.5)
287        return (newh, neww)
288
289    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
290        """Resize the image so that the longest side has the correct length.
291
292        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
293
294        Args:
295            image: The input image.
296
297        Returns:
298            The resized image.
299        """
300        if image.ndim == 4:  # i.e. 2d image
301            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
302            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
303        elif image.ndim == 5:  # i.e. 3d volume
304            B, C, Z, H, W = image.shape
305            target_size = self.get_preprocess_shape(H, W, self.img_size)
306            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
307        else:
308            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
309
310    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
311        """@private
312        """
313        # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1).
314        view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
315        pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
316        pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
317        return pixel_mean, pixel_std
318
319    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
320        """@private
321        """
322        device = x.device
323        is_3d = (x.ndim == 5)
324        device, dtype = x.device, x.dtype
325
326        if self.use_sam_stats:
327            mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
328        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
329            raise NotImplementedError
330        elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"):
331            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
332        elif self.use_sam_stats and self.backbone == "sam3":
333            mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
334        else:
335            mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
336
337        pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
338
339        if self.resize_input:
340            x = self.resize_longest_side(x)
341        input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
342
343        x = (x - pixel_mean) / pixel_std
344        h, w = x.shape[-2:]
345        padh = self.encoder.img_size - h
346        padw = self.encoder.img_size - w
347
348        if is_3d:
349            x = F.pad(x, (0, padw, 0, padh, 0, 0))
350        else:
351            x = F.pad(x, (0, padw, 0, padh))
352
353        return x, input_shape
354
355    def postprocess_masks(
356        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
357    ) -> torch.Tensor:
358        """@private
359        """
360        if masks.ndim == 4:  # i.e. 2d labels
361            masks = F.interpolate(
362                masks,
363                (self.encoder.img_size, self.encoder.img_size),
364                mode="bilinear",
365                align_corners=False,
366            )
367            masks = masks[..., : input_size[0], : input_size[1]]
368            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
369
370        elif masks.ndim == 5:  # i.e. 3d volumetric labels
371            masks = F.interpolate(
372                masks,
373                (input_size[0], self.img_size, self.img_size),
374                mode="trilinear",
375                align_corners=False,
376            )
377            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
378            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
379
380        else:
381            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
382
383        return masks

Base class for implementing a UNETR.

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • backbone: The name of the vision transformer implementation. One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
  • encoder: The vision transformer. Can either be a name, such as "vit_b" (see all combinations for this below) or a torch module.
  • decoder: The convolutional decoder.
  • out_channels: The number of output channels of the UNETR.
  • use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM / SAM2 / SAM3 model.
  • use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv2 / DINOv3 model.
  • use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  • resize_input: Whether to resize the input images to match img_size. By default, it resizes the inputs to match the img_size.
  • encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
  • final_activation: The activation to apply to the UNETR output.
  • use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
  • embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
  • use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
  • NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
  • SAM_family_models: - 'sam' x 'vit_b'
    • 'sam' x 'vit_l'
    • 'sam' x 'vit_h'
    • 'sam2' x 'hvit_t'
    • 'sam2' x 'hvit_s'
    • 'sam2' x 'hvit_b'
    • 'sam2' x 'hvit_l'
    • 'sam3' x 'vit_pe'
    • 'cellpose_sam' x 'vit_l'
  • DINO_family_models: - 'dinov2' x 'vit_s'
    • 'dinov2' x 'vit_b'
    • 'dinov2' x 'vit_l'
    • 'dinov2' x 'vit_g'
    • 'dinov2' x 'vit_s_reg4'
    • 'dinov2' x 'vit_b_reg4'
    • 'dinov2' x 'vit_l_reg4'
    • 'dinov2' x 'vit_g_reg4'
    • 'dinov3' x 'vit_s'
    • 'dinov3' x 'vit_s+'
    • 'dinov3' x 'vit_b'
    • 'dinov3' x 'vit_l'
    • 'dinov3' x 'vit_l+'
    • 'dinov3' x 'vit_h+'
    • 'dinov3' x 'vit_7b'
  • MAE_family_models: - 'mae' x 'vit_b'
    • 'mae' x 'vit_l'
    • 'mae' x 'vit_h'
    • 'scalemae' x 'vit_b'
    • 'scalemae' x 'vit_l'
    • 'scalemae' x 'vit_h'
UNETRBase( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
 99    def __init__(
100        self,
101        img_size: int = 1024,
102        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
103        encoder: Optional[Union[nn.Module, str]] = "vit_b",
104        decoder: Optional[nn.Module] = None,
105        out_channels: int = 1,
106        use_sam_stats: bool = False,
107        use_mae_stats: bool = False,
108        use_dino_stats: bool = False,
109        resize_input: bool = True,
110        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
111        final_activation: Optional[Union[str, nn.Module]] = None,
112        use_skip_connection: bool = True,
113        embed_dim: Optional[int] = None,
114        use_conv_transpose: bool = False,
115        **kwargs
116    ) -> None:
117        super().__init__()
118
119        self.img_size = img_size
120        self.use_sam_stats = use_sam_stats
121        self.use_mae_stats = use_mae_stats
122        self.use_dino_stats = use_dino_stats
123        self.use_skip_connection = use_skip_connection
124        self.resize_input = resize_input
125        self.use_conv_transpose = use_conv_transpose
126        self.backbone = backbone
127
128        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
129            print(f"Using {encoder} from {backbone.upper()}")
130            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
131
132            if encoder_checkpoint is not None:
133                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
134
135            if embed_dim is None:
136                embed_dim = self.encoder.embed_dim
137
138        else:  # `nn.Module` ViT backbone
139            self.encoder = encoder
140
141            have_neck = False
142            for name, _ in self.encoder.named_parameters():
143                if name.startswith("neck"):
144                    have_neck = True
145
146            if embed_dim is None:
147                if have_neck:
148                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
149                else:
150                    embed_dim = self.encoder.patch_embed.proj.out_channels
151
152        self.embed_dim = embed_dim
153        self.final_activation = self._get_activation(final_activation)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

img_size
use_sam_stats
use_mae_stats
use_dino_stats
use_skip_connection
resize_input
use_conv_transpose
backbone
embed_dim
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
270    @staticmethod
271    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
272        """Compute the output size given input size and target long side length.
273
274        Args:
275            oldh: The input image height.
276            oldw: The input image width.
277            long_side_length: The longest side length for resizing.
278
279        Returns:
280            The new image height.
281            The new image width.
282        """
283        scale = long_side_length * 1.0 / max(oldh, oldw)
284        newh, neww = oldh * scale, oldw * scale
285        neww = int(neww + 0.5)
286        newh = int(newh + 0.5)
287        return (newh, neww)

Compute the output size given input size and target long side length.

Arguments:
  • oldh: The input image height.
  • oldw: The input image width.
  • long_side_length: The longest side length for resizing.
Returns:

The new image height. The new image width.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
289    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
290        """Resize the image so that the longest side has the correct length.
291
292        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
293
294        Args:
295            image: The input image.
296
297        Returns:
298            The resized image.
299        """
300        if image.ndim == 4:  # i.e. 2d image
301            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
302            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
303        elif image.ndim == 5:  # i.e. 3d volume
304            B, C, Z, H, W = image.shape
305            target_size = self.get_preprocess_shape(H, W, self.img_size)
306            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
307        else:
308            raise ValueError("Expected 4d or 5d inputs, got", image.shape)

Resize the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.

Arguments:
  • image: The input image.
Returns:

The resized image.

class UNETR(UNETRBase):
386class UNETR(UNETRBase):
387    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
388    """
389    def __init__(
390        self,
391        img_size: int = 1024,
392        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
393        encoder: Optional[Union[nn.Module, str]] = "vit_b",
394        decoder: Optional[nn.Module] = None,
395        out_channels: int = 1,
396        use_sam_stats: bool = False,
397        use_mae_stats: bool = False,
398        use_dino_stats: bool = False,
399        resize_input: bool = True,
400        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
401        final_activation: Optional[Union[str, nn.Module]] = None,
402        use_skip_connection: bool = True,
403        embed_dim: Optional[int] = None,
404        use_conv_transpose: bool = False,
405        **kwargs
406    ) -> None:
407
408        super().__init__(
409            img_size=img_size,
410            backbone=backbone,
411            encoder=encoder,
412            decoder=decoder,
413            out_channels=out_channels,
414            use_sam_stats=use_sam_stats,
415            use_mae_stats=use_mae_stats,
416            use_dino_stats=use_dino_stats,
417            resize_input=resize_input,
418            encoder_checkpoint=encoder_checkpoint,
419            final_activation=final_activation,
420            use_skip_connection=use_skip_connection,
421            embed_dim=embed_dim,
422            use_conv_transpose=use_conv_transpose,
423            **kwargs,
424        )
425
426        encoder = self.encoder
427
428        if backbone == "sam2" and hasattr(encoder, "trunk"):
429            in_chans = encoder.trunk.patch_embed.proj.in_channels
430        elif hasattr(encoder, "in_chans"):
431            in_chans = encoder.in_chans
432        else:  # `nn.Module` ViT backbone.
433            try:
434                in_chans = encoder.patch_embed.proj.in_channels
435            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
436                in_chans = encoder.patch_embed.seq[0].c.in_channels
437
438        # parameters for the decoder network
439        depth = 3
440        initial_features = 64
441        gain = 2
442        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
443        scale_factors = depth * [2]
444        self.out_channels = out_channels
445
446        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
447        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
448
449        self.decoder = decoder or Decoder(
450            features=features_decoder,
451            scale_factors=scale_factors[::-1],
452            conv_block_impl=ConvBlock2d,
453            sampler_impl=_upsampler,
454        )
455
456        if use_skip_connection:
457            self.deconv1 = Deconv2DBlock(
458                in_channels=self.embed_dim,
459                out_channels=features_decoder[0],
460                use_conv_transpose=use_conv_transpose,
461            )
462            self.deconv2 = nn.Sequential(
463                Deconv2DBlock(
464                    in_channels=self.embed_dim,
465                    out_channels=features_decoder[0],
466                    use_conv_transpose=use_conv_transpose,
467                ),
468                Deconv2DBlock(
469                    in_channels=features_decoder[0],
470                    out_channels=features_decoder[1],
471                    use_conv_transpose=use_conv_transpose,
472                )
473            )
474            self.deconv3 = nn.Sequential(
475                Deconv2DBlock(
476                    in_channels=self.embed_dim,
477                    out_channels=features_decoder[0],
478                    use_conv_transpose=use_conv_transpose,
479                ),
480                Deconv2DBlock(
481                    in_channels=features_decoder[0],
482                    out_channels=features_decoder[1],
483                    use_conv_transpose=use_conv_transpose,
484                ),
485                Deconv2DBlock(
486                    in_channels=features_decoder[1],
487                    out_channels=features_decoder[2],
488                    use_conv_transpose=use_conv_transpose,
489                )
490            )
491            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
492        else:
493            self.deconv1 = Deconv2DBlock(
494                in_channels=self.embed_dim,
495                out_channels=features_decoder[0],
496                use_conv_transpose=use_conv_transpose,
497            )
498            self.deconv2 = Deconv2DBlock(
499                in_channels=features_decoder[0],
500                out_channels=features_decoder[1],
501                use_conv_transpose=use_conv_transpose,
502            )
503            self.deconv3 = Deconv2DBlock(
504                in_channels=features_decoder[1],
505                out_channels=features_decoder[2],
506                use_conv_transpose=use_conv_transpose,
507            )
508            self.deconv4 = Deconv2DBlock(
509                in_channels=features_decoder[2],
510                out_channels=features_decoder[3],
511                use_conv_transpose=use_conv_transpose,
512            )
513
514        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
515        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
516        self.deconv_out = _upsampler(
517            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
518        )
519        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
520
521    def forward(self, x: torch.Tensor) -> torch.Tensor:
522        """Apply the UNETR to the input data.
523
524        Args:
525            x: The input tensor.
526
527        Returns:
528            The UNETR output.
529        """
530        original_shape = x.shape[-2:]
531
532        # Reshape the inputs to the shape expected by the encoder
533        # and normalize the inputs if normalization is part of the model.
534        x, input_shape = self.preprocess(x)
535
536        encoder_outputs = self.encoder(x)
537
538        if isinstance(encoder_outputs[-1], list):
539            # `encoder_outputs` can be arranged in only two forms:
540            #   - either we only return the image embeddings
541            #   - or, we return the image embeddings and the "list" of global attention layers
542            z12, from_encoder = encoder_outputs
543        else:
544            z12 = encoder_outputs
545
546        if self.use_skip_connection:
547            from_encoder = from_encoder[::-1]
548            z9 = self.deconv1(from_encoder[0])
549            z6 = self.deconv2(from_encoder[1])
550            z3 = self.deconv3(from_encoder[2])
551            z0 = self.deconv4(x)
552
553        else:
554            z9 = self.deconv1(z12)
555            z6 = self.deconv2(z9)
556            z3 = self.deconv3(z6)
557            z0 = self.deconv4(z3)
558
559        updated_from_encoder = [z9, z6, z3]
560
561        x = self.base(z12)
562        x = self.decoder(x, encoder_inputs=updated_from_encoder)
563        x = self.deconv_out(x)
564
565        x = torch.cat([x, z0], dim=1)
566        x = self.decoder_head(x)
567
568        x = self.out_conv(x)
569        if self.final_activation is not None:
570            x = self.final_activation(x)
571
572        x = self.postprocess_masks(x, input_shape, original_shape)
573        return x

A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
389    def __init__(
390        self,
391        img_size: int = 1024,
392        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
393        encoder: Optional[Union[nn.Module, str]] = "vit_b",
394        decoder: Optional[nn.Module] = None,
395        out_channels: int = 1,
396        use_sam_stats: bool = False,
397        use_mae_stats: bool = False,
398        use_dino_stats: bool = False,
399        resize_input: bool = True,
400        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
401        final_activation: Optional[Union[str, nn.Module]] = None,
402        use_skip_connection: bool = True,
403        embed_dim: Optional[int] = None,
404        use_conv_transpose: bool = False,
405        **kwargs
406    ) -> None:
407
408        super().__init__(
409            img_size=img_size,
410            backbone=backbone,
411            encoder=encoder,
412            decoder=decoder,
413            out_channels=out_channels,
414            use_sam_stats=use_sam_stats,
415            use_mae_stats=use_mae_stats,
416            use_dino_stats=use_dino_stats,
417            resize_input=resize_input,
418            encoder_checkpoint=encoder_checkpoint,
419            final_activation=final_activation,
420            use_skip_connection=use_skip_connection,
421            embed_dim=embed_dim,
422            use_conv_transpose=use_conv_transpose,
423            **kwargs,
424        )
425
426        encoder = self.encoder
427
428        if backbone == "sam2" and hasattr(encoder, "trunk"):
429            in_chans = encoder.trunk.patch_embed.proj.in_channels
430        elif hasattr(encoder, "in_chans"):
431            in_chans = encoder.in_chans
432        else:  # `nn.Module` ViT backbone.
433            try:
434                in_chans = encoder.patch_embed.proj.in_channels
435            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
436                in_chans = encoder.patch_embed.seq[0].c.in_channels
437
438        # parameters for the decoder network
439        depth = 3
440        initial_features = 64
441        gain = 2
442        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
443        scale_factors = depth * [2]
444        self.out_channels = out_channels
445
446        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
447        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
448
449        self.decoder = decoder or Decoder(
450            features=features_decoder,
451            scale_factors=scale_factors[::-1],
452            conv_block_impl=ConvBlock2d,
453            sampler_impl=_upsampler,
454        )
455
456        if use_skip_connection:
457            self.deconv1 = Deconv2DBlock(
458                in_channels=self.embed_dim,
459                out_channels=features_decoder[0],
460                use_conv_transpose=use_conv_transpose,
461            )
462            self.deconv2 = nn.Sequential(
463                Deconv2DBlock(
464                    in_channels=self.embed_dim,
465                    out_channels=features_decoder[0],
466                    use_conv_transpose=use_conv_transpose,
467                ),
468                Deconv2DBlock(
469                    in_channels=features_decoder[0],
470                    out_channels=features_decoder[1],
471                    use_conv_transpose=use_conv_transpose,
472                )
473            )
474            self.deconv3 = nn.Sequential(
475                Deconv2DBlock(
476                    in_channels=self.embed_dim,
477                    out_channels=features_decoder[0],
478                    use_conv_transpose=use_conv_transpose,
479                ),
480                Deconv2DBlock(
481                    in_channels=features_decoder[0],
482                    out_channels=features_decoder[1],
483                    use_conv_transpose=use_conv_transpose,
484                ),
485                Deconv2DBlock(
486                    in_channels=features_decoder[1],
487                    out_channels=features_decoder[2],
488                    use_conv_transpose=use_conv_transpose,
489                )
490            )
491            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
492        else:
493            self.deconv1 = Deconv2DBlock(
494                in_channels=self.embed_dim,
495                out_channels=features_decoder[0],
496                use_conv_transpose=use_conv_transpose,
497            )
498            self.deconv2 = Deconv2DBlock(
499                in_channels=features_decoder[0],
500                out_channels=features_decoder[1],
501                use_conv_transpose=use_conv_transpose,
502            )
503            self.deconv3 = Deconv2DBlock(
504                in_channels=features_decoder[1],
505                out_channels=features_decoder[2],
506                use_conv_transpose=use_conv_transpose,
507            )
508            self.deconv4 = Deconv2DBlock(
509                in_channels=features_decoder[2],
510                out_channels=features_decoder[3],
511                use_conv_transpose=use_conv_transpose,
512            )
513
514        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
515        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
516        self.deconv_out = _upsampler(
517            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
518        )
519        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
decoder
base
out_conv
deconv_out
decoder_head
def forward(self, x: torch.Tensor) -> torch.Tensor:
521    def forward(self, x: torch.Tensor) -> torch.Tensor:
522        """Apply the UNETR to the input data.
523
524        Args:
525            x: The input tensor.
526
527        Returns:
528            The UNETR output.
529        """
530        original_shape = x.shape[-2:]
531
532        # Reshape the inputs to the shape expected by the encoder
533        # and normalize the inputs if normalization is part of the model.
534        x, input_shape = self.preprocess(x)
535
536        encoder_outputs = self.encoder(x)
537
538        if isinstance(encoder_outputs[-1], list):
539            # `encoder_outputs` can be arranged in only two forms:
540            #   - either we only return the image embeddings
541            #   - or, we return the image embeddings and the "list" of global attention layers
542            z12, from_encoder = encoder_outputs
543        else:
544            z12 = encoder_outputs
545
546        if self.use_skip_connection:
547            from_encoder = from_encoder[::-1]
548            z9 = self.deconv1(from_encoder[0])
549            z6 = self.deconv2(from_encoder[1])
550            z3 = self.deconv3(from_encoder[2])
551            z0 = self.deconv4(x)
552
553        else:
554            z9 = self.deconv1(z12)
555            z6 = self.deconv2(z9)
556            z3 = self.deconv3(z6)
557            z0 = self.deconv4(z3)
558
559        updated_from_encoder = [z9, z6, z3]
560
561        x = self.base(z12)
562        x = self.decoder(x, encoder_inputs=updated_from_encoder)
563        x = self.deconv_out(x)
564
565        x = torch.cat([x, z0], dim=1)
566        x = self.decoder_head(x)
567
568        x = self.out_conv(x)
569        if self.final_activation is not None:
570            x = self.final_activation(x)
571
572        x = self.postprocess_masks(x, input_shape, original_shape)
573        return x

Apply the UNETR to the input data.

Arguments:
  • x: The input tensor.
Returns:

The UNETR output.

class UNETR2D(UNETR):
576class UNETR2D(UNETR):
577    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
578    """
579    pass

A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

class UNETR3D(UNETRBase):
582class UNETR3D(UNETRBase):
583    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
584    """
585    def __init__(
586        self,
587        img_size: int = 1024,
588        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
589        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
590        decoder: Optional[nn.Module] = None,
591        out_channels: int = 1,
592        use_sam_stats: bool = False,
593        use_mae_stats: bool = False,
594        use_dino_stats: bool = False,
595        resize_input: bool = True,
596        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
597        final_activation: Optional[Union[str, nn.Module]] = None,
598        use_skip_connection: bool = False,
599        embed_dim: Optional[int] = None,
600        use_conv_transpose: bool = False,
601        use_strip_pooling: bool = True,
602        **kwargs
603    ):
604        if use_skip_connection:
605            raise NotImplementedError("The framework cannot handle skip connections atm.")
606        if use_conv_transpose:
607            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
608
609        # Sort the `embed_dim` out
610        embed_dim = 256 if embed_dim is None else embed_dim
611
612        super().__init__(
613            img_size=img_size,
614            backbone=backbone,
615            encoder=encoder,
616            decoder=decoder,
617            out_channels=out_channels,
618            use_sam_stats=use_sam_stats,
619            use_mae_stats=use_mae_stats,
620            use_dino_stats=use_dino_stats,
621            resize_input=resize_input,
622            encoder_checkpoint=encoder_checkpoint,
623            final_activation=final_activation,
624            use_skip_connection=use_skip_connection,
625            embed_dim=embed_dim,
626            use_conv_transpose=use_conv_transpose,
627            **kwargs,
628        )
629
630        # The 3d convolutional decoder.
631        # First, get the important parameters for the decoder.
632        depth = 3
633        initial_features = 64
634        gain = 2
635        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
636        scale_factors = [1, 2, 2]
637        self.out_channels = out_channels
638
639        # The mapping blocks.
640        self.deconv1 = Deconv3DBlock(
641            in_channels=embed_dim,
642            out_channels=features_decoder[0],
643            scale_factor=scale_factors,
644            use_strip_pooling=use_strip_pooling,
645        )
646        self.deconv2 = Deconv3DBlock(
647            in_channels=features_decoder[0],
648            out_channels=features_decoder[1],
649            scale_factor=scale_factors,
650            use_strip_pooling=use_strip_pooling,
651        )
652        self.deconv3 = Deconv3DBlock(
653            in_channels=features_decoder[1],
654            out_channels=features_decoder[2],
655            scale_factor=scale_factors,
656            use_strip_pooling=use_strip_pooling,
657        )
658        self.deconv4 = Deconv3DBlock(
659            in_channels=features_decoder[2],
660            out_channels=features_decoder[3],
661            scale_factor=scale_factors,
662            use_strip_pooling=use_strip_pooling,
663        )
664
665        # The core decoder block.
666        self.decoder = decoder or Decoder(
667            features=features_decoder,
668            scale_factors=[scale_factors] * depth,
669            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
670            sampler_impl=Upsampler3d,
671        )
672
673        # And the final upsampler to match the expected dimensions.
674        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
675            in_channels=features_decoder[-1],
676            out_channels=features_decoder[-1],
677            scale_factor=scale_factors,
678            use_strip_pooling=use_strip_pooling,
679        )
680
681        # Additional conjunction blocks.
682        self.base = ConvBlock3dWithStrip(
683            in_channels=embed_dim,
684            out_channels=features_decoder[0],
685            use_strip_pooling=use_strip_pooling,
686        )
687
688        # And the output layers.
689        self.decoder_head = ConvBlock3dWithStrip(
690            in_channels=2 * features_decoder[-1],
691            out_channels=features_decoder[-1],
692            use_strip_pooling=use_strip_pooling,
693        )
694        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
695
696    def forward(self, x: torch.Tensor):
697        """Forward pass of the UNETR-3D model.
698
699        Args:
700            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
701
702        Returns:
703            The UNETR output.
704        """
705        B, C, Z, H, W = x.shape
706        original_shape = (Z, H, W)
707
708        # Preprocessing step
709        x, input_shape = self.preprocess(x)
710
711        # Run the image encoder.
712        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
713
714        # Prepare the counterparts for the decoder.
715        # NOTE: The section below is sequential, there's no skip connections atm.
716        z9 = self.deconv1(curr_features)
717        z6 = self.deconv2(z9)
718        z3 = self.deconv3(z6)
719        z0 = self.deconv4(z3)
720
721        updated_from_encoder = [z9, z6, z3]
722
723        # Align the features through the base block.
724        x = self.base(curr_features)
725        # Run the decoder
726        x = self.decoder(x, encoder_inputs=updated_from_encoder)
727        x = self.deconv_out(x)  # NOTE before `end_up`
728
729        # And the final output head.
730        x = torch.cat([x, z0], dim=1)
731        x = self.decoder_head(x)
732        x = self.out_conv(x)
733        if self.final_activation is not None:
734            x = self.final_activation(x)
735
736        # Postprocess the output back to original size.
737        x = self.postprocess_masks(x, input_shape, original_shape)
738        return x

A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR3D( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'hvit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = False, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, use_strip_pooling: bool = True, **kwargs)
585    def __init__(
586        self,
587        img_size: int = 1024,
588        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
589        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
590        decoder: Optional[nn.Module] = None,
591        out_channels: int = 1,
592        use_sam_stats: bool = False,
593        use_mae_stats: bool = False,
594        use_dino_stats: bool = False,
595        resize_input: bool = True,
596        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
597        final_activation: Optional[Union[str, nn.Module]] = None,
598        use_skip_connection: bool = False,
599        embed_dim: Optional[int] = None,
600        use_conv_transpose: bool = False,
601        use_strip_pooling: bool = True,
602        **kwargs
603    ):
604        if use_skip_connection:
605            raise NotImplementedError("The framework cannot handle skip connections atm.")
606        if use_conv_transpose:
607            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
608
609        # Sort the `embed_dim` out
610        embed_dim = 256 if embed_dim is None else embed_dim
611
612        super().__init__(
613            img_size=img_size,
614            backbone=backbone,
615            encoder=encoder,
616            decoder=decoder,
617            out_channels=out_channels,
618            use_sam_stats=use_sam_stats,
619            use_mae_stats=use_mae_stats,
620            use_dino_stats=use_dino_stats,
621            resize_input=resize_input,
622            encoder_checkpoint=encoder_checkpoint,
623            final_activation=final_activation,
624            use_skip_connection=use_skip_connection,
625            embed_dim=embed_dim,
626            use_conv_transpose=use_conv_transpose,
627            **kwargs,
628        )
629
630        # The 3d convolutional decoder.
631        # First, get the important parameters for the decoder.
632        depth = 3
633        initial_features = 64
634        gain = 2
635        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
636        scale_factors = [1, 2, 2]
637        self.out_channels = out_channels
638
639        # The mapping blocks.
640        self.deconv1 = Deconv3DBlock(
641            in_channels=embed_dim,
642            out_channels=features_decoder[0],
643            scale_factor=scale_factors,
644            use_strip_pooling=use_strip_pooling,
645        )
646        self.deconv2 = Deconv3DBlock(
647            in_channels=features_decoder[0],
648            out_channels=features_decoder[1],
649            scale_factor=scale_factors,
650            use_strip_pooling=use_strip_pooling,
651        )
652        self.deconv3 = Deconv3DBlock(
653            in_channels=features_decoder[1],
654            out_channels=features_decoder[2],
655            scale_factor=scale_factors,
656            use_strip_pooling=use_strip_pooling,
657        )
658        self.deconv4 = Deconv3DBlock(
659            in_channels=features_decoder[2],
660            out_channels=features_decoder[3],
661            scale_factor=scale_factors,
662            use_strip_pooling=use_strip_pooling,
663        )
664
665        # The core decoder block.
666        self.decoder = decoder or Decoder(
667            features=features_decoder,
668            scale_factors=[scale_factors] * depth,
669            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
670            sampler_impl=Upsampler3d,
671        )
672
673        # And the final upsampler to match the expected dimensions.
674        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
675            in_channels=features_decoder[-1],
676            out_channels=features_decoder[-1],
677            scale_factor=scale_factors,
678            use_strip_pooling=use_strip_pooling,
679        )
680
681        # Additional conjunction blocks.
682        self.base = ConvBlock3dWithStrip(
683            in_channels=embed_dim,
684            out_channels=features_decoder[0],
685            use_strip_pooling=use_strip_pooling,
686        )
687
688        # And the output layers.
689        self.decoder_head = ConvBlock3dWithStrip(
690            in_channels=2 * features_decoder[-1],
691            out_channels=features_decoder[-1],
692            use_strip_pooling=use_strip_pooling,
693        )
694        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
deconv1
deconv2
deconv3
deconv4
decoder
deconv_out
base
decoder_head
out_conv
def forward(self, x: torch.Tensor):
696    def forward(self, x: torch.Tensor):
697        """Forward pass of the UNETR-3D model.
698
699        Args:
700            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
701
702        Returns:
703            The UNETR output.
704        """
705        B, C, Z, H, W = x.shape
706        original_shape = (Z, H, W)
707
708        # Preprocessing step
709        x, input_shape = self.preprocess(x)
710
711        # Run the image encoder.
712        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
713
714        # Prepare the counterparts for the decoder.
715        # NOTE: The section below is sequential, there's no skip connections atm.
716        z9 = self.deconv1(curr_features)
717        z6 = self.deconv2(z9)
718        z3 = self.deconv3(z6)
719        z0 = self.deconv4(z3)
720
721        updated_from_encoder = [z9, z6, z3]
722
723        # Align the features through the base block.
724        x = self.base(curr_features)
725        # Run the decoder
726        x = self.decoder(x, encoder_inputs=updated_from_encoder)
727        x = self.deconv_out(x)  # NOTE before `end_up`
728
729        # And the final output head.
730        x = torch.cat([x, z0], dim=1)
731        x = self.decoder_head(x)
732        x = self.out_conv(x)
733        if self.final_activation is not None:
734            x = self.final_activation(x)
735
736        # Postprocess the output back to original size.
737        x = self.postprocess_masks(x, input_shape, original_shape)
738        return x

Forward pass of the UNETR-3D model.

Arguments:
  • x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:

The UNETR output.