torch_em.model.unetr

   1from functools import partial
   2from collections import OrderedDict
   3from typing import Optional, Tuple, Union, Literal
   4
   5import torch
   6import torch.nn as nn
   7import torch.nn.functional as F
   8
   9from .vit import get_vision_transformer
  10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs
  11
  12try:
  13    from micro_sam.util import get_sam_model
  14except ImportError:
  15    get_sam_model = None
  16
  17try:
  18    from micro_sam.v2.util import get_sam2_model
  19except ImportError:
  20    get_sam2_model = None
  21
  22try:
  23    from micro_sam3.util import get_sam3_model
  24except ImportError:
  25    get_sam3_model = None
  26
  27
  28#
  29# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / CellposeSAM / SAM2 / SAM3 / DINOv2 / DINOv3 / MAE / ScaleMAE) + UNet Decoder from `torch_em`]  # noqa
  30#
  31
  32
  33def _check_input_normalization_range(
  34    x: torch.Tensor,
  35    expected_range: Optional[Tuple[float, float]],
  36    unit_scale_max: Optional[float] = None,
  37) -> None:
  38    """Check whether raw inputs match the value range expected by the model normalizer.
  39
  40    Args:
  41        x: The input tensor to validate.
  42        expected_range: The (min, max) value range the input must lie within. If None, all checks are skipped.
  43        unit_scale_max: If set, raises an error when the input's maximum value is at or below this threshold,
  44            catching inputs that are likely in the wrong scale (e.g. [0, 1] instead of [0, 255] for SAM1).
  45    """
  46    if expected_range is None:
  47        return
  48
  49    if not torch.all(torch.isfinite(x)):
  50        raise ValueError("The input contains NaN or infinite values before normalization.")
  51
  52    min_value, max_value = expected_range
  53    if torch.any((x < min_value) | (x > max_value)):
  54        actual_min, actual_max = torch.aminmax(x.detach())
  55        raise ValueError(
  56            "The input is outside the expected scale before normalization: "
  57            f"expected values in [{min_value}, {max_value}], got [{actual_min.item()}, {actual_max.item()}]. "
  58            "Please check whether the raw inputs should be scaled to [0, 1] or kept in [0, 255] "
  59            "before applying the pretrained normalization statistics."
  60        )
  61
  62    if unit_scale_max is not None:
  63        actual_max = x.detach().max().item()
  64        if actual_max <= unit_scale_max:
  65            raise ValueError(
  66                f"The input maximum value ({actual_max:.4f}) suggests the input is in the wrong scale: "
  67                f"expected inputs with values in [{min_value}, {max_value}], "
  68                f"but the maximum is only {actual_max:.4f}. "
  69                "Please check whether the raw inputs should be scaled to [0, 255] instead of [0, 1]."
  70            )
  71
  72
  73def _as_stats(mean, std, device, dtype, is_3d: bool):
  74    view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
  75    pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
  76    pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
  77    return pixel_mean, pixel_std
  78
  79
  80class UNETRBase(nn.Module):
  81    """Base class for implementing a UNETR.
  82
  83    Args:
  84        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  85        backbone: The name of the vision transformer implementation.
  86            One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"
  87            (see all combinations below)
  88        encoder: The vision transformer. Can either be a name, such as "vit_b"
  89            (see all combinations for this below) or a torch module.
  90        decoder: The convolutional decoder.
  91        out_channels: The number of output channels of the UNETR.
  92        use_sam_stats: Whether to normalize the input data with the statistics of the
  93            pretrained SAM / SAM2 / SAM3 model.
  94        use_dino_stats: Whether to normalize the input data with the statistics of the
  95            pretrained DINOv2 / DINOv3 model.
  96        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  97        resize_input: Whether to resize the input images to match `img_size`.
  98            By default, it resizes the inputs to match the `img_size`.
  99        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 100            Can either be a filepath or an already loaded checkpoint.
 101        final_activation: The activation to apply to the UNETR output.
 102        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 103        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 104        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 105            By default, it uses resampling for upsampling.
 106        perform_range_checks: Whether to validate the input value range before normalization on each forward pass.
 107            You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct.
 108
 109        NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
 110
 111        SAM_family_models:
 112            - 'sam' x 'vit_b'
 113            - 'sam' x 'vit_l'
 114            - 'sam' x 'vit_h'
 115            - 'sam2' x 'hvit_t'
 116            - 'sam2' x 'hvit_s'
 117            - 'sam2' x 'hvit_b'
 118            - 'sam2' x 'hvit_l'
 119            - 'sam3' x 'vit_pe'
 120            - 'cellpose_sam' x 'vit_l'
 121
 122        DINO_family_models:
 123            - 'dinov2' x 'vit_s'
 124            - 'dinov2' x 'vit_b'
 125            - 'dinov2' x 'vit_l'
 126            - 'dinov2' x 'vit_g'
 127            - 'dinov2' x 'vit_s_reg4'
 128            - 'dinov2' x 'vit_b_reg4'
 129            - 'dinov2' x 'vit_l_reg4'
 130            - 'dinov2' x 'vit_g_reg4'
 131            - 'dinov3' x 'vit_s'
 132            - 'dinov3' x 'vit_s+'
 133            - 'dinov3' x 'vit_b'
 134            - 'dinov3' x 'vit_l'
 135            - 'dinov3' x 'vit_l+'
 136            - 'dinov3' x 'vit_h+'
 137            - 'dinov3' x 'vit_7b'
 138
 139        MAE_family_models:
 140            - 'mae' x 'vit_b'
 141            - 'mae' x 'vit_l'
 142            - 'mae' x 'vit_h'
 143            - 'scalemae' x 'vit_b'
 144            - 'scalemae' x 'vit_l'
 145            - 'scalemae' x 'vit_h'
 146    """
 147    def __init__(
 148        self,
 149        img_size: int = 1024,
 150        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
 151        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 152        decoder: Optional[nn.Module] = None,
 153        out_channels: int = 1,
 154        use_sam_stats: bool = False,
 155        use_mae_stats: bool = False,
 156        use_dino_stats: bool = False,
 157        resize_input: bool = True,
 158        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 159        final_activation: Optional[Union[str, nn.Module]] = None,
 160        use_skip_connection: bool = True,
 161        embed_dim: Optional[int] = None,
 162        use_conv_transpose: bool = False,
 163        perform_range_checks: bool = True,
 164        **kwargs
 165    ) -> None:
 166        super().__init__()
 167
 168        self.img_size = img_size
 169        self.use_sam_stats = use_sam_stats
 170        self.use_mae_stats = use_mae_stats
 171        self.use_dino_stats = use_dino_stats
 172        self.use_skip_connection = use_skip_connection
 173        self.resize_input = resize_input
 174        self.perform_range_checks = perform_range_checks
 175        self.use_conv_transpose = use_conv_transpose
 176        self.backbone = backbone
 177
 178        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
 179            print(f"Using {encoder} from {backbone.upper()}")
 180            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
 181
 182            if encoder_checkpoint is not None:
 183                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
 184
 185            if embed_dim is None:
 186                embed_dim = self.encoder.embed_dim
 187
 188            # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change.
 189            if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck:
 190                embed_dim = self.encoder.neck[2].out_channels  # the value is 256
 191
 192        else:  # `nn.Module` ViT backbone
 193            self.encoder = encoder
 194
 195            have_neck = False
 196            for name, _ in self.encoder.named_parameters():
 197                if name.startswith("neck"):
 198                    have_neck = True
 199
 200            if embed_dim is None:
 201                if have_neck:
 202                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
 203                else:
 204                    embed_dim = self.encoder.patch_embed.proj.out_channels
 205
 206        self.embed_dim = embed_dim
 207        self.final_activation = self._get_activation(final_activation)
 208
 209    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 210        """Function to load pretrained weights to the image encoder.
 211        """
 212        if isinstance(checkpoint, str):
 213            if backbone == "sam" and isinstance(encoder, str):
 214                # If we have a SAM encoder, then we first try to load the full SAM Model
 215                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 216                # from the checkpoint
 217                try:
 218                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
 219                    encoder_state = model.image_encoder.state_dict()
 220                except Exception:
 221                    # Try loading the encoder state directly from a checkpoint.
 222                    encoder_state = torch.load(checkpoint, weights_only=False)
 223
 224            elif backbone == "cellpose_sam" and isinstance(encoder, str):
 225                # The architecture matches CellposeSAM exactly (same rel_pos sizes),
 226                # so weights load directly without any interpolation.
 227                encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False)
 228                # Handle DataParallel/DistributedDataParallel prefix.
 229                if any(k.startswith("module.") for k in encoder_state.keys()):
 230                    encoder_state = OrderedDict(
 231                        {k[len("module."):]: v for k, v in encoder_state.items()}
 232                    )
 233                # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix).
 234                if any(k.startswith("encoder.") for k in encoder_state.keys()):
 235                    encoder_state = OrderedDict(
 236                        {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")}
 237                    )
 238
 239            elif backbone == "sam2" and isinstance(encoder, str):
 240                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
 241                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
 242                # from the checkpoint
 243                try:
 244                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
 245                    encoder_state = model.image_encoder.state_dict()
 246                except Exception:
 247                    # Try loading the encoder state directly from a checkpoint.
 248                    encoder_state = torch.load(checkpoint, weights_only=False)
 249
 250            elif backbone == "sam3" and isinstance(encoder, str):
 251                # If we have a SAM3 encoder, then we first try to load the full SAM3 Model.
 252                # (using micro_sam3) and otherwise fall back on directly loading the encoder state
 253                # from the checkpoint
 254                try:
 255                    model = get_sam3_model(checkpoint_path=checkpoint)
 256                    encoder_state = model.backbone.vision_backbone.state_dict()
 257                    # Let's align loading the encoder weights with expected parameter names
 258                    encoder_state = {
 259                        k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items()
 260                    }
 261                    # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks.
 262                    encoder_state = {
 263                        k: v for k, v in encoder_state.items()
 264                        if not (k.startswith("convs.") or k.startswith("sam2_convs."))
 265                    }
 266                except Exception:
 267                    # Try loading the encoder state directly from a checkpoint.
 268                    encoder_state = torch.load(checkpoint, weights_only=False)
 269
 270            elif backbone == "mae":
 271                # vit initialization hints from:
 272                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 273                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
 274                encoder_state = OrderedDict({
 275                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
 276                })
 277                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 278                current_encoder_state = self.encoder.state_dict()
 279                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 280                    del self.encoder.head
 281
 282            elif backbone == "scalemae":
 283                # Load the encoder state directly from a checkpoint.
 284                encoder_state = torch.load(checkpoint)["model"]
 285                encoder_state = OrderedDict({
 286                    k: v for k, v in encoder_state.items()
 287                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
 288                })
 289
 290                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 291                current_encoder_state = self.encoder.state_dict()
 292                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 293                    del self.encoder.head
 294
 295                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
 296                    del self.encoder.pos_embed
 297
 298            elif backbone in ["dinov2", "dinov3"]:  # Load the encoder state directly from a checkpoint.
 299                encoder_state = torch.load(checkpoint)
 300
 301            else:
 302                raise ValueError(
 303                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
 304                )
 305
 306        else:
 307            encoder_state = checkpoint
 308
 309        self.encoder.load_state_dict(encoder_state)
 310
 311    def _get_activation(self, activation):
 312        return_activation = None
 313        if activation is None:
 314            return None
 315        if isinstance(activation, nn.Module):
 316            return activation
 317        if isinstance(activation, str):
 318            return_activation = getattr(nn, activation, None)
 319        if return_activation is None:
 320            raise ValueError(f"Invalid activation: {activation}")
 321
 322        return return_activation()
 323
 324    @staticmethod
 325    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
 326        """Compute the output size given input size and target long side length.
 327
 328        Args:
 329            oldh: The input image height.
 330            oldw: The input image width.
 331            long_side_length: The longest side length for resizing.
 332
 333        Returns:
 334            The new image height.
 335            The new image width.
 336        """
 337        scale = long_side_length * 1.0 / max(oldh, oldw)
 338        newh, neww = oldh * scale, oldw * scale
 339        neww = int(neww + 0.5)
 340        newh = int(newh + 0.5)
 341        return (newh, neww)
 342
 343    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
 344        """Resize the image so that the longest side has the correct length.
 345
 346        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
 347
 348        Args:
 349            image: The input image.
 350
 351        Returns:
 352            The resized image.
 353        """
 354        if image.ndim == 4:  # i.e. 2d image
 355            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
 356            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
 357        elif image.ndim == 5:  # i.e. 3d volume
 358            B, C, Z, H, W = image.shape
 359            target_size = self.get_preprocess_shape(H, W, self.img_size)
 360            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
 361        else:
 362            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
 363
 364    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
 365        """@private
 366        """
 367        return _as_stats(mean, std, device, dtype, is_3d)
 368
 369    def _check_input_normalization_range(self, x: torch.Tensor, expected_range: Optional[Tuple[float, float]]) -> None:
 370        """@private
 371        """
 372        _check_input_normalization_range(x, expected_range)
 373
 374    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
 375        """@private
 376        """
 377        return preprocess_vit_inputs(
 378            x,
 379            use_sam_stats=self.use_sam_stats,
 380            backbone=self.backbone,
 381            use_mae_stats=self.use_mae_stats,
 382            use_dino_stats=self.use_dino_stats,
 383            resize_input=self.resize_input,
 384            img_size=self.img_size,
 385            encoder_img_size=self.encoder.img_size,
 386            perform_range_checks=self.perform_range_checks,
 387        )
 388
 389    def postprocess_masks(
 390        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
 391    ) -> torch.Tensor:
 392        """@private
 393        """
 394        if masks.ndim == 4:  # i.e. 2d labels
 395            masks = F.interpolate(
 396                masks,
 397                (self.encoder.img_size, self.encoder.img_size),
 398                mode="bilinear",
 399                align_corners=False,
 400            )
 401            masks = masks[..., : input_size[0], : input_size[1]]
 402            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
 403
 404        elif masks.ndim == 5:  # i.e. 3d volumetric labels
 405            masks = F.interpolate(
 406                masks,
 407                (input_size[0], self.img_size, self.img_size),
 408                mode="trilinear",
 409                align_corners=False,
 410            )
 411            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
 412            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
 413
 414        else:
 415            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
 416
 417        return masks
 418
 419
 420def preprocess_vit_inputs(
 421    x: torch.Tensor,
 422    use_sam_stats: bool = False,
 423    backbone: str = "sam",
 424    use_mae_stats: bool = False,
 425    use_dino_stats: bool = False,
 426    resize_input: bool = True,
 427    img_size: int = 1024,
 428    encoder_img_size: int = 1024,
 429    perform_range_checks: bool = True,
 430) -> Tuple[torch.Tensor, Tuple]:
 431    """Preprocess inputs for ViT-backbones in UNETR models.
 432
 433    Handles normalization stat selection, input range validation, optional resizing to the longest side,
 434    and padding to `encoder_img_size`. Can be used as a standalone function without a model instance.
 435
 436    Args:
 437        x: Input tensor of shape (B, C, H, W) for 2D or (B, C, Z, H, W) for 3D.
 438        use_sam_stats: Whether to normalize with SAM/SAM2/SAM3 backbone statistics.
 439        backbone: The backbone name - controls which SAM stats are used when `use_sam_stats=True`.
 440        use_mae_stats: Whether to normalize with MAE statistics.
 441        use_dino_stats: Whether to normalize with DINOv2/DINOv3 statistics.
 442        resize_input: Whether to resize the input to the longest side before padding.
 443        img_size: The model image size, used for 3D resize.
 444        encoder_img_size: The encoder image size, used for 2D resize and padding.
 445        perform_range_checks: Whether to validate the expected input value range before normalization.
 446            You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct.
 447
 448    Returns:
 449        The preprocessed tensor and the spatial shape after resizing (before padding).
 450    """
 451    is_3d = (x.ndim == 5)
 452    device, dtype = x.device, x.dtype
 453    mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
 454    expected_range = None
 455    unit_scale_max = None
 456
 457    if use_sam_stats:
 458        if backbone == "sam2":
 459            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
 460            expected_range = (0.0, 1.0)
 461        elif backbone == "sam3":
 462            mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
 463            expected_range = (0.0, 1.0)
 464        else:  # sam1 / default
 465            mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
 466            expected_range = (0.0, 255.0)
 467            unit_scale_max = 1.0
 468    elif use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
 469        raise NotImplementedError
 470    elif use_dino_stats:
 471        mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
 472        expected_range = (0.0, 1.0)
 473    else:
 474        mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
 475        expected_range = None
 476
 477    if perform_range_checks:
 478        _check_input_normalization_range(x, expected_range, unit_scale_max)
 479    pixel_mean, pixel_std = _as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
 480
 481    if resize_input:
 482        if x.ndim == 4:
 483            target_size = UNETRBase.get_preprocess_shape(x.shape[2], x.shape[3], encoder_img_size)
 484            x = F.interpolate(x, target_size, mode="bilinear", align_corners=False, antialias=True)
 485        elif x.ndim == 5:
 486            B, C, Z, H, W = x.shape
 487            target_size = UNETRBase.get_preprocess_shape(H, W, img_size)
 488            x = F.interpolate(x, (Z, *target_size), mode="trilinear", align_corners=False)
 489
 490    input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
 491
 492    x = (x - pixel_mean) / pixel_std
 493    h, w = x.shape[-2:]
 494    padh = encoder_img_size - h
 495    padw = encoder_img_size - w
 496
 497    if is_3d:
 498        x = F.pad(x, (0, padw, 0, padh, 0, 0))
 499    else:
 500        x = F.pad(x, (0, padw, 0, padh))
 501
 502    return x, input_shape
 503
 504
 505class UNETR(UNETRBase):
 506    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
 507    """
 508    def __init__(
 509        self,
 510        img_size: int = 1024,
 511        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
 512        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 513        decoder: Optional[nn.Module] = None,
 514        out_channels: int = 1,
 515        use_sam_stats: bool = False,
 516        use_mae_stats: bool = False,
 517        use_dino_stats: bool = False,
 518        resize_input: bool = True,
 519        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 520        final_activation: Optional[Union[str, nn.Module]] = None,
 521        use_skip_connection: bool = True,
 522        embed_dim: Optional[int] = None,
 523        use_conv_transpose: bool = False,
 524        perform_range_checks: bool = True,
 525        **kwargs
 526    ) -> None:
 527
 528        super().__init__(
 529            img_size=img_size,
 530            backbone=backbone,
 531            encoder=encoder,
 532            decoder=decoder,
 533            out_channels=out_channels,
 534            use_sam_stats=use_sam_stats,
 535            use_mae_stats=use_mae_stats,
 536            use_dino_stats=use_dino_stats,
 537            resize_input=resize_input,
 538            encoder_checkpoint=encoder_checkpoint,
 539            final_activation=final_activation,
 540            use_skip_connection=use_skip_connection,
 541            embed_dim=embed_dim,
 542            use_conv_transpose=use_conv_transpose,
 543            perform_range_checks=perform_range_checks,
 544            **kwargs,
 545        )
 546
 547        encoder = self.encoder
 548
 549        if backbone == "sam2" and hasattr(encoder, "trunk"):
 550            in_chans = encoder.trunk.patch_embed.proj.in_channels
 551        elif hasattr(encoder, "in_chans"):
 552            in_chans = encoder.in_chans
 553        else:  # `nn.Module` ViT backbone.
 554            try:
 555                in_chans = encoder.patch_embed.proj.in_channels
 556            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
 557                in_chans = encoder.patch_embed.seq[0].c.in_channels
 558
 559        # parameters for the decoder network
 560        depth = 3
 561        initial_features = 64
 562        gain = 2
 563        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
 564        scale_factors = depth * [2]
 565        self.out_channels = out_channels
 566
 567        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
 568        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
 569
 570        self.decoder = decoder or Decoder(
 571            features=features_decoder,
 572            scale_factors=scale_factors[::-1],
 573            conv_block_impl=ConvBlock2d,
 574            sampler_impl=_upsampler,
 575        )
 576
 577        if use_skip_connection:
 578            self.deconv1 = Deconv2DBlock(
 579                in_channels=self.embed_dim,
 580                out_channels=features_decoder[0],
 581                use_conv_transpose=use_conv_transpose,
 582            )
 583            self.deconv2 = nn.Sequential(
 584                Deconv2DBlock(
 585                    in_channels=self.embed_dim,
 586                    out_channels=features_decoder[0],
 587                    use_conv_transpose=use_conv_transpose,
 588                ),
 589                Deconv2DBlock(
 590                    in_channels=features_decoder[0],
 591                    out_channels=features_decoder[1],
 592                    use_conv_transpose=use_conv_transpose,
 593                )
 594            )
 595            self.deconv3 = nn.Sequential(
 596                Deconv2DBlock(
 597                    in_channels=self.embed_dim,
 598                    out_channels=features_decoder[0],
 599                    use_conv_transpose=use_conv_transpose,
 600                ),
 601                Deconv2DBlock(
 602                    in_channels=features_decoder[0],
 603                    out_channels=features_decoder[1],
 604                    use_conv_transpose=use_conv_transpose,
 605                ),
 606                Deconv2DBlock(
 607                    in_channels=features_decoder[1],
 608                    out_channels=features_decoder[2],
 609                    use_conv_transpose=use_conv_transpose,
 610                )
 611            )
 612            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
 613        else:
 614            self.deconv1 = Deconv2DBlock(
 615                in_channels=self.embed_dim,
 616                out_channels=features_decoder[0],
 617                use_conv_transpose=use_conv_transpose,
 618            )
 619            self.deconv2 = Deconv2DBlock(
 620                in_channels=features_decoder[0],
 621                out_channels=features_decoder[1],
 622                use_conv_transpose=use_conv_transpose,
 623            )
 624            self.deconv3 = Deconv2DBlock(
 625                in_channels=features_decoder[1],
 626                out_channels=features_decoder[2],
 627                use_conv_transpose=use_conv_transpose,
 628            )
 629            self.deconv4 = Deconv2DBlock(
 630                in_channels=features_decoder[2],
 631                out_channels=features_decoder[3],
 632                use_conv_transpose=use_conv_transpose,
 633            )
 634
 635        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
 636        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
 637        self.deconv_out = _upsampler(
 638            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
 639        )
 640        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
 641
 642    def forward(self, x: torch.Tensor) -> torch.Tensor:
 643        """Apply the UNETR to the input data.
 644
 645        Args:
 646            x: The input tensor.
 647
 648        Returns:
 649            The UNETR output.
 650        """
 651        original_shape = x.shape[-2:]
 652
 653        # Reshape the inputs to the shape expected by the encoder
 654        # and normalize the inputs if normalization is part of the model.
 655        x, input_shape = self.preprocess(x)
 656
 657        encoder_outputs = self.encoder(x)
 658
 659        if isinstance(encoder_outputs[-1], list):
 660            # `encoder_outputs` can be arranged in only two forms:
 661            #   - either we only return the image embeddings
 662            #   - or, we return the image embeddings and the "list" of global attention layers
 663            z12, from_encoder = encoder_outputs
 664        else:
 665            z12 = encoder_outputs
 666
 667        if self.use_skip_connection:
 668            from_encoder = from_encoder[::-1]
 669            z9 = self.deconv1(from_encoder[0])
 670            z6 = self.deconv2(from_encoder[1])
 671            z3 = self.deconv3(from_encoder[2])
 672            z0 = self.deconv4(x)
 673
 674        else:
 675            z9 = self.deconv1(z12)
 676            z6 = self.deconv2(z9)
 677            z3 = self.deconv3(z6)
 678            z0 = self.deconv4(z3)
 679
 680        updated_from_encoder = [z9, z6, z3]
 681
 682        x = self.base(z12)
 683        x = self.decoder(x, encoder_inputs=updated_from_encoder)
 684        x = self.deconv_out(x)
 685
 686        x = torch.cat([x, z0], dim=1)
 687        x = self.decoder_head(x)
 688
 689        x = self.out_conv(x)
 690        if self.final_activation is not None:
 691            x = self.final_activation(x)
 692
 693        x = self.postprocess_masks(x, input_shape, original_shape)
 694        return x
 695
 696
 697class UNETR2D(UNETR):
 698    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
 699    """
 700    pass
 701
 702
 703class UNETR3D(UNETRBase):
 704    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
 705    """
 706    def __init__(
 707        self,
 708        img_size: int = 1024,
 709        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
 710        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
 711        decoder: Optional[nn.Module] = None,
 712        out_channels: int = 1,
 713        use_sam_stats: bool = False,
 714        use_mae_stats: bool = False,
 715        use_dino_stats: bool = False,
 716        resize_input: bool = True,
 717        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 718        final_activation: Optional[Union[str, nn.Module]] = None,
 719        use_skip_connection: bool = False,
 720        embed_dim: Optional[int] = None,
 721        use_conv_transpose: bool = False,
 722        use_strip_pooling: bool = True,
 723        perform_range_checks: bool = True,
 724        **kwargs
 725    ):
 726        if use_skip_connection:
 727            raise NotImplementedError("The framework cannot handle skip connections atm.")
 728        if use_conv_transpose:
 729            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
 730
 731        # Sort the `embed_dim` out
 732        embed_dim = 256 if embed_dim is None else embed_dim
 733
 734        super().__init__(
 735            img_size=img_size,
 736            backbone=backbone,
 737            encoder=encoder,
 738            decoder=decoder,
 739            out_channels=out_channels,
 740            use_sam_stats=use_sam_stats,
 741            use_mae_stats=use_mae_stats,
 742            use_dino_stats=use_dino_stats,
 743            resize_input=resize_input,
 744            encoder_checkpoint=encoder_checkpoint,
 745            final_activation=final_activation,
 746            use_skip_connection=use_skip_connection,
 747            embed_dim=embed_dim,
 748            use_conv_transpose=use_conv_transpose,
 749            perform_range_checks=perform_range_checks,
 750            **kwargs,
 751        )
 752
 753        # The 3d convolutional decoder.
 754        # First, get the important parameters for the decoder.
 755        depth = 3
 756        initial_features = 64
 757        gain = 2
 758        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
 759        scale_factors = [1, 2, 2]
 760        self.out_channels = out_channels
 761
 762        # The mapping blocks.
 763        self.deconv1 = Deconv3DBlock(
 764            in_channels=embed_dim,
 765            out_channels=features_decoder[0],
 766            scale_factor=scale_factors,
 767            use_strip_pooling=use_strip_pooling,
 768        )
 769        self.deconv2 = Deconv3DBlock(
 770            in_channels=features_decoder[0],
 771            out_channels=features_decoder[1],
 772            scale_factor=scale_factors,
 773            use_strip_pooling=use_strip_pooling,
 774        )
 775        self.deconv3 = Deconv3DBlock(
 776            in_channels=features_decoder[1],
 777            out_channels=features_decoder[2],
 778            scale_factor=scale_factors,
 779            use_strip_pooling=use_strip_pooling,
 780        )
 781        self.deconv4 = Deconv3DBlock(
 782            in_channels=features_decoder[2],
 783            out_channels=features_decoder[3],
 784            scale_factor=scale_factors,
 785            use_strip_pooling=use_strip_pooling,
 786        )
 787
 788        # The core decoder block.
 789        self.decoder = decoder or Decoder(
 790            features=features_decoder,
 791            scale_factors=[scale_factors] * depth,
 792            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
 793            sampler_impl=Upsampler3d,
 794        )
 795
 796        # And the final upsampler to match the expected dimensions.
 797        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
 798            in_channels=features_decoder[-1],
 799            out_channels=features_decoder[-1],
 800            scale_factor=scale_factors,
 801            use_strip_pooling=use_strip_pooling,
 802        )
 803
 804        # Additional conjunction blocks.
 805        self.base = ConvBlock3dWithStrip(
 806            in_channels=embed_dim,
 807            out_channels=features_decoder[0],
 808            use_strip_pooling=use_strip_pooling,
 809        )
 810
 811        # And the output layers.
 812        self.decoder_head = ConvBlock3dWithStrip(
 813            in_channels=2 * features_decoder[-1],
 814            out_channels=features_decoder[-1],
 815            use_strip_pooling=use_strip_pooling,
 816        )
 817        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
 818
 819    def forward(self, x: torch.Tensor):
 820        """Forward pass of the UNETR-3D model.
 821
 822        Args:
 823            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
 824
 825        Returns:
 826            The UNETR output.
 827        """
 828        B, C, Z, H, W = x.shape
 829        original_shape = (Z, H, W)
 830
 831        # Preprocessing step
 832        x, input_shape = self.preprocess(x)
 833
 834        # Run the image encoder.
 835        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
 836
 837        # Prepare the counterparts for the decoder.
 838        # NOTE: The section below is sequential, there's no skip connections atm.
 839        z9 = self.deconv1(curr_features)
 840        z6 = self.deconv2(z9)
 841        z3 = self.deconv3(z6)
 842        z0 = self.deconv4(z3)
 843
 844        updated_from_encoder = [z9, z6, z3]
 845
 846        # Align the features through the base block.
 847        x = self.base(curr_features)
 848        # Run the decoder
 849        x = self.decoder(x, encoder_inputs=updated_from_encoder)
 850        x = self.deconv_out(x)  # NOTE before `end_up`
 851
 852        # And the final output head.
 853        x = torch.cat([x, z0], dim=1)
 854        x = self.decoder_head(x)
 855        x = self.out_conv(x)
 856        if self.final_activation is not None:
 857            x = self.final_activation(x)
 858
 859        # Postprocess the output back to original size.
 860        x = self.postprocess_masks(x, input_shape, original_shape)
 861        return x
 862
 863#
 864#  ADDITIONAL FUNCTIONALITIES
 865#
 866
 867
 868def _strip_pooling_layers(enabled, channels) -> nn.Module:
 869    return DepthStripPooling(channels) if enabled else nn.Identity()
 870
 871
 872class DepthStripPooling(nn.Module):
 873    """@private
 874    """
 875    def __init__(self, channels: int, reduction: int = 4):
 876        """Block for strip pooling along the depth dimension (only).
 877
 878        eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling
 879        to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to
 880        modulate the original features (using a gated residual).
 881
 882        For 2D (Z == 1): returns input unchanged (no-op).
 883
 884        Args:
 885            channels: The output channels.
 886            reduction: The reduction of the hidden layers.
 887        """
 888        super().__init__()
 889        hidden = max(1, channels // reduction)
 890        self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1)
 891        self.bn1 = nn.BatchNorm3d(hidden)
 892        self.relu = nn.ReLU(inplace=True)
 893        self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1)
 894
 895    def forward(self, x: torch.Tensor) -> torch.Tensor:
 896        if x.dim() != 5:
 897            raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.")
 898
 899        B, C, Z, H, W = x.shape
 900        if Z == 1:  # i.e. always the case of all 2d.
 901            return x  # We simply do nothing there.
 902
 903        # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W)
 904        feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W))
 905        feat = self.conv1(feat)
 906        feat = self.bn1(feat)
 907        feat = self.relu(feat)
 908        feat = self.conv2(feat)
 909        gate = torch.sigmoid(feat).expand(B, C, Z, H, W)  # Broadcast the collapsed depth context back to all slices
 910
 911        # Gated residual fusion
 912        return x * gate + x
 913
 914
 915class Deconv3DBlock(nn.Module):
 916    """@private
 917    """
 918    def __init__(
 919        self,
 920        scale_factor,
 921        in_channels,
 922        out_channels,
 923        kernel_size=3,
 924        anisotropic_kernel=True,
 925        use_strip_pooling=True,
 926    ):
 927        super().__init__()
 928        conv_block_kwargs = {
 929            "in_channels": out_channels,
 930            "out_channels": out_channels,
 931            "kernel_size": kernel_size,
 932            "padding": ((kernel_size - 1) // 2),
 933        }
 934        if anisotropic_kernel:
 935            conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor)
 936
 937        self.block = nn.Sequential(
 938            Upsampler3d(scale_factor, in_channels, out_channels),
 939            nn.Conv3d(**conv_block_kwargs),
 940            nn.BatchNorm3d(out_channels),
 941            nn.ReLU(True),
 942            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
 943        )
 944
 945    def forward(self, x):
 946        return self.block(x)
 947
 948
 949class ConvBlock3dWithStrip(nn.Module):
 950    """@private
 951    """
 952    def __init__(
 953        self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs
 954    ):
 955        super().__init__()
 956        self.block = nn.Sequential(
 957            ConvBlock3d(in_channels, out_channels, **kwargs),
 958            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
 959        )
 960
 961    def forward(self, x):
 962        return self.block(x)
 963
 964
 965class SingleDeconv2DBlock(nn.Module):
 966    """@private
 967    """
 968    def __init__(self, scale_factor, in_channels, out_channels):
 969        super().__init__()
 970        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
 971
 972    def forward(self, x):
 973        return self.block(x)
 974
 975
 976class SingleConv2DBlock(nn.Module):
 977    """@private
 978    """
 979    def __init__(self, in_channels, out_channels, kernel_size):
 980        super().__init__()
 981        self.block = nn.Conv2d(
 982            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
 983        )
 984
 985    def forward(self, x):
 986        return self.block(x)
 987
 988
 989class Conv2DBlock(nn.Module):
 990    """@private
 991    """
 992    def __init__(self, in_channels, out_channels, kernel_size=3):
 993        super().__init__()
 994        self.block = nn.Sequential(
 995            SingleConv2DBlock(in_channels, out_channels, kernel_size),
 996            nn.BatchNorm2d(out_channels),
 997            nn.ReLU(True)
 998        )
 999
1000    def forward(self, x):
1001        return self.block(x)
1002
1003
1004class Deconv2DBlock(nn.Module):
1005    """@private
1006    """
1007    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
1008        super().__init__()
1009        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
1010        self.block = nn.Sequential(
1011            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
1012            SingleConv2DBlock(out_channels, out_channels, kernel_size),
1013            nn.BatchNorm2d(out_channels),
1014            nn.ReLU(True)
1015        )
1016
1017    def forward(self, x):
1018        return self.block(x)
class UNETRBase(torch.nn.modules.module.Module):
 81class UNETRBase(nn.Module):
 82    """Base class for implementing a UNETR.
 83
 84    Args:
 85        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 86        backbone: The name of the vision transformer implementation.
 87            One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"
 88            (see all combinations below)
 89        encoder: The vision transformer. Can either be a name, such as "vit_b"
 90            (see all combinations for this below) or a torch module.
 91        decoder: The convolutional decoder.
 92        out_channels: The number of output channels of the UNETR.
 93        use_sam_stats: Whether to normalize the input data with the statistics of the
 94            pretrained SAM / SAM2 / SAM3 model.
 95        use_dino_stats: Whether to normalize the input data with the statistics of the
 96            pretrained DINOv2 / DINOv3 model.
 97        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 98        resize_input: Whether to resize the input images to match `img_size`.
 99            By default, it resizes the inputs to match the `img_size`.
100        encoder_checkpoint: Checkpoint for initializing the vision transformer.
101            Can either be a filepath or an already loaded checkpoint.
102        final_activation: The activation to apply to the UNETR output.
103        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
104        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
105        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
106            By default, it uses resampling for upsampling.
107        perform_range_checks: Whether to validate the input value range before normalization on each forward pass.
108            You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct.
109
110        NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
111
112        SAM_family_models:
113            - 'sam' x 'vit_b'
114            - 'sam' x 'vit_l'
115            - 'sam' x 'vit_h'
116            - 'sam2' x 'hvit_t'
117            - 'sam2' x 'hvit_s'
118            - 'sam2' x 'hvit_b'
119            - 'sam2' x 'hvit_l'
120            - 'sam3' x 'vit_pe'
121            - 'cellpose_sam' x 'vit_l'
122
123        DINO_family_models:
124            - 'dinov2' x 'vit_s'
125            - 'dinov2' x 'vit_b'
126            - 'dinov2' x 'vit_l'
127            - 'dinov2' x 'vit_g'
128            - 'dinov2' x 'vit_s_reg4'
129            - 'dinov2' x 'vit_b_reg4'
130            - 'dinov2' x 'vit_l_reg4'
131            - 'dinov2' x 'vit_g_reg4'
132            - 'dinov3' x 'vit_s'
133            - 'dinov3' x 'vit_s+'
134            - 'dinov3' x 'vit_b'
135            - 'dinov3' x 'vit_l'
136            - 'dinov3' x 'vit_l+'
137            - 'dinov3' x 'vit_h+'
138            - 'dinov3' x 'vit_7b'
139
140        MAE_family_models:
141            - 'mae' x 'vit_b'
142            - 'mae' x 'vit_l'
143            - 'mae' x 'vit_h'
144            - 'scalemae' x 'vit_b'
145            - 'scalemae' x 'vit_l'
146            - 'scalemae' x 'vit_h'
147    """
148    def __init__(
149        self,
150        img_size: int = 1024,
151        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
152        encoder: Optional[Union[nn.Module, str]] = "vit_b",
153        decoder: Optional[nn.Module] = None,
154        out_channels: int = 1,
155        use_sam_stats: bool = False,
156        use_mae_stats: bool = False,
157        use_dino_stats: bool = False,
158        resize_input: bool = True,
159        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
160        final_activation: Optional[Union[str, nn.Module]] = None,
161        use_skip_connection: bool = True,
162        embed_dim: Optional[int] = None,
163        use_conv_transpose: bool = False,
164        perform_range_checks: bool = True,
165        **kwargs
166    ) -> None:
167        super().__init__()
168
169        self.img_size = img_size
170        self.use_sam_stats = use_sam_stats
171        self.use_mae_stats = use_mae_stats
172        self.use_dino_stats = use_dino_stats
173        self.use_skip_connection = use_skip_connection
174        self.resize_input = resize_input
175        self.perform_range_checks = perform_range_checks
176        self.use_conv_transpose = use_conv_transpose
177        self.backbone = backbone
178
179        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
180            print(f"Using {encoder} from {backbone.upper()}")
181            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
182
183            if encoder_checkpoint is not None:
184                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
185
186            if embed_dim is None:
187                embed_dim = self.encoder.embed_dim
188
189            # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change.
190            if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck:
191                embed_dim = self.encoder.neck[2].out_channels  # the value is 256
192
193        else:  # `nn.Module` ViT backbone
194            self.encoder = encoder
195
196            have_neck = False
197            for name, _ in self.encoder.named_parameters():
198                if name.startswith("neck"):
199                    have_neck = True
200
201            if embed_dim is None:
202                if have_neck:
203                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
204                else:
205                    embed_dim = self.encoder.patch_embed.proj.out_channels
206
207        self.embed_dim = embed_dim
208        self.final_activation = self._get_activation(final_activation)
209
210    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
211        """Function to load pretrained weights to the image encoder.
212        """
213        if isinstance(checkpoint, str):
214            if backbone == "sam" and isinstance(encoder, str):
215                # If we have a SAM encoder, then we first try to load the full SAM Model
216                # (using micro_sam) and otherwise fall back on directly loading the encoder state
217                # from the checkpoint
218                try:
219                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
220                    encoder_state = model.image_encoder.state_dict()
221                except Exception:
222                    # Try loading the encoder state directly from a checkpoint.
223                    encoder_state = torch.load(checkpoint, weights_only=False)
224
225            elif backbone == "cellpose_sam" and isinstance(encoder, str):
226                # The architecture matches CellposeSAM exactly (same rel_pos sizes),
227                # so weights load directly without any interpolation.
228                encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False)
229                # Handle DataParallel/DistributedDataParallel prefix.
230                if any(k.startswith("module.") for k in encoder_state.keys()):
231                    encoder_state = OrderedDict(
232                        {k[len("module."):]: v for k, v in encoder_state.items()}
233                    )
234                # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix).
235                if any(k.startswith("encoder.") for k in encoder_state.keys()):
236                    encoder_state = OrderedDict(
237                        {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")}
238                    )
239
240            elif backbone == "sam2" and isinstance(encoder, str):
241                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
242                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
243                # from the checkpoint
244                try:
245                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
246                    encoder_state = model.image_encoder.state_dict()
247                except Exception:
248                    # Try loading the encoder state directly from a checkpoint.
249                    encoder_state = torch.load(checkpoint, weights_only=False)
250
251            elif backbone == "sam3" and isinstance(encoder, str):
252                # If we have a SAM3 encoder, then we first try to load the full SAM3 Model.
253                # (using micro_sam3) and otherwise fall back on directly loading the encoder state
254                # from the checkpoint
255                try:
256                    model = get_sam3_model(checkpoint_path=checkpoint)
257                    encoder_state = model.backbone.vision_backbone.state_dict()
258                    # Let's align loading the encoder weights with expected parameter names
259                    encoder_state = {
260                        k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items()
261                    }
262                    # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks.
263                    encoder_state = {
264                        k: v for k, v in encoder_state.items()
265                        if not (k.startswith("convs.") or k.startswith("sam2_convs."))
266                    }
267                except Exception:
268                    # Try loading the encoder state directly from a checkpoint.
269                    encoder_state = torch.load(checkpoint, weights_only=False)
270
271            elif backbone == "mae":
272                # vit initialization hints from:
273                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
274                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
275                encoder_state = OrderedDict({
276                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
277                })
278                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
279                current_encoder_state = self.encoder.state_dict()
280                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
281                    del self.encoder.head
282
283            elif backbone == "scalemae":
284                # Load the encoder state directly from a checkpoint.
285                encoder_state = torch.load(checkpoint)["model"]
286                encoder_state = OrderedDict({
287                    k: v for k, v in encoder_state.items()
288                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
289                })
290
291                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
292                current_encoder_state = self.encoder.state_dict()
293                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
294                    del self.encoder.head
295
296                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
297                    del self.encoder.pos_embed
298
299            elif backbone in ["dinov2", "dinov3"]:  # Load the encoder state directly from a checkpoint.
300                encoder_state = torch.load(checkpoint)
301
302            else:
303                raise ValueError(
304                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
305                )
306
307        else:
308            encoder_state = checkpoint
309
310        self.encoder.load_state_dict(encoder_state)
311
312    def _get_activation(self, activation):
313        return_activation = None
314        if activation is None:
315            return None
316        if isinstance(activation, nn.Module):
317            return activation
318        if isinstance(activation, str):
319            return_activation = getattr(nn, activation, None)
320        if return_activation is None:
321            raise ValueError(f"Invalid activation: {activation}")
322
323        return return_activation()
324
325    @staticmethod
326    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
327        """Compute the output size given input size and target long side length.
328
329        Args:
330            oldh: The input image height.
331            oldw: The input image width.
332            long_side_length: The longest side length for resizing.
333
334        Returns:
335            The new image height.
336            The new image width.
337        """
338        scale = long_side_length * 1.0 / max(oldh, oldw)
339        newh, neww = oldh * scale, oldw * scale
340        neww = int(neww + 0.5)
341        newh = int(newh + 0.5)
342        return (newh, neww)
343
344    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
345        """Resize the image so that the longest side has the correct length.
346
347        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
348
349        Args:
350            image: The input image.
351
352        Returns:
353            The resized image.
354        """
355        if image.ndim == 4:  # i.e. 2d image
356            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
357            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
358        elif image.ndim == 5:  # i.e. 3d volume
359            B, C, Z, H, W = image.shape
360            target_size = self.get_preprocess_shape(H, W, self.img_size)
361            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
362        else:
363            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
364
365    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
366        """@private
367        """
368        return _as_stats(mean, std, device, dtype, is_3d)
369
370    def _check_input_normalization_range(self, x: torch.Tensor, expected_range: Optional[Tuple[float, float]]) -> None:
371        """@private
372        """
373        _check_input_normalization_range(x, expected_range)
374
375    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
376        """@private
377        """
378        return preprocess_vit_inputs(
379            x,
380            use_sam_stats=self.use_sam_stats,
381            backbone=self.backbone,
382            use_mae_stats=self.use_mae_stats,
383            use_dino_stats=self.use_dino_stats,
384            resize_input=self.resize_input,
385            img_size=self.img_size,
386            encoder_img_size=self.encoder.img_size,
387            perform_range_checks=self.perform_range_checks,
388        )
389
390    def postprocess_masks(
391        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
392    ) -> torch.Tensor:
393        """@private
394        """
395        if masks.ndim == 4:  # i.e. 2d labels
396            masks = F.interpolate(
397                masks,
398                (self.encoder.img_size, self.encoder.img_size),
399                mode="bilinear",
400                align_corners=False,
401            )
402            masks = masks[..., : input_size[0], : input_size[1]]
403            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
404
405        elif masks.ndim == 5:  # i.e. 3d volumetric labels
406            masks = F.interpolate(
407                masks,
408                (input_size[0], self.img_size, self.img_size),
409                mode="trilinear",
410                align_corners=False,
411            )
412            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
413            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
414
415        else:
416            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
417
418        return masks

Base class for implementing a UNETR.

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • backbone: The name of the vision transformer implementation. One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
  • encoder: The vision transformer. Can either be a name, such as "vit_b" (see all combinations for this below) or a torch module.
  • decoder: The convolutional decoder.
  • out_channels: The number of output channels of the UNETR.
  • use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM / SAM2 / SAM3 model.
  • use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv2 / DINOv3 model.
  • use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  • resize_input: Whether to resize the input images to match img_size. By default, it resizes the inputs to match the img_size.
  • encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
  • final_activation: The activation to apply to the UNETR output.
  • use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
  • embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
  • use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
  • perform_range_checks: Whether to validate the input value range before normalization on each forward pass. You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct.
  • NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
  • SAM_family_models: - 'sam' x 'vit_b'
    • 'sam' x 'vit_l'
    • 'sam' x 'vit_h'
    • 'sam2' x 'hvit_t'
    • 'sam2' x 'hvit_s'
    • 'sam2' x 'hvit_b'
    • 'sam2' x 'hvit_l'
    • 'sam3' x 'vit_pe'
    • 'cellpose_sam' x 'vit_l'
  • DINO_family_models: - 'dinov2' x 'vit_s'
    • 'dinov2' x 'vit_b'
    • 'dinov2' x 'vit_l'
    • 'dinov2' x 'vit_g'
    • 'dinov2' x 'vit_s_reg4'
    • 'dinov2' x 'vit_b_reg4'
    • 'dinov2' x 'vit_l_reg4'
    • 'dinov2' x 'vit_g_reg4'
    • 'dinov3' x 'vit_s'
    • 'dinov3' x 'vit_s+'
    • 'dinov3' x 'vit_b'
    • 'dinov3' x 'vit_l'
    • 'dinov3' x 'vit_l+'
    • 'dinov3' x 'vit_h+'
    • 'dinov3' x 'vit_7b'
  • MAE_family_models: - 'mae' x 'vit_b'
    • 'mae' x 'vit_l'
    • 'mae' x 'vit_h'
    • 'scalemae' x 'vit_b'
    • 'scalemae' x 'vit_l'
    • 'scalemae' x 'vit_h'
UNETRBase( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, perform_range_checks: bool = True, **kwargs)
148    def __init__(
149        self,
150        img_size: int = 1024,
151        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
152        encoder: Optional[Union[nn.Module, str]] = "vit_b",
153        decoder: Optional[nn.Module] = None,
154        out_channels: int = 1,
155        use_sam_stats: bool = False,
156        use_mae_stats: bool = False,
157        use_dino_stats: bool = False,
158        resize_input: bool = True,
159        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
160        final_activation: Optional[Union[str, nn.Module]] = None,
161        use_skip_connection: bool = True,
162        embed_dim: Optional[int] = None,
163        use_conv_transpose: bool = False,
164        perform_range_checks: bool = True,
165        **kwargs
166    ) -> None:
167        super().__init__()
168
169        self.img_size = img_size
170        self.use_sam_stats = use_sam_stats
171        self.use_mae_stats = use_mae_stats
172        self.use_dino_stats = use_dino_stats
173        self.use_skip_connection = use_skip_connection
174        self.resize_input = resize_input
175        self.perform_range_checks = perform_range_checks
176        self.use_conv_transpose = use_conv_transpose
177        self.backbone = backbone
178
179        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
180            print(f"Using {encoder} from {backbone.upper()}")
181            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
182
183            if encoder_checkpoint is not None:
184                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
185
186            if embed_dim is None:
187                embed_dim = self.encoder.embed_dim
188
189            # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change.
190            if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck:
191                embed_dim = self.encoder.neck[2].out_channels  # the value is 256
192
193        else:  # `nn.Module` ViT backbone
194            self.encoder = encoder
195
196            have_neck = False
197            for name, _ in self.encoder.named_parameters():
198                if name.startswith("neck"):
199                    have_neck = True
200
201            if embed_dim is None:
202                if have_neck:
203                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
204                else:
205                    embed_dim = self.encoder.patch_embed.proj.out_channels
206
207        self.embed_dim = embed_dim
208        self.final_activation = self._get_activation(final_activation)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

img_size
use_sam_stats
use_mae_stats
use_dino_stats
use_skip_connection
resize_input
perform_range_checks
use_conv_transpose
backbone
embed_dim
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
325    @staticmethod
326    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
327        """Compute the output size given input size and target long side length.
328
329        Args:
330            oldh: The input image height.
331            oldw: The input image width.
332            long_side_length: The longest side length for resizing.
333
334        Returns:
335            The new image height.
336            The new image width.
337        """
338        scale = long_side_length * 1.0 / max(oldh, oldw)
339        newh, neww = oldh * scale, oldw * scale
340        neww = int(neww + 0.5)
341        newh = int(newh + 0.5)
342        return (newh, neww)

Compute the output size given input size and target long side length.

Arguments:
  • oldh: The input image height.
  • oldw: The input image width.
  • long_side_length: The longest side length for resizing.
Returns:

The new image height. The new image width.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
344    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
345        """Resize the image so that the longest side has the correct length.
346
347        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
348
349        Args:
350            image: The input image.
351
352        Returns:
353            The resized image.
354        """
355        if image.ndim == 4:  # i.e. 2d image
356            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
357            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
358        elif image.ndim == 5:  # i.e. 3d volume
359            B, C, Z, H, W = image.shape
360            target_size = self.get_preprocess_shape(H, W, self.img_size)
361            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
362        else:
363            raise ValueError("Expected 4d or 5d inputs, got", image.shape)

Resize the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.

Arguments:
  • image: The input image.
Returns:

The resized image.

def preprocess_vit_inputs( x: torch.Tensor, use_sam_stats: bool = False, backbone: str = 'sam', use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, img_size: int = 1024, encoder_img_size: int = 1024, perform_range_checks: bool = True) -> Tuple[torch.Tensor, Tuple]:
421def preprocess_vit_inputs(
422    x: torch.Tensor,
423    use_sam_stats: bool = False,
424    backbone: str = "sam",
425    use_mae_stats: bool = False,
426    use_dino_stats: bool = False,
427    resize_input: bool = True,
428    img_size: int = 1024,
429    encoder_img_size: int = 1024,
430    perform_range_checks: bool = True,
431) -> Tuple[torch.Tensor, Tuple]:
432    """Preprocess inputs for ViT-backbones in UNETR models.
433
434    Handles normalization stat selection, input range validation, optional resizing to the longest side,
435    and padding to `encoder_img_size`. Can be used as a standalone function without a model instance.
436
437    Args:
438        x: Input tensor of shape (B, C, H, W) for 2D or (B, C, Z, H, W) for 3D.
439        use_sam_stats: Whether to normalize with SAM/SAM2/SAM3 backbone statistics.
440        backbone: The backbone name - controls which SAM stats are used when `use_sam_stats=True`.
441        use_mae_stats: Whether to normalize with MAE statistics.
442        use_dino_stats: Whether to normalize with DINOv2/DINOv3 statistics.
443        resize_input: Whether to resize the input to the longest side before padding.
444        img_size: The model image size, used for 3D resize.
445        encoder_img_size: The encoder image size, used for 2D resize and padding.
446        perform_range_checks: Whether to validate the expected input value range before normalization.
447            You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct.
448
449    Returns:
450        The preprocessed tensor and the spatial shape after resizing (before padding).
451    """
452    is_3d = (x.ndim == 5)
453    device, dtype = x.device, x.dtype
454    mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
455    expected_range = None
456    unit_scale_max = None
457
458    if use_sam_stats:
459        if backbone == "sam2":
460            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
461            expected_range = (0.0, 1.0)
462        elif backbone == "sam3":
463            mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
464            expected_range = (0.0, 1.0)
465        else:  # sam1 / default
466            mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
467            expected_range = (0.0, 255.0)
468            unit_scale_max = 1.0
469    elif use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
470        raise NotImplementedError
471    elif use_dino_stats:
472        mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
473        expected_range = (0.0, 1.0)
474    else:
475        mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
476        expected_range = None
477
478    if perform_range_checks:
479        _check_input_normalization_range(x, expected_range, unit_scale_max)
480    pixel_mean, pixel_std = _as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
481
482    if resize_input:
483        if x.ndim == 4:
484            target_size = UNETRBase.get_preprocess_shape(x.shape[2], x.shape[3], encoder_img_size)
485            x = F.interpolate(x, target_size, mode="bilinear", align_corners=False, antialias=True)
486        elif x.ndim == 5:
487            B, C, Z, H, W = x.shape
488            target_size = UNETRBase.get_preprocess_shape(H, W, img_size)
489            x = F.interpolate(x, (Z, *target_size), mode="trilinear", align_corners=False)
490
491    input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
492
493    x = (x - pixel_mean) / pixel_std
494    h, w = x.shape[-2:]
495    padh = encoder_img_size - h
496    padw = encoder_img_size - w
497
498    if is_3d:
499        x = F.pad(x, (0, padw, 0, padh, 0, 0))
500    else:
501        x = F.pad(x, (0, padw, 0, padh))
502
503    return x, input_shape

Preprocess inputs for ViT-backbones in UNETR models.

Handles normalization stat selection, input range validation, optional resizing to the longest side, and padding to encoder_img_size. Can be used as a standalone function without a model instance.

Arguments:
  • x: Input tensor of shape (B, C, H, W) for 2D or (B, C, Z, H, W) for 3D.
  • use_sam_stats: Whether to normalize with SAM/SAM2/SAM3 backbone statistics.
  • backbone: The backbone name - controls which SAM stats are used when use_sam_stats=True.
  • use_mae_stats: Whether to normalize with MAE statistics.
  • use_dino_stats: Whether to normalize with DINOv2/DINOv3 statistics.
  • resize_input: Whether to resize the input to the longest side before padding.
  • img_size: The model image size, used for 3D resize.
  • encoder_img_size: The encoder image size, used for 2D resize and padding.
  • perform_range_checks: Whether to validate the expected input value range before normalization. You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct.
Returns:

The preprocessed tensor and the spatial shape after resizing (before padding).

class UNETR(UNETRBase):
506class UNETR(UNETRBase):
507    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
508    """
509    def __init__(
510        self,
511        img_size: int = 1024,
512        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
513        encoder: Optional[Union[nn.Module, str]] = "vit_b",
514        decoder: Optional[nn.Module] = None,
515        out_channels: int = 1,
516        use_sam_stats: bool = False,
517        use_mae_stats: bool = False,
518        use_dino_stats: bool = False,
519        resize_input: bool = True,
520        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
521        final_activation: Optional[Union[str, nn.Module]] = None,
522        use_skip_connection: bool = True,
523        embed_dim: Optional[int] = None,
524        use_conv_transpose: bool = False,
525        perform_range_checks: bool = True,
526        **kwargs
527    ) -> None:
528
529        super().__init__(
530            img_size=img_size,
531            backbone=backbone,
532            encoder=encoder,
533            decoder=decoder,
534            out_channels=out_channels,
535            use_sam_stats=use_sam_stats,
536            use_mae_stats=use_mae_stats,
537            use_dino_stats=use_dino_stats,
538            resize_input=resize_input,
539            encoder_checkpoint=encoder_checkpoint,
540            final_activation=final_activation,
541            use_skip_connection=use_skip_connection,
542            embed_dim=embed_dim,
543            use_conv_transpose=use_conv_transpose,
544            perform_range_checks=perform_range_checks,
545            **kwargs,
546        )
547
548        encoder = self.encoder
549
550        if backbone == "sam2" and hasattr(encoder, "trunk"):
551            in_chans = encoder.trunk.patch_embed.proj.in_channels
552        elif hasattr(encoder, "in_chans"):
553            in_chans = encoder.in_chans
554        else:  # `nn.Module` ViT backbone.
555            try:
556                in_chans = encoder.patch_embed.proj.in_channels
557            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
558                in_chans = encoder.patch_embed.seq[0].c.in_channels
559
560        # parameters for the decoder network
561        depth = 3
562        initial_features = 64
563        gain = 2
564        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
565        scale_factors = depth * [2]
566        self.out_channels = out_channels
567
568        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
569        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
570
571        self.decoder = decoder or Decoder(
572            features=features_decoder,
573            scale_factors=scale_factors[::-1],
574            conv_block_impl=ConvBlock2d,
575            sampler_impl=_upsampler,
576        )
577
578        if use_skip_connection:
579            self.deconv1 = Deconv2DBlock(
580                in_channels=self.embed_dim,
581                out_channels=features_decoder[0],
582                use_conv_transpose=use_conv_transpose,
583            )
584            self.deconv2 = nn.Sequential(
585                Deconv2DBlock(
586                    in_channels=self.embed_dim,
587                    out_channels=features_decoder[0],
588                    use_conv_transpose=use_conv_transpose,
589                ),
590                Deconv2DBlock(
591                    in_channels=features_decoder[0],
592                    out_channels=features_decoder[1],
593                    use_conv_transpose=use_conv_transpose,
594                )
595            )
596            self.deconv3 = nn.Sequential(
597                Deconv2DBlock(
598                    in_channels=self.embed_dim,
599                    out_channels=features_decoder[0],
600                    use_conv_transpose=use_conv_transpose,
601                ),
602                Deconv2DBlock(
603                    in_channels=features_decoder[0],
604                    out_channels=features_decoder[1],
605                    use_conv_transpose=use_conv_transpose,
606                ),
607                Deconv2DBlock(
608                    in_channels=features_decoder[1],
609                    out_channels=features_decoder[2],
610                    use_conv_transpose=use_conv_transpose,
611                )
612            )
613            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
614        else:
615            self.deconv1 = Deconv2DBlock(
616                in_channels=self.embed_dim,
617                out_channels=features_decoder[0],
618                use_conv_transpose=use_conv_transpose,
619            )
620            self.deconv2 = Deconv2DBlock(
621                in_channels=features_decoder[0],
622                out_channels=features_decoder[1],
623                use_conv_transpose=use_conv_transpose,
624            )
625            self.deconv3 = Deconv2DBlock(
626                in_channels=features_decoder[1],
627                out_channels=features_decoder[2],
628                use_conv_transpose=use_conv_transpose,
629            )
630            self.deconv4 = Deconv2DBlock(
631                in_channels=features_decoder[2],
632                out_channels=features_decoder[3],
633                use_conv_transpose=use_conv_transpose,
634            )
635
636        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
637        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
638        self.deconv_out = _upsampler(
639            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
640        )
641        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
642
643    def forward(self, x: torch.Tensor) -> torch.Tensor:
644        """Apply the UNETR to the input data.
645
646        Args:
647            x: The input tensor.
648
649        Returns:
650            The UNETR output.
651        """
652        original_shape = x.shape[-2:]
653
654        # Reshape the inputs to the shape expected by the encoder
655        # and normalize the inputs if normalization is part of the model.
656        x, input_shape = self.preprocess(x)
657
658        encoder_outputs = self.encoder(x)
659
660        if isinstance(encoder_outputs[-1], list):
661            # `encoder_outputs` can be arranged in only two forms:
662            #   - either we only return the image embeddings
663            #   - or, we return the image embeddings and the "list" of global attention layers
664            z12, from_encoder = encoder_outputs
665        else:
666            z12 = encoder_outputs
667
668        if self.use_skip_connection:
669            from_encoder = from_encoder[::-1]
670            z9 = self.deconv1(from_encoder[0])
671            z6 = self.deconv2(from_encoder[1])
672            z3 = self.deconv3(from_encoder[2])
673            z0 = self.deconv4(x)
674
675        else:
676            z9 = self.deconv1(z12)
677            z6 = self.deconv2(z9)
678            z3 = self.deconv3(z6)
679            z0 = self.deconv4(z3)
680
681        updated_from_encoder = [z9, z6, z3]
682
683        x = self.base(z12)
684        x = self.decoder(x, encoder_inputs=updated_from_encoder)
685        x = self.deconv_out(x)
686
687        x = torch.cat([x, z0], dim=1)
688        x = self.decoder_head(x)
689
690        x = self.out_conv(x)
691        if self.final_activation is not None:
692            x = self.final_activation(x)
693
694        x = self.postprocess_masks(x, input_shape, original_shape)
695        return x

A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, perform_range_checks: bool = True, **kwargs)
509    def __init__(
510        self,
511        img_size: int = 1024,
512        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
513        encoder: Optional[Union[nn.Module, str]] = "vit_b",
514        decoder: Optional[nn.Module] = None,
515        out_channels: int = 1,
516        use_sam_stats: bool = False,
517        use_mae_stats: bool = False,
518        use_dino_stats: bool = False,
519        resize_input: bool = True,
520        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
521        final_activation: Optional[Union[str, nn.Module]] = None,
522        use_skip_connection: bool = True,
523        embed_dim: Optional[int] = None,
524        use_conv_transpose: bool = False,
525        perform_range_checks: bool = True,
526        **kwargs
527    ) -> None:
528
529        super().__init__(
530            img_size=img_size,
531            backbone=backbone,
532            encoder=encoder,
533            decoder=decoder,
534            out_channels=out_channels,
535            use_sam_stats=use_sam_stats,
536            use_mae_stats=use_mae_stats,
537            use_dino_stats=use_dino_stats,
538            resize_input=resize_input,
539            encoder_checkpoint=encoder_checkpoint,
540            final_activation=final_activation,
541            use_skip_connection=use_skip_connection,
542            embed_dim=embed_dim,
543            use_conv_transpose=use_conv_transpose,
544            perform_range_checks=perform_range_checks,
545            **kwargs,
546        )
547
548        encoder = self.encoder
549
550        if backbone == "sam2" and hasattr(encoder, "trunk"):
551            in_chans = encoder.trunk.patch_embed.proj.in_channels
552        elif hasattr(encoder, "in_chans"):
553            in_chans = encoder.in_chans
554        else:  # `nn.Module` ViT backbone.
555            try:
556                in_chans = encoder.patch_embed.proj.in_channels
557            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
558                in_chans = encoder.patch_embed.seq[0].c.in_channels
559
560        # parameters for the decoder network
561        depth = 3
562        initial_features = 64
563        gain = 2
564        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
565        scale_factors = depth * [2]
566        self.out_channels = out_channels
567
568        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
569        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
570
571        self.decoder = decoder or Decoder(
572            features=features_decoder,
573            scale_factors=scale_factors[::-1],
574            conv_block_impl=ConvBlock2d,
575            sampler_impl=_upsampler,
576        )
577
578        if use_skip_connection:
579            self.deconv1 = Deconv2DBlock(
580                in_channels=self.embed_dim,
581                out_channels=features_decoder[0],
582                use_conv_transpose=use_conv_transpose,
583            )
584            self.deconv2 = nn.Sequential(
585                Deconv2DBlock(
586                    in_channels=self.embed_dim,
587                    out_channels=features_decoder[0],
588                    use_conv_transpose=use_conv_transpose,
589                ),
590                Deconv2DBlock(
591                    in_channels=features_decoder[0],
592                    out_channels=features_decoder[1],
593                    use_conv_transpose=use_conv_transpose,
594                )
595            )
596            self.deconv3 = nn.Sequential(
597                Deconv2DBlock(
598                    in_channels=self.embed_dim,
599                    out_channels=features_decoder[0],
600                    use_conv_transpose=use_conv_transpose,
601                ),
602                Deconv2DBlock(
603                    in_channels=features_decoder[0],
604                    out_channels=features_decoder[1],
605                    use_conv_transpose=use_conv_transpose,
606                ),
607                Deconv2DBlock(
608                    in_channels=features_decoder[1],
609                    out_channels=features_decoder[2],
610                    use_conv_transpose=use_conv_transpose,
611                )
612            )
613            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
614        else:
615            self.deconv1 = Deconv2DBlock(
616                in_channels=self.embed_dim,
617                out_channels=features_decoder[0],
618                use_conv_transpose=use_conv_transpose,
619            )
620            self.deconv2 = Deconv2DBlock(
621                in_channels=features_decoder[0],
622                out_channels=features_decoder[1],
623                use_conv_transpose=use_conv_transpose,
624            )
625            self.deconv3 = Deconv2DBlock(
626                in_channels=features_decoder[1],
627                out_channels=features_decoder[2],
628                use_conv_transpose=use_conv_transpose,
629            )
630            self.deconv4 = Deconv2DBlock(
631                in_channels=features_decoder[2],
632                out_channels=features_decoder[3],
633                use_conv_transpose=use_conv_transpose,
634            )
635
636        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
637        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
638        self.deconv_out = _upsampler(
639            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
640        )
641        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
decoder
base
out_conv
deconv_out
decoder_head
def forward(self, x: torch.Tensor) -> torch.Tensor:
643    def forward(self, x: torch.Tensor) -> torch.Tensor:
644        """Apply the UNETR to the input data.
645
646        Args:
647            x: The input tensor.
648
649        Returns:
650            The UNETR output.
651        """
652        original_shape = x.shape[-2:]
653
654        # Reshape the inputs to the shape expected by the encoder
655        # and normalize the inputs if normalization is part of the model.
656        x, input_shape = self.preprocess(x)
657
658        encoder_outputs = self.encoder(x)
659
660        if isinstance(encoder_outputs[-1], list):
661            # `encoder_outputs` can be arranged in only two forms:
662            #   - either we only return the image embeddings
663            #   - or, we return the image embeddings and the "list" of global attention layers
664            z12, from_encoder = encoder_outputs
665        else:
666            z12 = encoder_outputs
667
668        if self.use_skip_connection:
669            from_encoder = from_encoder[::-1]
670            z9 = self.deconv1(from_encoder[0])
671            z6 = self.deconv2(from_encoder[1])
672            z3 = self.deconv3(from_encoder[2])
673            z0 = self.deconv4(x)
674
675        else:
676            z9 = self.deconv1(z12)
677            z6 = self.deconv2(z9)
678            z3 = self.deconv3(z6)
679            z0 = self.deconv4(z3)
680
681        updated_from_encoder = [z9, z6, z3]
682
683        x = self.base(z12)
684        x = self.decoder(x, encoder_inputs=updated_from_encoder)
685        x = self.deconv_out(x)
686
687        x = torch.cat([x, z0], dim=1)
688        x = self.decoder_head(x)
689
690        x = self.out_conv(x)
691        if self.final_activation is not None:
692            x = self.final_activation(x)
693
694        x = self.postprocess_masks(x, input_shape, original_shape)
695        return x

Apply the UNETR to the input data.

Arguments:
  • x: The input tensor.
Returns:

The UNETR output.

class UNETR2D(UNETR):
698class UNETR2D(UNETR):
699    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
700    """
701    pass

A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

class UNETR3D(UNETRBase):
704class UNETR3D(UNETRBase):
705    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
706    """
707    def __init__(
708        self,
709        img_size: int = 1024,
710        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
711        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
712        decoder: Optional[nn.Module] = None,
713        out_channels: int = 1,
714        use_sam_stats: bool = False,
715        use_mae_stats: bool = False,
716        use_dino_stats: bool = False,
717        resize_input: bool = True,
718        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
719        final_activation: Optional[Union[str, nn.Module]] = None,
720        use_skip_connection: bool = False,
721        embed_dim: Optional[int] = None,
722        use_conv_transpose: bool = False,
723        use_strip_pooling: bool = True,
724        perform_range_checks: bool = True,
725        **kwargs
726    ):
727        if use_skip_connection:
728            raise NotImplementedError("The framework cannot handle skip connections atm.")
729        if use_conv_transpose:
730            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
731
732        # Sort the `embed_dim` out
733        embed_dim = 256 if embed_dim is None else embed_dim
734
735        super().__init__(
736            img_size=img_size,
737            backbone=backbone,
738            encoder=encoder,
739            decoder=decoder,
740            out_channels=out_channels,
741            use_sam_stats=use_sam_stats,
742            use_mae_stats=use_mae_stats,
743            use_dino_stats=use_dino_stats,
744            resize_input=resize_input,
745            encoder_checkpoint=encoder_checkpoint,
746            final_activation=final_activation,
747            use_skip_connection=use_skip_connection,
748            embed_dim=embed_dim,
749            use_conv_transpose=use_conv_transpose,
750            perform_range_checks=perform_range_checks,
751            **kwargs,
752        )
753
754        # The 3d convolutional decoder.
755        # First, get the important parameters for the decoder.
756        depth = 3
757        initial_features = 64
758        gain = 2
759        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
760        scale_factors = [1, 2, 2]
761        self.out_channels = out_channels
762
763        # The mapping blocks.
764        self.deconv1 = Deconv3DBlock(
765            in_channels=embed_dim,
766            out_channels=features_decoder[0],
767            scale_factor=scale_factors,
768            use_strip_pooling=use_strip_pooling,
769        )
770        self.deconv2 = Deconv3DBlock(
771            in_channels=features_decoder[0],
772            out_channels=features_decoder[1],
773            scale_factor=scale_factors,
774            use_strip_pooling=use_strip_pooling,
775        )
776        self.deconv3 = Deconv3DBlock(
777            in_channels=features_decoder[1],
778            out_channels=features_decoder[2],
779            scale_factor=scale_factors,
780            use_strip_pooling=use_strip_pooling,
781        )
782        self.deconv4 = Deconv3DBlock(
783            in_channels=features_decoder[2],
784            out_channels=features_decoder[3],
785            scale_factor=scale_factors,
786            use_strip_pooling=use_strip_pooling,
787        )
788
789        # The core decoder block.
790        self.decoder = decoder or Decoder(
791            features=features_decoder,
792            scale_factors=[scale_factors] * depth,
793            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
794            sampler_impl=Upsampler3d,
795        )
796
797        # And the final upsampler to match the expected dimensions.
798        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
799            in_channels=features_decoder[-1],
800            out_channels=features_decoder[-1],
801            scale_factor=scale_factors,
802            use_strip_pooling=use_strip_pooling,
803        )
804
805        # Additional conjunction blocks.
806        self.base = ConvBlock3dWithStrip(
807            in_channels=embed_dim,
808            out_channels=features_decoder[0],
809            use_strip_pooling=use_strip_pooling,
810        )
811
812        # And the output layers.
813        self.decoder_head = ConvBlock3dWithStrip(
814            in_channels=2 * features_decoder[-1],
815            out_channels=features_decoder[-1],
816            use_strip_pooling=use_strip_pooling,
817        )
818        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
819
820    def forward(self, x: torch.Tensor):
821        """Forward pass of the UNETR-3D model.
822
823        Args:
824            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
825
826        Returns:
827            The UNETR output.
828        """
829        B, C, Z, H, W = x.shape
830        original_shape = (Z, H, W)
831
832        # Preprocessing step
833        x, input_shape = self.preprocess(x)
834
835        # Run the image encoder.
836        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
837
838        # Prepare the counterparts for the decoder.
839        # NOTE: The section below is sequential, there's no skip connections atm.
840        z9 = self.deconv1(curr_features)
841        z6 = self.deconv2(z9)
842        z3 = self.deconv3(z6)
843        z0 = self.deconv4(z3)
844
845        updated_from_encoder = [z9, z6, z3]
846
847        # Align the features through the base block.
848        x = self.base(curr_features)
849        # Run the decoder
850        x = self.decoder(x, encoder_inputs=updated_from_encoder)
851        x = self.deconv_out(x)  # NOTE before `end_up`
852
853        # And the final output head.
854        x = torch.cat([x, z0], dim=1)
855        x = self.decoder_head(x)
856        x = self.out_conv(x)
857        if self.final_activation is not None:
858            x = self.final_activation(x)
859
860        # Postprocess the output back to original size.
861        x = self.postprocess_masks(x, input_shape, original_shape)
862        return x

A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR3D( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'hvit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = False, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, use_strip_pooling: bool = True, perform_range_checks: bool = True, **kwargs)
707    def __init__(
708        self,
709        img_size: int = 1024,
710        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
711        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
712        decoder: Optional[nn.Module] = None,
713        out_channels: int = 1,
714        use_sam_stats: bool = False,
715        use_mae_stats: bool = False,
716        use_dino_stats: bool = False,
717        resize_input: bool = True,
718        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
719        final_activation: Optional[Union[str, nn.Module]] = None,
720        use_skip_connection: bool = False,
721        embed_dim: Optional[int] = None,
722        use_conv_transpose: bool = False,
723        use_strip_pooling: bool = True,
724        perform_range_checks: bool = True,
725        **kwargs
726    ):
727        if use_skip_connection:
728            raise NotImplementedError("The framework cannot handle skip connections atm.")
729        if use_conv_transpose:
730            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
731
732        # Sort the `embed_dim` out
733        embed_dim = 256 if embed_dim is None else embed_dim
734
735        super().__init__(
736            img_size=img_size,
737            backbone=backbone,
738            encoder=encoder,
739            decoder=decoder,
740            out_channels=out_channels,
741            use_sam_stats=use_sam_stats,
742            use_mae_stats=use_mae_stats,
743            use_dino_stats=use_dino_stats,
744            resize_input=resize_input,
745            encoder_checkpoint=encoder_checkpoint,
746            final_activation=final_activation,
747            use_skip_connection=use_skip_connection,
748            embed_dim=embed_dim,
749            use_conv_transpose=use_conv_transpose,
750            perform_range_checks=perform_range_checks,
751            **kwargs,
752        )
753
754        # The 3d convolutional decoder.
755        # First, get the important parameters for the decoder.
756        depth = 3
757        initial_features = 64
758        gain = 2
759        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
760        scale_factors = [1, 2, 2]
761        self.out_channels = out_channels
762
763        # The mapping blocks.
764        self.deconv1 = Deconv3DBlock(
765            in_channels=embed_dim,
766            out_channels=features_decoder[0],
767            scale_factor=scale_factors,
768            use_strip_pooling=use_strip_pooling,
769        )
770        self.deconv2 = Deconv3DBlock(
771            in_channels=features_decoder[0],
772            out_channels=features_decoder[1],
773            scale_factor=scale_factors,
774            use_strip_pooling=use_strip_pooling,
775        )
776        self.deconv3 = Deconv3DBlock(
777            in_channels=features_decoder[1],
778            out_channels=features_decoder[2],
779            scale_factor=scale_factors,
780            use_strip_pooling=use_strip_pooling,
781        )
782        self.deconv4 = Deconv3DBlock(
783            in_channels=features_decoder[2],
784            out_channels=features_decoder[3],
785            scale_factor=scale_factors,
786            use_strip_pooling=use_strip_pooling,
787        )
788
789        # The core decoder block.
790        self.decoder = decoder or Decoder(
791            features=features_decoder,
792            scale_factors=[scale_factors] * depth,
793            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
794            sampler_impl=Upsampler3d,
795        )
796
797        # And the final upsampler to match the expected dimensions.
798        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
799            in_channels=features_decoder[-1],
800            out_channels=features_decoder[-1],
801            scale_factor=scale_factors,
802            use_strip_pooling=use_strip_pooling,
803        )
804
805        # Additional conjunction blocks.
806        self.base = ConvBlock3dWithStrip(
807            in_channels=embed_dim,
808            out_channels=features_decoder[0],
809            use_strip_pooling=use_strip_pooling,
810        )
811
812        # And the output layers.
813        self.decoder_head = ConvBlock3dWithStrip(
814            in_channels=2 * features_decoder[-1],
815            out_channels=features_decoder[-1],
816            use_strip_pooling=use_strip_pooling,
817        )
818        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
deconv1
deconv2
deconv3
deconv4
decoder
deconv_out
base
decoder_head
out_conv
def forward(self, x: torch.Tensor):
820    def forward(self, x: torch.Tensor):
821        """Forward pass of the UNETR-3D model.
822
823        Args:
824            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
825
826        Returns:
827            The UNETR output.
828        """
829        B, C, Z, H, W = x.shape
830        original_shape = (Z, H, W)
831
832        # Preprocessing step
833        x, input_shape = self.preprocess(x)
834
835        # Run the image encoder.
836        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
837
838        # Prepare the counterparts for the decoder.
839        # NOTE: The section below is sequential, there's no skip connections atm.
840        z9 = self.deconv1(curr_features)
841        z6 = self.deconv2(z9)
842        z3 = self.deconv3(z6)
843        z0 = self.deconv4(z3)
844
845        updated_from_encoder = [z9, z6, z3]
846
847        # Align the features through the base block.
848        x = self.base(curr_features)
849        # Run the decoder
850        x = self.decoder(x, encoder_inputs=updated_from_encoder)
851        x = self.deconv_out(x)  # NOTE before `end_up`
852
853        # And the final output head.
854        x = torch.cat([x, z0], dim=1)
855        x = self.decoder_head(x)
856        x = self.out_conv(x)
857        if self.final_activation is not None:
858            x = self.final_activation(x)
859
860        # Postprocess the output back to original size.
861        x = self.postprocess_masks(x, input_shape, original_shape)
862        return x

Forward pass of the UNETR-3D model.

Arguments:
  • x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:

The UNETR output.