torch_em.model.unetr

  1from functools import partial
  2from collections import OrderedDict
  3from typing import Optional, Tuple, Union, Literal
  4
  5import torch
  6import torch.nn as nn
  7import torch.nn.functional as F
  8
  9from .vit import get_vision_transformer
 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs
 11
 12try:
 13    from micro_sam.util import get_sam_model
 14except ImportError:
 15    get_sam_model = None
 16
 17try:
 18    from micro_sam.v2.util import get_sam2_model
 19except ImportError:
 20    get_sam2_model = None
 21
 22try:
 23    from micro_sam3.util import get_sam3_model
 24except ImportError:
 25    get_sam3_model = None
 26
 27
 28#
 29# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / CellposeSAM / SAM2 / SAM3 / DINOv2 / DINOv3 / MAE / ScaleMAE) + UNet Decoder from `torch_em`]  # noqa
 30#
 31
 32
 33class UNETRBase(nn.Module):
 34    """Base class for implementing a UNETR.
 35
 36    Args:
 37        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 38        backbone: The name of the vision transformer implementation.
 39            One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"
 40            (see all combinations below)
 41        encoder: The vision transformer. Can either be a name, such as "vit_b"
 42            (see all combinations for this below) or a torch module.
 43        decoder: The convolutional decoder.
 44        out_channels: The number of output channels of the UNETR.
 45        use_sam_stats: Whether to normalize the input data with the statistics of the
 46            pretrained SAM / SAM2 / SAM3 model.
 47        use_dino_stats: Whether to normalize the input data with the statistics of the
 48            pretrained DINOv2 / DINOv3 model.
 49        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 50        resize_input: Whether to resize the input images to match `img_size`.
 51            By default, it resizes the inputs to match the `img_size`.
 52        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 53            Can either be a filepath or an already loaded checkpoint.
 54        final_activation: The activation to apply to the UNETR output.
 55        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 56        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 57        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 58            By default, it uses resampling for upsampling.
 59
 60        NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
 61
 62        SAM_family_models:
 63            - 'sam' x 'vit_b'
 64            - 'sam' x 'vit_l'
 65            - 'sam' x 'vit_h'
 66            - 'sam2' x 'hvit_t'
 67            - 'sam2' x 'hvit_s'
 68            - 'sam2' x 'hvit_b'
 69            - 'sam2' x 'hvit_l'
 70            - 'sam3' x 'vit_pe'
 71            - 'cellpose_sam' x 'vit_l'
 72
 73        DINO_family_models:
 74            - 'dinov2' x 'vit_s'
 75            - 'dinov2' x 'vit_b'
 76            - 'dinov2' x 'vit_l'
 77            - 'dinov2' x 'vit_g'
 78            - 'dinov2' x 'vit_s_reg4'
 79            - 'dinov2' x 'vit_b_reg4'
 80            - 'dinov2' x 'vit_l_reg4'
 81            - 'dinov2' x 'vit_g_reg4'
 82            - 'dinov3' x 'vit_s'
 83            - 'dinov3' x 'vit_s+'
 84            - 'dinov3' x 'vit_b'
 85            - 'dinov3' x 'vit_l'
 86            - 'dinov3' x 'vit_l+'
 87            - 'dinov3' x 'vit_h+'
 88            - 'dinov3' x 'vit_7b'
 89
 90        MAE_family_models:
 91            - 'mae' x 'vit_b'
 92            - 'mae' x 'vit_l'
 93            - 'mae' x 'vit_h'
 94            - 'scalemae' x 'vit_b'
 95            - 'scalemae' x 'vit_l'
 96            - 'scalemae' x 'vit_h'
 97    """
 98    def __init__(
 99        self,
100        img_size: int = 1024,
101        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
102        encoder: Optional[Union[nn.Module, str]] = "vit_b",
103        decoder: Optional[nn.Module] = None,
104        out_channels: int = 1,
105        use_sam_stats: bool = False,
106        use_mae_stats: bool = False,
107        use_dino_stats: bool = False,
108        resize_input: bool = True,
109        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
110        final_activation: Optional[Union[str, nn.Module]] = None,
111        use_skip_connection: bool = True,
112        embed_dim: Optional[int] = None,
113        use_conv_transpose: bool = False,
114        **kwargs
115    ) -> None:
116        super().__init__()
117
118        self.img_size = img_size
119        self.use_sam_stats = use_sam_stats
120        self.use_mae_stats = use_mae_stats
121        self.use_dino_stats = use_dino_stats
122        self.use_skip_connection = use_skip_connection
123        self.resize_input = resize_input
124        self.use_conv_transpose = use_conv_transpose
125        self.backbone = backbone
126
127        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
128            print(f"Using {encoder} from {backbone.upper()}")
129            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
130
131            if encoder_checkpoint is not None:
132                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
133
134            if embed_dim is None:
135                embed_dim = self.encoder.embed_dim
136
137            # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change.
138            if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck:
139                embed_dim = self.encoder.neck[2].out_channels  # the value is 256
140
141        else:  # `nn.Module` ViT backbone
142            self.encoder = encoder
143
144            have_neck = False
145            for name, _ in self.encoder.named_parameters():
146                if name.startswith("neck"):
147                    have_neck = True
148
149            if embed_dim is None:
150                if have_neck:
151                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
152                else:
153                    embed_dim = self.encoder.patch_embed.proj.out_channels
154
155        self.embed_dim = embed_dim
156        self.final_activation = self._get_activation(final_activation)
157
158    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
159        """Function to load pretrained weights to the image encoder.
160        """
161        if isinstance(checkpoint, str):
162            if backbone == "sam" and isinstance(encoder, str):
163                # If we have a SAM encoder, then we first try to load the full SAM Model
164                # (using micro_sam) and otherwise fall back on directly loading the encoder state
165                # from the checkpoint
166                try:
167                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
168                    encoder_state = model.image_encoder.state_dict()
169                except Exception:
170                    # Try loading the encoder state directly from a checkpoint.
171                    encoder_state = torch.load(checkpoint, weights_only=False)
172
173            elif backbone == "cellpose_sam" and isinstance(encoder, str):
174                # The architecture matches CellposeSAM exactly (same rel_pos sizes),
175                # so weights load directly without any interpolation.
176                encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False)
177                # Handle DataParallel/DistributedDataParallel prefix.
178                if any(k.startswith("module.") for k in encoder_state.keys()):
179                    encoder_state = OrderedDict(
180                        {k[len("module."):]: v for k, v in encoder_state.items()}
181                    )
182                # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix).
183                if any(k.startswith("encoder.") for k in encoder_state.keys()):
184                    encoder_state = OrderedDict(
185                        {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")}
186                    )
187
188            elif backbone == "sam2" and isinstance(encoder, str):
189                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
190                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
191                # from the checkpoint
192                try:
193                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
194                    encoder_state = model.image_encoder.state_dict()
195                except Exception:
196                    # Try loading the encoder state directly from a checkpoint.
197                    encoder_state = torch.load(checkpoint, weights_only=False)
198
199            elif backbone == "sam3" and isinstance(encoder, str):
200                # If we have a SAM3 encoder, then we first try to load the full SAM3 Model.
201                # (using micro_sam3) and otherwise fall back on directly loading the encoder state
202                # from the checkpoint
203                try:
204                    model = get_sam3_model(checkpoint_path=checkpoint)
205                    encoder_state = model.backbone.vision_backbone.state_dict()
206                    # Let's align loading the encoder weights with expected parameter names
207                    encoder_state = {
208                        k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items()
209                    }
210                    # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks.
211                    encoder_state = {
212                        k: v for k, v in encoder_state.items()
213                        if not (k.startswith("convs.") or k.startswith("sam2_convs."))
214                    }
215                except Exception:
216                    # Try loading the encoder state directly from a checkpoint.
217                    encoder_state = torch.load(checkpoint, weights_only=False)
218
219            elif backbone == "mae":
220                # vit initialization hints from:
221                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
222                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
223                encoder_state = OrderedDict({
224                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
225                })
226                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
227                current_encoder_state = self.encoder.state_dict()
228                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
229                    del self.encoder.head
230
231            elif backbone == "scalemae":
232                # Load the encoder state directly from a checkpoint.
233                encoder_state = torch.load(checkpoint)["model"]
234                encoder_state = OrderedDict({
235                    k: v for k, v in encoder_state.items()
236                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
237                })
238
239                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
240                current_encoder_state = self.encoder.state_dict()
241                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
242                    del self.encoder.head
243
244                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
245                    del self.encoder.pos_embed
246
247            elif backbone in ["dinov2", "dinov3"]:  # Load the encoder state directly from a checkpoint.
248                encoder_state = torch.load(checkpoint)
249
250            else:
251                raise ValueError(
252                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
253                )
254
255        else:
256            encoder_state = checkpoint
257
258        self.encoder.load_state_dict(encoder_state)
259
260    def _get_activation(self, activation):
261        return_activation = None
262        if activation is None:
263            return None
264        if isinstance(activation, nn.Module):
265            return activation
266        if isinstance(activation, str):
267            return_activation = getattr(nn, activation, None)
268        if return_activation is None:
269            raise ValueError(f"Invalid activation: {activation}")
270
271        return return_activation()
272
273    @staticmethod
274    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
275        """Compute the output size given input size and target long side length.
276
277        Args:
278            oldh: The input image height.
279            oldw: The input image width.
280            long_side_length: The longest side length for resizing.
281
282        Returns:
283            The new image height.
284            The new image width.
285        """
286        scale = long_side_length * 1.0 / max(oldh, oldw)
287        newh, neww = oldh * scale, oldw * scale
288        neww = int(neww + 0.5)
289        newh = int(newh + 0.5)
290        return (newh, neww)
291
292    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
293        """Resize the image so that the longest side has the correct length.
294
295        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
296
297        Args:
298            image: The input image.
299
300        Returns:
301            The resized image.
302        """
303        if image.ndim == 4:  # i.e. 2d image
304            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
305            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
306        elif image.ndim == 5:  # i.e. 3d volume
307            B, C, Z, H, W = image.shape
308            target_size = self.get_preprocess_shape(H, W, self.img_size)
309            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
310        else:
311            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
312
313    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
314        """@private
315        """
316        # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1).
317        view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
318        pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
319        pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
320        return pixel_mean, pixel_std
321
322    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
323        """@private
324        """
325        is_3d = (x.ndim == 5)
326        device, dtype = x.device, x.dtype
327
328        if self.use_sam_stats:
329            if self.backbone == "sam2":
330                mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
331            elif self.backbone == "sam3":
332                mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
333            else:  # sam1 / default
334                mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
335        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
336            raise NotImplementedError
337        elif self.use_dino_stats:
338            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
339        else:
340            mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
341
342        pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
343
344        if self.resize_input:
345            x = self.resize_longest_side(x)
346        input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
347
348        x = (x - pixel_mean) / pixel_std
349        h, w = x.shape[-2:]
350        padh = self.encoder.img_size - h
351        padw = self.encoder.img_size - w
352
353        if is_3d:
354            x = F.pad(x, (0, padw, 0, padh, 0, 0))
355        else:
356            x = F.pad(x, (0, padw, 0, padh))
357
358        return x, input_shape
359
360    def postprocess_masks(
361        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
362    ) -> torch.Tensor:
363        """@private
364        """
365        if masks.ndim == 4:  # i.e. 2d labels
366            masks = F.interpolate(
367                masks,
368                (self.encoder.img_size, self.encoder.img_size),
369                mode="bilinear",
370                align_corners=False,
371            )
372            masks = masks[..., : input_size[0], : input_size[1]]
373            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
374
375        elif masks.ndim == 5:  # i.e. 3d volumetric labels
376            masks = F.interpolate(
377                masks,
378                (input_size[0], self.img_size, self.img_size),
379                mode="trilinear",
380                align_corners=False,
381            )
382            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
383            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
384
385        else:
386            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
387
388        return masks
389
390
391class UNETR(UNETRBase):
392    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
393    """
394    def __init__(
395        self,
396        img_size: int = 1024,
397        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
398        encoder: Optional[Union[nn.Module, str]] = "vit_b",
399        decoder: Optional[nn.Module] = None,
400        out_channels: int = 1,
401        use_sam_stats: bool = False,
402        use_mae_stats: bool = False,
403        use_dino_stats: bool = False,
404        resize_input: bool = True,
405        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
406        final_activation: Optional[Union[str, nn.Module]] = None,
407        use_skip_connection: bool = True,
408        embed_dim: Optional[int] = None,
409        use_conv_transpose: bool = False,
410        **kwargs
411    ) -> None:
412
413        super().__init__(
414            img_size=img_size,
415            backbone=backbone,
416            encoder=encoder,
417            decoder=decoder,
418            out_channels=out_channels,
419            use_sam_stats=use_sam_stats,
420            use_mae_stats=use_mae_stats,
421            use_dino_stats=use_dino_stats,
422            resize_input=resize_input,
423            encoder_checkpoint=encoder_checkpoint,
424            final_activation=final_activation,
425            use_skip_connection=use_skip_connection,
426            embed_dim=embed_dim,
427            use_conv_transpose=use_conv_transpose,
428            **kwargs,
429        )
430
431        encoder = self.encoder
432
433        if backbone == "sam2" and hasattr(encoder, "trunk"):
434            in_chans = encoder.trunk.patch_embed.proj.in_channels
435        elif hasattr(encoder, "in_chans"):
436            in_chans = encoder.in_chans
437        else:  # `nn.Module` ViT backbone.
438            try:
439                in_chans = encoder.patch_embed.proj.in_channels
440            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
441                in_chans = encoder.patch_embed.seq[0].c.in_channels
442
443        # parameters for the decoder network
444        depth = 3
445        initial_features = 64
446        gain = 2
447        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
448        scale_factors = depth * [2]
449        self.out_channels = out_channels
450
451        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
452        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
453
454        self.decoder = decoder or Decoder(
455            features=features_decoder,
456            scale_factors=scale_factors[::-1],
457            conv_block_impl=ConvBlock2d,
458            sampler_impl=_upsampler,
459        )
460
461        if use_skip_connection:
462            self.deconv1 = Deconv2DBlock(
463                in_channels=self.embed_dim,
464                out_channels=features_decoder[0],
465                use_conv_transpose=use_conv_transpose,
466            )
467            self.deconv2 = nn.Sequential(
468                Deconv2DBlock(
469                    in_channels=self.embed_dim,
470                    out_channels=features_decoder[0],
471                    use_conv_transpose=use_conv_transpose,
472                ),
473                Deconv2DBlock(
474                    in_channels=features_decoder[0],
475                    out_channels=features_decoder[1],
476                    use_conv_transpose=use_conv_transpose,
477                )
478            )
479            self.deconv3 = nn.Sequential(
480                Deconv2DBlock(
481                    in_channels=self.embed_dim,
482                    out_channels=features_decoder[0],
483                    use_conv_transpose=use_conv_transpose,
484                ),
485                Deconv2DBlock(
486                    in_channels=features_decoder[0],
487                    out_channels=features_decoder[1],
488                    use_conv_transpose=use_conv_transpose,
489                ),
490                Deconv2DBlock(
491                    in_channels=features_decoder[1],
492                    out_channels=features_decoder[2],
493                    use_conv_transpose=use_conv_transpose,
494                )
495            )
496            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
497        else:
498            self.deconv1 = Deconv2DBlock(
499                in_channels=self.embed_dim,
500                out_channels=features_decoder[0],
501                use_conv_transpose=use_conv_transpose,
502            )
503            self.deconv2 = Deconv2DBlock(
504                in_channels=features_decoder[0],
505                out_channels=features_decoder[1],
506                use_conv_transpose=use_conv_transpose,
507            )
508            self.deconv3 = Deconv2DBlock(
509                in_channels=features_decoder[1],
510                out_channels=features_decoder[2],
511                use_conv_transpose=use_conv_transpose,
512            )
513            self.deconv4 = Deconv2DBlock(
514                in_channels=features_decoder[2],
515                out_channels=features_decoder[3],
516                use_conv_transpose=use_conv_transpose,
517            )
518
519        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
520        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
521        self.deconv_out = _upsampler(
522            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
523        )
524        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
525
526    def forward(self, x: torch.Tensor) -> torch.Tensor:
527        """Apply the UNETR to the input data.
528
529        Args:
530            x: The input tensor.
531
532        Returns:
533            The UNETR output.
534        """
535        original_shape = x.shape[-2:]
536
537        # Reshape the inputs to the shape expected by the encoder
538        # and normalize the inputs if normalization is part of the model.
539        x, input_shape = self.preprocess(x)
540
541        encoder_outputs = self.encoder(x)
542
543        if isinstance(encoder_outputs[-1], list):
544            # `encoder_outputs` can be arranged in only two forms:
545            #   - either we only return the image embeddings
546            #   - or, we return the image embeddings and the "list" of global attention layers
547            z12, from_encoder = encoder_outputs
548        else:
549            z12 = encoder_outputs
550
551        if self.use_skip_connection:
552            from_encoder = from_encoder[::-1]
553            z9 = self.deconv1(from_encoder[0])
554            z6 = self.deconv2(from_encoder[1])
555            z3 = self.deconv3(from_encoder[2])
556            z0 = self.deconv4(x)
557
558        else:
559            z9 = self.deconv1(z12)
560            z6 = self.deconv2(z9)
561            z3 = self.deconv3(z6)
562            z0 = self.deconv4(z3)
563
564        updated_from_encoder = [z9, z6, z3]
565
566        x = self.base(z12)
567        x = self.decoder(x, encoder_inputs=updated_from_encoder)
568        x = self.deconv_out(x)
569
570        x = torch.cat([x, z0], dim=1)
571        x = self.decoder_head(x)
572
573        x = self.out_conv(x)
574        if self.final_activation is not None:
575            x = self.final_activation(x)
576
577        x = self.postprocess_masks(x, input_shape, original_shape)
578        return x
579
580
581class UNETR2D(UNETR):
582    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
583    """
584    pass
585
586
587class UNETR3D(UNETRBase):
588    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
589    """
590    def __init__(
591        self,
592        img_size: int = 1024,
593        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
594        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
595        decoder: Optional[nn.Module] = None,
596        out_channels: int = 1,
597        use_sam_stats: bool = False,
598        use_mae_stats: bool = False,
599        use_dino_stats: bool = False,
600        resize_input: bool = True,
601        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
602        final_activation: Optional[Union[str, nn.Module]] = None,
603        use_skip_connection: bool = False,
604        embed_dim: Optional[int] = None,
605        use_conv_transpose: bool = False,
606        use_strip_pooling: bool = True,
607        **kwargs
608    ):
609        if use_skip_connection:
610            raise NotImplementedError("The framework cannot handle skip connections atm.")
611        if use_conv_transpose:
612            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
613
614        # Sort the `embed_dim` out
615        embed_dim = 256 if embed_dim is None else embed_dim
616
617        super().__init__(
618            img_size=img_size,
619            backbone=backbone,
620            encoder=encoder,
621            decoder=decoder,
622            out_channels=out_channels,
623            use_sam_stats=use_sam_stats,
624            use_mae_stats=use_mae_stats,
625            use_dino_stats=use_dino_stats,
626            resize_input=resize_input,
627            encoder_checkpoint=encoder_checkpoint,
628            final_activation=final_activation,
629            use_skip_connection=use_skip_connection,
630            embed_dim=embed_dim,
631            use_conv_transpose=use_conv_transpose,
632            **kwargs,
633        )
634
635        # The 3d convolutional decoder.
636        # First, get the important parameters for the decoder.
637        depth = 3
638        initial_features = 64
639        gain = 2
640        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
641        scale_factors = [1, 2, 2]
642        self.out_channels = out_channels
643
644        # The mapping blocks.
645        self.deconv1 = Deconv3DBlock(
646            in_channels=embed_dim,
647            out_channels=features_decoder[0],
648            scale_factor=scale_factors,
649            use_strip_pooling=use_strip_pooling,
650        )
651        self.deconv2 = Deconv3DBlock(
652            in_channels=features_decoder[0],
653            out_channels=features_decoder[1],
654            scale_factor=scale_factors,
655            use_strip_pooling=use_strip_pooling,
656        )
657        self.deconv3 = Deconv3DBlock(
658            in_channels=features_decoder[1],
659            out_channels=features_decoder[2],
660            scale_factor=scale_factors,
661            use_strip_pooling=use_strip_pooling,
662        )
663        self.deconv4 = Deconv3DBlock(
664            in_channels=features_decoder[2],
665            out_channels=features_decoder[3],
666            scale_factor=scale_factors,
667            use_strip_pooling=use_strip_pooling,
668        )
669
670        # The core decoder block.
671        self.decoder = decoder or Decoder(
672            features=features_decoder,
673            scale_factors=[scale_factors] * depth,
674            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
675            sampler_impl=Upsampler3d,
676        )
677
678        # And the final upsampler to match the expected dimensions.
679        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
680            in_channels=features_decoder[-1],
681            out_channels=features_decoder[-1],
682            scale_factor=scale_factors,
683            use_strip_pooling=use_strip_pooling,
684        )
685
686        # Additional conjunction blocks.
687        self.base = ConvBlock3dWithStrip(
688            in_channels=embed_dim,
689            out_channels=features_decoder[0],
690            use_strip_pooling=use_strip_pooling,
691        )
692
693        # And the output layers.
694        self.decoder_head = ConvBlock3dWithStrip(
695            in_channels=2 * features_decoder[-1],
696            out_channels=features_decoder[-1],
697            use_strip_pooling=use_strip_pooling,
698        )
699        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
700
701    def forward(self, x: torch.Tensor):
702        """Forward pass of the UNETR-3D model.
703
704        Args:
705            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
706
707        Returns:
708            The UNETR output.
709        """
710        B, C, Z, H, W = x.shape
711        original_shape = (Z, H, W)
712
713        # Preprocessing step
714        x, input_shape = self.preprocess(x)
715
716        # Run the image encoder.
717        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
718
719        # Prepare the counterparts for the decoder.
720        # NOTE: The section below is sequential, there's no skip connections atm.
721        z9 = self.deconv1(curr_features)
722        z6 = self.deconv2(z9)
723        z3 = self.deconv3(z6)
724        z0 = self.deconv4(z3)
725
726        updated_from_encoder = [z9, z6, z3]
727
728        # Align the features through the base block.
729        x = self.base(curr_features)
730        # Run the decoder
731        x = self.decoder(x, encoder_inputs=updated_from_encoder)
732        x = self.deconv_out(x)  # NOTE before `end_up`
733
734        # And the final output head.
735        x = torch.cat([x, z0], dim=1)
736        x = self.decoder_head(x)
737        x = self.out_conv(x)
738        if self.final_activation is not None:
739            x = self.final_activation(x)
740
741        # Postprocess the output back to original size.
742        x = self.postprocess_masks(x, input_shape, original_shape)
743        return x
744
745#
746#  ADDITIONAL FUNCTIONALITIES
747#
748
749
750def _strip_pooling_layers(enabled, channels) -> nn.Module:
751    return DepthStripPooling(channels) if enabled else nn.Identity()
752
753
754class DepthStripPooling(nn.Module):
755    """@private
756    """
757    def __init__(self, channels: int, reduction: int = 4):
758        """Block for strip pooling along the depth dimension (only).
759
760        eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling
761        to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to
762        modulate the original features (using a gated residual).
763
764        For 2D (Z == 1): returns input unchanged (no-op).
765
766        Args:
767            channels: The output channels.
768            reduction: The reduction of the hidden layers.
769        """
770        super().__init__()
771        hidden = max(1, channels // reduction)
772        self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1)
773        self.bn1 = nn.BatchNorm3d(hidden)
774        self.relu = nn.ReLU(inplace=True)
775        self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1)
776
777    def forward(self, x: torch.Tensor) -> torch.Tensor:
778        if x.dim() != 5:
779            raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.")
780
781        B, C, Z, H, W = x.shape
782        if Z == 1:  # i.e. always the case of all 2d.
783            return x  # We simply do nothing there.
784
785        # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W)
786        feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W))
787        feat = self.conv1(feat)
788        feat = self.bn1(feat)
789        feat = self.relu(feat)
790        feat = self.conv2(feat)
791        gate = torch.sigmoid(feat).expand(B, C, Z, H, W)  # Broadcast the collapsed depth context back to all slices
792
793        # Gated residual fusion
794        return x * gate + x
795
796
797class Deconv3DBlock(nn.Module):
798    """@private
799    """
800    def __init__(
801        self,
802        scale_factor,
803        in_channels,
804        out_channels,
805        kernel_size=3,
806        anisotropic_kernel=True,
807        use_strip_pooling=True,
808    ):
809        super().__init__()
810        conv_block_kwargs = {
811            "in_channels": out_channels,
812            "out_channels": out_channels,
813            "kernel_size": kernel_size,
814            "padding": ((kernel_size - 1) // 2),
815        }
816        if anisotropic_kernel:
817            conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor)
818
819        self.block = nn.Sequential(
820            Upsampler3d(scale_factor, in_channels, out_channels),
821            nn.Conv3d(**conv_block_kwargs),
822            nn.BatchNorm3d(out_channels),
823            nn.ReLU(True),
824            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
825        )
826
827    def forward(self, x):
828        return self.block(x)
829
830
831class ConvBlock3dWithStrip(nn.Module):
832    """@private
833    """
834    def __init__(
835        self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs
836    ):
837        super().__init__()
838        self.block = nn.Sequential(
839            ConvBlock3d(in_channels, out_channels, **kwargs),
840            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
841        )
842
843    def forward(self, x):
844        return self.block(x)
845
846
847class SingleDeconv2DBlock(nn.Module):
848    """@private
849    """
850    def __init__(self, scale_factor, in_channels, out_channels):
851        super().__init__()
852        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
853
854    def forward(self, x):
855        return self.block(x)
856
857
858class SingleConv2DBlock(nn.Module):
859    """@private
860    """
861    def __init__(self, in_channels, out_channels, kernel_size):
862        super().__init__()
863        self.block = nn.Conv2d(
864            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
865        )
866
867    def forward(self, x):
868        return self.block(x)
869
870
871class Conv2DBlock(nn.Module):
872    """@private
873    """
874    def __init__(self, in_channels, out_channels, kernel_size=3):
875        super().__init__()
876        self.block = nn.Sequential(
877            SingleConv2DBlock(in_channels, out_channels, kernel_size),
878            nn.BatchNorm2d(out_channels),
879            nn.ReLU(True)
880        )
881
882    def forward(self, x):
883        return self.block(x)
884
885
886class Deconv2DBlock(nn.Module):
887    """@private
888    """
889    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
890        super().__init__()
891        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
892        self.block = nn.Sequential(
893            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
894            SingleConv2DBlock(out_channels, out_channels, kernel_size),
895            nn.BatchNorm2d(out_channels),
896            nn.ReLU(True)
897        )
898
899    def forward(self, x):
900        return self.block(x)
class UNETRBase(torch.nn.modules.module.Module):
 34class UNETRBase(nn.Module):
 35    """Base class for implementing a UNETR.
 36
 37    Args:
 38        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 39        backbone: The name of the vision transformer implementation.
 40            One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"
 41            (see all combinations below)
 42        encoder: The vision transformer. Can either be a name, such as "vit_b"
 43            (see all combinations for this below) or a torch module.
 44        decoder: The convolutional decoder.
 45        out_channels: The number of output channels of the UNETR.
 46        use_sam_stats: Whether to normalize the input data with the statistics of the
 47            pretrained SAM / SAM2 / SAM3 model.
 48        use_dino_stats: Whether to normalize the input data with the statistics of the
 49            pretrained DINOv2 / DINOv3 model.
 50        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 51        resize_input: Whether to resize the input images to match `img_size`.
 52            By default, it resizes the inputs to match the `img_size`.
 53        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 54            Can either be a filepath or an already loaded checkpoint.
 55        final_activation: The activation to apply to the UNETR output.
 56        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 57        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 58        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 59            By default, it uses resampling for upsampling.
 60
 61        NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
 62
 63        SAM_family_models:
 64            - 'sam' x 'vit_b'
 65            - 'sam' x 'vit_l'
 66            - 'sam' x 'vit_h'
 67            - 'sam2' x 'hvit_t'
 68            - 'sam2' x 'hvit_s'
 69            - 'sam2' x 'hvit_b'
 70            - 'sam2' x 'hvit_l'
 71            - 'sam3' x 'vit_pe'
 72            - 'cellpose_sam' x 'vit_l'
 73
 74        DINO_family_models:
 75            - 'dinov2' x 'vit_s'
 76            - 'dinov2' x 'vit_b'
 77            - 'dinov2' x 'vit_l'
 78            - 'dinov2' x 'vit_g'
 79            - 'dinov2' x 'vit_s_reg4'
 80            - 'dinov2' x 'vit_b_reg4'
 81            - 'dinov2' x 'vit_l_reg4'
 82            - 'dinov2' x 'vit_g_reg4'
 83            - 'dinov3' x 'vit_s'
 84            - 'dinov3' x 'vit_s+'
 85            - 'dinov3' x 'vit_b'
 86            - 'dinov3' x 'vit_l'
 87            - 'dinov3' x 'vit_l+'
 88            - 'dinov3' x 'vit_h+'
 89            - 'dinov3' x 'vit_7b'
 90
 91        MAE_family_models:
 92            - 'mae' x 'vit_b'
 93            - 'mae' x 'vit_l'
 94            - 'mae' x 'vit_h'
 95            - 'scalemae' x 'vit_b'
 96            - 'scalemae' x 'vit_l'
 97            - 'scalemae' x 'vit_h'
 98    """
 99    def __init__(
100        self,
101        img_size: int = 1024,
102        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
103        encoder: Optional[Union[nn.Module, str]] = "vit_b",
104        decoder: Optional[nn.Module] = None,
105        out_channels: int = 1,
106        use_sam_stats: bool = False,
107        use_mae_stats: bool = False,
108        use_dino_stats: bool = False,
109        resize_input: bool = True,
110        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
111        final_activation: Optional[Union[str, nn.Module]] = None,
112        use_skip_connection: bool = True,
113        embed_dim: Optional[int] = None,
114        use_conv_transpose: bool = False,
115        **kwargs
116    ) -> None:
117        super().__init__()
118
119        self.img_size = img_size
120        self.use_sam_stats = use_sam_stats
121        self.use_mae_stats = use_mae_stats
122        self.use_dino_stats = use_dino_stats
123        self.use_skip_connection = use_skip_connection
124        self.resize_input = resize_input
125        self.use_conv_transpose = use_conv_transpose
126        self.backbone = backbone
127
128        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
129            print(f"Using {encoder} from {backbone.upper()}")
130            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
131
132            if encoder_checkpoint is not None:
133                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
134
135            if embed_dim is None:
136                embed_dim = self.encoder.embed_dim
137
138            # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change.
139            if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck:
140                embed_dim = self.encoder.neck[2].out_channels  # the value is 256
141
142        else:  # `nn.Module` ViT backbone
143            self.encoder = encoder
144
145            have_neck = False
146            for name, _ in self.encoder.named_parameters():
147                if name.startswith("neck"):
148                    have_neck = True
149
150            if embed_dim is None:
151                if have_neck:
152                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
153                else:
154                    embed_dim = self.encoder.patch_embed.proj.out_channels
155
156        self.embed_dim = embed_dim
157        self.final_activation = self._get_activation(final_activation)
158
159    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
160        """Function to load pretrained weights to the image encoder.
161        """
162        if isinstance(checkpoint, str):
163            if backbone == "sam" and isinstance(encoder, str):
164                # If we have a SAM encoder, then we first try to load the full SAM Model
165                # (using micro_sam) and otherwise fall back on directly loading the encoder state
166                # from the checkpoint
167                try:
168                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
169                    encoder_state = model.image_encoder.state_dict()
170                except Exception:
171                    # Try loading the encoder state directly from a checkpoint.
172                    encoder_state = torch.load(checkpoint, weights_only=False)
173
174            elif backbone == "cellpose_sam" and isinstance(encoder, str):
175                # The architecture matches CellposeSAM exactly (same rel_pos sizes),
176                # so weights load directly without any interpolation.
177                encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False)
178                # Handle DataParallel/DistributedDataParallel prefix.
179                if any(k.startswith("module.") for k in encoder_state.keys()):
180                    encoder_state = OrderedDict(
181                        {k[len("module."):]: v for k, v in encoder_state.items()}
182                    )
183                # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix).
184                if any(k.startswith("encoder.") for k in encoder_state.keys()):
185                    encoder_state = OrderedDict(
186                        {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")}
187                    )
188
189            elif backbone == "sam2" and isinstance(encoder, str):
190                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
191                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
192                # from the checkpoint
193                try:
194                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
195                    encoder_state = model.image_encoder.state_dict()
196                except Exception:
197                    # Try loading the encoder state directly from a checkpoint.
198                    encoder_state = torch.load(checkpoint, weights_only=False)
199
200            elif backbone == "sam3" and isinstance(encoder, str):
201                # If we have a SAM3 encoder, then we first try to load the full SAM3 Model.
202                # (using micro_sam3) and otherwise fall back on directly loading the encoder state
203                # from the checkpoint
204                try:
205                    model = get_sam3_model(checkpoint_path=checkpoint)
206                    encoder_state = model.backbone.vision_backbone.state_dict()
207                    # Let's align loading the encoder weights with expected parameter names
208                    encoder_state = {
209                        k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items()
210                    }
211                    # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks.
212                    encoder_state = {
213                        k: v for k, v in encoder_state.items()
214                        if not (k.startswith("convs.") or k.startswith("sam2_convs."))
215                    }
216                except Exception:
217                    # Try loading the encoder state directly from a checkpoint.
218                    encoder_state = torch.load(checkpoint, weights_only=False)
219
220            elif backbone == "mae":
221                # vit initialization hints from:
222                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
223                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
224                encoder_state = OrderedDict({
225                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
226                })
227                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
228                current_encoder_state = self.encoder.state_dict()
229                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
230                    del self.encoder.head
231
232            elif backbone == "scalemae":
233                # Load the encoder state directly from a checkpoint.
234                encoder_state = torch.load(checkpoint)["model"]
235                encoder_state = OrderedDict({
236                    k: v for k, v in encoder_state.items()
237                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
238                })
239
240                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
241                current_encoder_state = self.encoder.state_dict()
242                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
243                    del self.encoder.head
244
245                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
246                    del self.encoder.pos_embed
247
248            elif backbone in ["dinov2", "dinov3"]:  # Load the encoder state directly from a checkpoint.
249                encoder_state = torch.load(checkpoint)
250
251            else:
252                raise ValueError(
253                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
254                )
255
256        else:
257            encoder_state = checkpoint
258
259        self.encoder.load_state_dict(encoder_state)
260
261    def _get_activation(self, activation):
262        return_activation = None
263        if activation is None:
264            return None
265        if isinstance(activation, nn.Module):
266            return activation
267        if isinstance(activation, str):
268            return_activation = getattr(nn, activation, None)
269        if return_activation is None:
270            raise ValueError(f"Invalid activation: {activation}")
271
272        return return_activation()
273
274    @staticmethod
275    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
276        """Compute the output size given input size and target long side length.
277
278        Args:
279            oldh: The input image height.
280            oldw: The input image width.
281            long_side_length: The longest side length for resizing.
282
283        Returns:
284            The new image height.
285            The new image width.
286        """
287        scale = long_side_length * 1.0 / max(oldh, oldw)
288        newh, neww = oldh * scale, oldw * scale
289        neww = int(neww + 0.5)
290        newh = int(newh + 0.5)
291        return (newh, neww)
292
293    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
294        """Resize the image so that the longest side has the correct length.
295
296        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
297
298        Args:
299            image: The input image.
300
301        Returns:
302            The resized image.
303        """
304        if image.ndim == 4:  # i.e. 2d image
305            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
306            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
307        elif image.ndim == 5:  # i.e. 3d volume
308            B, C, Z, H, W = image.shape
309            target_size = self.get_preprocess_shape(H, W, self.img_size)
310            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
311        else:
312            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
313
314    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
315        """@private
316        """
317        # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1).
318        view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
319        pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
320        pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
321        return pixel_mean, pixel_std
322
323    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
324        """@private
325        """
326        is_3d = (x.ndim == 5)
327        device, dtype = x.device, x.dtype
328
329        if self.use_sam_stats:
330            if self.backbone == "sam2":
331                mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
332            elif self.backbone == "sam3":
333                mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
334            else:  # sam1 / default
335                mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
336        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
337            raise NotImplementedError
338        elif self.use_dino_stats:
339            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
340        else:
341            mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
342
343        pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
344
345        if self.resize_input:
346            x = self.resize_longest_side(x)
347        input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
348
349        x = (x - pixel_mean) / pixel_std
350        h, w = x.shape[-2:]
351        padh = self.encoder.img_size - h
352        padw = self.encoder.img_size - w
353
354        if is_3d:
355            x = F.pad(x, (0, padw, 0, padh, 0, 0))
356        else:
357            x = F.pad(x, (0, padw, 0, padh))
358
359        return x, input_shape
360
361    def postprocess_masks(
362        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
363    ) -> torch.Tensor:
364        """@private
365        """
366        if masks.ndim == 4:  # i.e. 2d labels
367            masks = F.interpolate(
368                masks,
369                (self.encoder.img_size, self.encoder.img_size),
370                mode="bilinear",
371                align_corners=False,
372            )
373            masks = masks[..., : input_size[0], : input_size[1]]
374            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
375
376        elif masks.ndim == 5:  # i.e. 3d volumetric labels
377            masks = F.interpolate(
378                masks,
379                (input_size[0], self.img_size, self.img_size),
380                mode="trilinear",
381                align_corners=False,
382            )
383            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
384            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
385
386        else:
387            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
388
389        return masks

Base class for implementing a UNETR.

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • backbone: The name of the vision transformer implementation. One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
  • encoder: The vision transformer. Can either be a name, such as "vit_b" (see all combinations for this below) or a torch module.
  • decoder: The convolutional decoder.
  • out_channels: The number of output channels of the UNETR.
  • use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM / SAM2 / SAM3 model.
  • use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv2 / DINOv3 model.
  • use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  • resize_input: Whether to resize the input images to match img_size. By default, it resizes the inputs to match the img_size.
  • encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
  • final_activation: The activation to apply to the UNETR output.
  • use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
  • embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
  • use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
  • NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
  • SAM_family_models: - 'sam' x 'vit_b'
    • 'sam' x 'vit_l'
    • 'sam' x 'vit_h'
    • 'sam2' x 'hvit_t'
    • 'sam2' x 'hvit_s'
    • 'sam2' x 'hvit_b'
    • 'sam2' x 'hvit_l'
    • 'sam3' x 'vit_pe'
    • 'cellpose_sam' x 'vit_l'
  • DINO_family_models: - 'dinov2' x 'vit_s'
    • 'dinov2' x 'vit_b'
    • 'dinov2' x 'vit_l'
    • 'dinov2' x 'vit_g'
    • 'dinov2' x 'vit_s_reg4'
    • 'dinov2' x 'vit_b_reg4'
    • 'dinov2' x 'vit_l_reg4'
    • 'dinov2' x 'vit_g_reg4'
    • 'dinov3' x 'vit_s'
    • 'dinov3' x 'vit_s+'
    • 'dinov3' x 'vit_b'
    • 'dinov3' x 'vit_l'
    • 'dinov3' x 'vit_l+'
    • 'dinov3' x 'vit_h+'
    • 'dinov3' x 'vit_7b'
  • MAE_family_models: - 'mae' x 'vit_b'
    • 'mae' x 'vit_l'
    • 'mae' x 'vit_h'
    • 'scalemae' x 'vit_b'
    • 'scalemae' x 'vit_l'
    • 'scalemae' x 'vit_h'
UNETRBase( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
 99    def __init__(
100        self,
101        img_size: int = 1024,
102        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
103        encoder: Optional[Union[nn.Module, str]] = "vit_b",
104        decoder: Optional[nn.Module] = None,
105        out_channels: int = 1,
106        use_sam_stats: bool = False,
107        use_mae_stats: bool = False,
108        use_dino_stats: bool = False,
109        resize_input: bool = True,
110        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
111        final_activation: Optional[Union[str, nn.Module]] = None,
112        use_skip_connection: bool = True,
113        embed_dim: Optional[int] = None,
114        use_conv_transpose: bool = False,
115        **kwargs
116    ) -> None:
117        super().__init__()
118
119        self.img_size = img_size
120        self.use_sam_stats = use_sam_stats
121        self.use_mae_stats = use_mae_stats
122        self.use_dino_stats = use_dino_stats
123        self.use_skip_connection = use_skip_connection
124        self.resize_input = resize_input
125        self.use_conv_transpose = use_conv_transpose
126        self.backbone = backbone
127
128        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
129            print(f"Using {encoder} from {backbone.upper()}")
130            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
131
132            if encoder_checkpoint is not None:
133                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
134
135            if embed_dim is None:
136                embed_dim = self.encoder.embed_dim
137
138            # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change.
139            if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck:
140                embed_dim = self.encoder.neck[2].out_channels  # the value is 256
141
142        else:  # `nn.Module` ViT backbone
143            self.encoder = encoder
144
145            have_neck = False
146            for name, _ in self.encoder.named_parameters():
147                if name.startswith("neck"):
148                    have_neck = True
149
150            if embed_dim is None:
151                if have_neck:
152                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
153                else:
154                    embed_dim = self.encoder.patch_embed.proj.out_channels
155
156        self.embed_dim = embed_dim
157        self.final_activation = self._get_activation(final_activation)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

img_size
use_sam_stats
use_mae_stats
use_dino_stats
use_skip_connection
resize_input
use_conv_transpose
backbone
embed_dim
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
274    @staticmethod
275    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
276        """Compute the output size given input size and target long side length.
277
278        Args:
279            oldh: The input image height.
280            oldw: The input image width.
281            long_side_length: The longest side length for resizing.
282
283        Returns:
284            The new image height.
285            The new image width.
286        """
287        scale = long_side_length * 1.0 / max(oldh, oldw)
288        newh, neww = oldh * scale, oldw * scale
289        neww = int(neww + 0.5)
290        newh = int(newh + 0.5)
291        return (newh, neww)

Compute the output size given input size and target long side length.

Arguments:
  • oldh: The input image height.
  • oldw: The input image width.
  • long_side_length: The longest side length for resizing.
Returns:

The new image height. The new image width.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
293    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
294        """Resize the image so that the longest side has the correct length.
295
296        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
297
298        Args:
299            image: The input image.
300
301        Returns:
302            The resized image.
303        """
304        if image.ndim == 4:  # i.e. 2d image
305            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
306            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
307        elif image.ndim == 5:  # i.e. 3d volume
308            B, C, Z, H, W = image.shape
309            target_size = self.get_preprocess_shape(H, W, self.img_size)
310            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
311        else:
312            raise ValueError("Expected 4d or 5d inputs, got", image.shape)

Resize the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.

Arguments:
  • image: The input image.
Returns:

The resized image.

class UNETR(UNETRBase):
392class UNETR(UNETRBase):
393    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
394    """
395    def __init__(
396        self,
397        img_size: int = 1024,
398        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
399        encoder: Optional[Union[nn.Module, str]] = "vit_b",
400        decoder: Optional[nn.Module] = None,
401        out_channels: int = 1,
402        use_sam_stats: bool = False,
403        use_mae_stats: bool = False,
404        use_dino_stats: bool = False,
405        resize_input: bool = True,
406        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
407        final_activation: Optional[Union[str, nn.Module]] = None,
408        use_skip_connection: bool = True,
409        embed_dim: Optional[int] = None,
410        use_conv_transpose: bool = False,
411        **kwargs
412    ) -> None:
413
414        super().__init__(
415            img_size=img_size,
416            backbone=backbone,
417            encoder=encoder,
418            decoder=decoder,
419            out_channels=out_channels,
420            use_sam_stats=use_sam_stats,
421            use_mae_stats=use_mae_stats,
422            use_dino_stats=use_dino_stats,
423            resize_input=resize_input,
424            encoder_checkpoint=encoder_checkpoint,
425            final_activation=final_activation,
426            use_skip_connection=use_skip_connection,
427            embed_dim=embed_dim,
428            use_conv_transpose=use_conv_transpose,
429            **kwargs,
430        )
431
432        encoder = self.encoder
433
434        if backbone == "sam2" and hasattr(encoder, "trunk"):
435            in_chans = encoder.trunk.patch_embed.proj.in_channels
436        elif hasattr(encoder, "in_chans"):
437            in_chans = encoder.in_chans
438        else:  # `nn.Module` ViT backbone.
439            try:
440                in_chans = encoder.patch_embed.proj.in_channels
441            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
442                in_chans = encoder.patch_embed.seq[0].c.in_channels
443
444        # parameters for the decoder network
445        depth = 3
446        initial_features = 64
447        gain = 2
448        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
449        scale_factors = depth * [2]
450        self.out_channels = out_channels
451
452        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
453        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
454
455        self.decoder = decoder or Decoder(
456            features=features_decoder,
457            scale_factors=scale_factors[::-1],
458            conv_block_impl=ConvBlock2d,
459            sampler_impl=_upsampler,
460        )
461
462        if use_skip_connection:
463            self.deconv1 = Deconv2DBlock(
464                in_channels=self.embed_dim,
465                out_channels=features_decoder[0],
466                use_conv_transpose=use_conv_transpose,
467            )
468            self.deconv2 = nn.Sequential(
469                Deconv2DBlock(
470                    in_channels=self.embed_dim,
471                    out_channels=features_decoder[0],
472                    use_conv_transpose=use_conv_transpose,
473                ),
474                Deconv2DBlock(
475                    in_channels=features_decoder[0],
476                    out_channels=features_decoder[1],
477                    use_conv_transpose=use_conv_transpose,
478                )
479            )
480            self.deconv3 = nn.Sequential(
481                Deconv2DBlock(
482                    in_channels=self.embed_dim,
483                    out_channels=features_decoder[0],
484                    use_conv_transpose=use_conv_transpose,
485                ),
486                Deconv2DBlock(
487                    in_channels=features_decoder[0],
488                    out_channels=features_decoder[1],
489                    use_conv_transpose=use_conv_transpose,
490                ),
491                Deconv2DBlock(
492                    in_channels=features_decoder[1],
493                    out_channels=features_decoder[2],
494                    use_conv_transpose=use_conv_transpose,
495                )
496            )
497            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
498        else:
499            self.deconv1 = Deconv2DBlock(
500                in_channels=self.embed_dim,
501                out_channels=features_decoder[0],
502                use_conv_transpose=use_conv_transpose,
503            )
504            self.deconv2 = Deconv2DBlock(
505                in_channels=features_decoder[0],
506                out_channels=features_decoder[1],
507                use_conv_transpose=use_conv_transpose,
508            )
509            self.deconv3 = Deconv2DBlock(
510                in_channels=features_decoder[1],
511                out_channels=features_decoder[2],
512                use_conv_transpose=use_conv_transpose,
513            )
514            self.deconv4 = Deconv2DBlock(
515                in_channels=features_decoder[2],
516                out_channels=features_decoder[3],
517                use_conv_transpose=use_conv_transpose,
518            )
519
520        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
521        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
522        self.deconv_out = _upsampler(
523            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
524        )
525        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
526
527    def forward(self, x: torch.Tensor) -> torch.Tensor:
528        """Apply the UNETR to the input data.
529
530        Args:
531            x: The input tensor.
532
533        Returns:
534            The UNETR output.
535        """
536        original_shape = x.shape[-2:]
537
538        # Reshape the inputs to the shape expected by the encoder
539        # and normalize the inputs if normalization is part of the model.
540        x, input_shape = self.preprocess(x)
541
542        encoder_outputs = self.encoder(x)
543
544        if isinstance(encoder_outputs[-1], list):
545            # `encoder_outputs` can be arranged in only two forms:
546            #   - either we only return the image embeddings
547            #   - or, we return the image embeddings and the "list" of global attention layers
548            z12, from_encoder = encoder_outputs
549        else:
550            z12 = encoder_outputs
551
552        if self.use_skip_connection:
553            from_encoder = from_encoder[::-1]
554            z9 = self.deconv1(from_encoder[0])
555            z6 = self.deconv2(from_encoder[1])
556            z3 = self.deconv3(from_encoder[2])
557            z0 = self.deconv4(x)
558
559        else:
560            z9 = self.deconv1(z12)
561            z6 = self.deconv2(z9)
562            z3 = self.deconv3(z6)
563            z0 = self.deconv4(z3)
564
565        updated_from_encoder = [z9, z6, z3]
566
567        x = self.base(z12)
568        x = self.decoder(x, encoder_inputs=updated_from_encoder)
569        x = self.deconv_out(x)
570
571        x = torch.cat([x, z0], dim=1)
572        x = self.decoder_head(x)
573
574        x = self.out_conv(x)
575        if self.final_activation is not None:
576            x = self.final_activation(x)
577
578        x = self.postprocess_masks(x, input_shape, original_shape)
579        return x

A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
395    def __init__(
396        self,
397        img_size: int = 1024,
398        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
399        encoder: Optional[Union[nn.Module, str]] = "vit_b",
400        decoder: Optional[nn.Module] = None,
401        out_channels: int = 1,
402        use_sam_stats: bool = False,
403        use_mae_stats: bool = False,
404        use_dino_stats: bool = False,
405        resize_input: bool = True,
406        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
407        final_activation: Optional[Union[str, nn.Module]] = None,
408        use_skip_connection: bool = True,
409        embed_dim: Optional[int] = None,
410        use_conv_transpose: bool = False,
411        **kwargs
412    ) -> None:
413
414        super().__init__(
415            img_size=img_size,
416            backbone=backbone,
417            encoder=encoder,
418            decoder=decoder,
419            out_channels=out_channels,
420            use_sam_stats=use_sam_stats,
421            use_mae_stats=use_mae_stats,
422            use_dino_stats=use_dino_stats,
423            resize_input=resize_input,
424            encoder_checkpoint=encoder_checkpoint,
425            final_activation=final_activation,
426            use_skip_connection=use_skip_connection,
427            embed_dim=embed_dim,
428            use_conv_transpose=use_conv_transpose,
429            **kwargs,
430        )
431
432        encoder = self.encoder
433
434        if backbone == "sam2" and hasattr(encoder, "trunk"):
435            in_chans = encoder.trunk.patch_embed.proj.in_channels
436        elif hasattr(encoder, "in_chans"):
437            in_chans = encoder.in_chans
438        else:  # `nn.Module` ViT backbone.
439            try:
440                in_chans = encoder.patch_embed.proj.in_channels
441            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
442                in_chans = encoder.patch_embed.seq[0].c.in_channels
443
444        # parameters for the decoder network
445        depth = 3
446        initial_features = 64
447        gain = 2
448        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
449        scale_factors = depth * [2]
450        self.out_channels = out_channels
451
452        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
453        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
454
455        self.decoder = decoder or Decoder(
456            features=features_decoder,
457            scale_factors=scale_factors[::-1],
458            conv_block_impl=ConvBlock2d,
459            sampler_impl=_upsampler,
460        )
461
462        if use_skip_connection:
463            self.deconv1 = Deconv2DBlock(
464                in_channels=self.embed_dim,
465                out_channels=features_decoder[0],
466                use_conv_transpose=use_conv_transpose,
467            )
468            self.deconv2 = nn.Sequential(
469                Deconv2DBlock(
470                    in_channels=self.embed_dim,
471                    out_channels=features_decoder[0],
472                    use_conv_transpose=use_conv_transpose,
473                ),
474                Deconv2DBlock(
475                    in_channels=features_decoder[0],
476                    out_channels=features_decoder[1],
477                    use_conv_transpose=use_conv_transpose,
478                )
479            )
480            self.deconv3 = nn.Sequential(
481                Deconv2DBlock(
482                    in_channels=self.embed_dim,
483                    out_channels=features_decoder[0],
484                    use_conv_transpose=use_conv_transpose,
485                ),
486                Deconv2DBlock(
487                    in_channels=features_decoder[0],
488                    out_channels=features_decoder[1],
489                    use_conv_transpose=use_conv_transpose,
490                ),
491                Deconv2DBlock(
492                    in_channels=features_decoder[1],
493                    out_channels=features_decoder[2],
494                    use_conv_transpose=use_conv_transpose,
495                )
496            )
497            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
498        else:
499            self.deconv1 = Deconv2DBlock(
500                in_channels=self.embed_dim,
501                out_channels=features_decoder[0],
502                use_conv_transpose=use_conv_transpose,
503            )
504            self.deconv2 = Deconv2DBlock(
505                in_channels=features_decoder[0],
506                out_channels=features_decoder[1],
507                use_conv_transpose=use_conv_transpose,
508            )
509            self.deconv3 = Deconv2DBlock(
510                in_channels=features_decoder[1],
511                out_channels=features_decoder[2],
512                use_conv_transpose=use_conv_transpose,
513            )
514            self.deconv4 = Deconv2DBlock(
515                in_channels=features_decoder[2],
516                out_channels=features_decoder[3],
517                use_conv_transpose=use_conv_transpose,
518            )
519
520        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
521        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
522        self.deconv_out = _upsampler(
523            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
524        )
525        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
decoder
base
out_conv
deconv_out
decoder_head
def forward(self, x: torch.Tensor) -> torch.Tensor:
527    def forward(self, x: torch.Tensor) -> torch.Tensor:
528        """Apply the UNETR to the input data.
529
530        Args:
531            x: The input tensor.
532
533        Returns:
534            The UNETR output.
535        """
536        original_shape = x.shape[-2:]
537
538        # Reshape the inputs to the shape expected by the encoder
539        # and normalize the inputs if normalization is part of the model.
540        x, input_shape = self.preprocess(x)
541
542        encoder_outputs = self.encoder(x)
543
544        if isinstance(encoder_outputs[-1], list):
545            # `encoder_outputs` can be arranged in only two forms:
546            #   - either we only return the image embeddings
547            #   - or, we return the image embeddings and the "list" of global attention layers
548            z12, from_encoder = encoder_outputs
549        else:
550            z12 = encoder_outputs
551
552        if self.use_skip_connection:
553            from_encoder = from_encoder[::-1]
554            z9 = self.deconv1(from_encoder[0])
555            z6 = self.deconv2(from_encoder[1])
556            z3 = self.deconv3(from_encoder[2])
557            z0 = self.deconv4(x)
558
559        else:
560            z9 = self.deconv1(z12)
561            z6 = self.deconv2(z9)
562            z3 = self.deconv3(z6)
563            z0 = self.deconv4(z3)
564
565        updated_from_encoder = [z9, z6, z3]
566
567        x = self.base(z12)
568        x = self.decoder(x, encoder_inputs=updated_from_encoder)
569        x = self.deconv_out(x)
570
571        x = torch.cat([x, z0], dim=1)
572        x = self.decoder_head(x)
573
574        x = self.out_conv(x)
575        if self.final_activation is not None:
576            x = self.final_activation(x)
577
578        x = self.postprocess_masks(x, input_shape, original_shape)
579        return x

Apply the UNETR to the input data.

Arguments:
  • x: The input tensor.
Returns:

The UNETR output.

class UNETR2D(UNETR):
582class UNETR2D(UNETR):
583    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
584    """
585    pass

A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

class UNETR3D(UNETRBase):
588class UNETR3D(UNETRBase):
589    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
590    """
591    def __init__(
592        self,
593        img_size: int = 1024,
594        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
595        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
596        decoder: Optional[nn.Module] = None,
597        out_channels: int = 1,
598        use_sam_stats: bool = False,
599        use_mae_stats: bool = False,
600        use_dino_stats: bool = False,
601        resize_input: bool = True,
602        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
603        final_activation: Optional[Union[str, nn.Module]] = None,
604        use_skip_connection: bool = False,
605        embed_dim: Optional[int] = None,
606        use_conv_transpose: bool = False,
607        use_strip_pooling: bool = True,
608        **kwargs
609    ):
610        if use_skip_connection:
611            raise NotImplementedError("The framework cannot handle skip connections atm.")
612        if use_conv_transpose:
613            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
614
615        # Sort the `embed_dim` out
616        embed_dim = 256 if embed_dim is None else embed_dim
617
618        super().__init__(
619            img_size=img_size,
620            backbone=backbone,
621            encoder=encoder,
622            decoder=decoder,
623            out_channels=out_channels,
624            use_sam_stats=use_sam_stats,
625            use_mae_stats=use_mae_stats,
626            use_dino_stats=use_dino_stats,
627            resize_input=resize_input,
628            encoder_checkpoint=encoder_checkpoint,
629            final_activation=final_activation,
630            use_skip_connection=use_skip_connection,
631            embed_dim=embed_dim,
632            use_conv_transpose=use_conv_transpose,
633            **kwargs,
634        )
635
636        # The 3d convolutional decoder.
637        # First, get the important parameters for the decoder.
638        depth = 3
639        initial_features = 64
640        gain = 2
641        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
642        scale_factors = [1, 2, 2]
643        self.out_channels = out_channels
644
645        # The mapping blocks.
646        self.deconv1 = Deconv3DBlock(
647            in_channels=embed_dim,
648            out_channels=features_decoder[0],
649            scale_factor=scale_factors,
650            use_strip_pooling=use_strip_pooling,
651        )
652        self.deconv2 = Deconv3DBlock(
653            in_channels=features_decoder[0],
654            out_channels=features_decoder[1],
655            scale_factor=scale_factors,
656            use_strip_pooling=use_strip_pooling,
657        )
658        self.deconv3 = Deconv3DBlock(
659            in_channels=features_decoder[1],
660            out_channels=features_decoder[2],
661            scale_factor=scale_factors,
662            use_strip_pooling=use_strip_pooling,
663        )
664        self.deconv4 = Deconv3DBlock(
665            in_channels=features_decoder[2],
666            out_channels=features_decoder[3],
667            scale_factor=scale_factors,
668            use_strip_pooling=use_strip_pooling,
669        )
670
671        # The core decoder block.
672        self.decoder = decoder or Decoder(
673            features=features_decoder,
674            scale_factors=[scale_factors] * depth,
675            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
676            sampler_impl=Upsampler3d,
677        )
678
679        # And the final upsampler to match the expected dimensions.
680        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
681            in_channels=features_decoder[-1],
682            out_channels=features_decoder[-1],
683            scale_factor=scale_factors,
684            use_strip_pooling=use_strip_pooling,
685        )
686
687        # Additional conjunction blocks.
688        self.base = ConvBlock3dWithStrip(
689            in_channels=embed_dim,
690            out_channels=features_decoder[0],
691            use_strip_pooling=use_strip_pooling,
692        )
693
694        # And the output layers.
695        self.decoder_head = ConvBlock3dWithStrip(
696            in_channels=2 * features_decoder[-1],
697            out_channels=features_decoder[-1],
698            use_strip_pooling=use_strip_pooling,
699        )
700        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
701
702    def forward(self, x: torch.Tensor):
703        """Forward pass of the UNETR-3D model.
704
705        Args:
706            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
707
708        Returns:
709            The UNETR output.
710        """
711        B, C, Z, H, W = x.shape
712        original_shape = (Z, H, W)
713
714        # Preprocessing step
715        x, input_shape = self.preprocess(x)
716
717        # Run the image encoder.
718        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
719
720        # Prepare the counterparts for the decoder.
721        # NOTE: The section below is sequential, there's no skip connections atm.
722        z9 = self.deconv1(curr_features)
723        z6 = self.deconv2(z9)
724        z3 = self.deconv3(z6)
725        z0 = self.deconv4(z3)
726
727        updated_from_encoder = [z9, z6, z3]
728
729        # Align the features through the base block.
730        x = self.base(curr_features)
731        # Run the decoder
732        x = self.decoder(x, encoder_inputs=updated_from_encoder)
733        x = self.deconv_out(x)  # NOTE before `end_up`
734
735        # And the final output head.
736        x = torch.cat([x, z0], dim=1)
737        x = self.decoder_head(x)
738        x = self.out_conv(x)
739        if self.final_activation is not None:
740            x = self.final_activation(x)
741
742        # Postprocess the output back to original size.
743        x = self.postprocess_masks(x, input_shape, original_shape)
744        return x

A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR3D( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'hvit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = False, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, use_strip_pooling: bool = True, **kwargs)
591    def __init__(
592        self,
593        img_size: int = 1024,
594        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
595        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
596        decoder: Optional[nn.Module] = None,
597        out_channels: int = 1,
598        use_sam_stats: bool = False,
599        use_mae_stats: bool = False,
600        use_dino_stats: bool = False,
601        resize_input: bool = True,
602        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
603        final_activation: Optional[Union[str, nn.Module]] = None,
604        use_skip_connection: bool = False,
605        embed_dim: Optional[int] = None,
606        use_conv_transpose: bool = False,
607        use_strip_pooling: bool = True,
608        **kwargs
609    ):
610        if use_skip_connection:
611            raise NotImplementedError("The framework cannot handle skip connections atm.")
612        if use_conv_transpose:
613            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
614
615        # Sort the `embed_dim` out
616        embed_dim = 256 if embed_dim is None else embed_dim
617
618        super().__init__(
619            img_size=img_size,
620            backbone=backbone,
621            encoder=encoder,
622            decoder=decoder,
623            out_channels=out_channels,
624            use_sam_stats=use_sam_stats,
625            use_mae_stats=use_mae_stats,
626            use_dino_stats=use_dino_stats,
627            resize_input=resize_input,
628            encoder_checkpoint=encoder_checkpoint,
629            final_activation=final_activation,
630            use_skip_connection=use_skip_connection,
631            embed_dim=embed_dim,
632            use_conv_transpose=use_conv_transpose,
633            **kwargs,
634        )
635
636        # The 3d convolutional decoder.
637        # First, get the important parameters for the decoder.
638        depth = 3
639        initial_features = 64
640        gain = 2
641        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
642        scale_factors = [1, 2, 2]
643        self.out_channels = out_channels
644
645        # The mapping blocks.
646        self.deconv1 = Deconv3DBlock(
647            in_channels=embed_dim,
648            out_channels=features_decoder[0],
649            scale_factor=scale_factors,
650            use_strip_pooling=use_strip_pooling,
651        )
652        self.deconv2 = Deconv3DBlock(
653            in_channels=features_decoder[0],
654            out_channels=features_decoder[1],
655            scale_factor=scale_factors,
656            use_strip_pooling=use_strip_pooling,
657        )
658        self.deconv3 = Deconv3DBlock(
659            in_channels=features_decoder[1],
660            out_channels=features_decoder[2],
661            scale_factor=scale_factors,
662            use_strip_pooling=use_strip_pooling,
663        )
664        self.deconv4 = Deconv3DBlock(
665            in_channels=features_decoder[2],
666            out_channels=features_decoder[3],
667            scale_factor=scale_factors,
668            use_strip_pooling=use_strip_pooling,
669        )
670
671        # The core decoder block.
672        self.decoder = decoder or Decoder(
673            features=features_decoder,
674            scale_factors=[scale_factors] * depth,
675            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
676            sampler_impl=Upsampler3d,
677        )
678
679        # And the final upsampler to match the expected dimensions.
680        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
681            in_channels=features_decoder[-1],
682            out_channels=features_decoder[-1],
683            scale_factor=scale_factors,
684            use_strip_pooling=use_strip_pooling,
685        )
686
687        # Additional conjunction blocks.
688        self.base = ConvBlock3dWithStrip(
689            in_channels=embed_dim,
690            out_channels=features_decoder[0],
691            use_strip_pooling=use_strip_pooling,
692        )
693
694        # And the output layers.
695        self.decoder_head = ConvBlock3dWithStrip(
696            in_channels=2 * features_decoder[-1],
697            out_channels=features_decoder[-1],
698            use_strip_pooling=use_strip_pooling,
699        )
700        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
deconv1
deconv2
deconv3
deconv4
decoder
deconv_out
base
decoder_head
out_conv
def forward(self, x: torch.Tensor):
702    def forward(self, x: torch.Tensor):
703        """Forward pass of the UNETR-3D model.
704
705        Args:
706            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
707
708        Returns:
709            The UNETR output.
710        """
711        B, C, Z, H, W = x.shape
712        original_shape = (Z, H, W)
713
714        # Preprocessing step
715        x, input_shape = self.preprocess(x)
716
717        # Run the image encoder.
718        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
719
720        # Prepare the counterparts for the decoder.
721        # NOTE: The section below is sequential, there's no skip connections atm.
722        z9 = self.deconv1(curr_features)
723        z6 = self.deconv2(z9)
724        z3 = self.deconv3(z6)
725        z0 = self.deconv4(z3)
726
727        updated_from_encoder = [z9, z6, z3]
728
729        # Align the features through the base block.
730        x = self.base(curr_features)
731        # Run the decoder
732        x = self.decoder(x, encoder_inputs=updated_from_encoder)
733        x = self.deconv_out(x)  # NOTE before `end_up`
734
735        # And the final output head.
736        x = torch.cat([x, z0], dim=1)
737        x = self.decoder_head(x)
738        x = self.out_conv(x)
739        if self.final_activation is not None:
740            x = self.final_activation(x)
741
742        # Postprocess the output back to original size.
743        x = self.postprocess_masks(x, input_shape, original_shape)
744        return x

Forward pass of the UNETR-3D model.

Arguments:
  • x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:

The UNETR output.