torch_em.model.unetr

  1from functools import partial
  2from collections import OrderedDict
  3from typing import Optional, Tuple, Union, Literal
  4
  5import torch
  6import torch.nn as nn
  7import torch.nn.functional as F
  8
  9from .vit import get_vision_transformer
 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs
 11
 12try:
 13    from micro_sam.util import get_sam_model
 14except ImportError:
 15    get_sam_model = None
 16
 17try:
 18    from micro_sam.v2.util import get_sam2_model
 19except ImportError:
 20    get_sam2_model = None
 21
 22try:
 23    from micro_sam3.util import get_sam3_model
 24except ImportError:
 25    get_sam3_model = None
 26
 27
 28#
 29# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / CellposeSAM / SAM2 / SAM3 / DINOv2 / DINOv3 / MAE / ScaleMAE) + UNet Decoder from `torch_em`]  # noqa
 30#
 31
 32
 33class UNETRBase(nn.Module):
 34    """Base class for implementing a UNETR.
 35
 36    Args:
 37        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 38        backbone: The name of the vision transformer implementation.
 39            One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"
 40            (see all combinations below)
 41        encoder: The vision transformer. Can either be a name, such as "vit_b"
 42            (see all combinations for this below) or a torch module.
 43        decoder: The convolutional decoder.
 44        out_channels: The number of output channels of the UNETR.
 45        use_sam_stats: Whether to normalize the input data with the statistics of the
 46            pretrained SAM / SAM2 / SAM3 model.
 47        use_dino_stats: Whether to normalize the input data with the statistics of the
 48            pretrained DINOv2 / DINOv3 model.
 49        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 50        resize_input: Whether to resize the input images to match `img_size`.
 51            By default, it resizes the inputs to match the `img_size`.
 52        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 53            Can either be a filepath or an already loaded checkpoint.
 54        final_activation: The activation to apply to the UNETR output.
 55        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 56        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 57        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 58            By default, it uses resampling for upsampling.
 59
 60        NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
 61
 62        SAM_family_models:
 63            - 'sam' x 'vit_b'
 64            - 'sam' x 'vit_l'
 65            - 'sam' x 'vit_h'
 66            - 'sam2' x 'hvit_t'
 67            - 'sam2' x 'hvit_s'
 68            - 'sam2' x 'hvit_b'
 69            - 'sam2' x 'hvit_l'
 70            - 'sam3' x 'vit_pe'
 71            - 'cellpose_sam' x 'vit_l'
 72
 73        DINO_family_models:
 74            - 'dinov2' x 'vit_s'
 75            - 'dinov2' x 'vit_b'
 76            - 'dinov2' x 'vit_l'
 77            - 'dinov2' x 'vit_g'
 78            - 'dinov2' x 'vit_s_reg4'
 79            - 'dinov2' x 'vit_b_reg4'
 80            - 'dinov2' x 'vit_l_reg4'
 81            - 'dinov2' x 'vit_g_reg4'
 82            - 'dinov3' x 'vit_s'
 83            - 'dinov3' x 'vit_s+'
 84            - 'dinov3' x 'vit_b'
 85            - 'dinov3' x 'vit_l'
 86            - 'dinov3' x 'vit_l+'
 87            - 'dinov3' x 'vit_h+'
 88            - 'dinov3' x 'vit_7b'
 89
 90        MAE_family_models:
 91            - 'mae' x 'vit_b'
 92            - 'mae' x 'vit_l'
 93            - 'mae' x 'vit_h'
 94            - 'scalemae' x 'vit_b'
 95            - 'scalemae' x 'vit_l'
 96            - 'scalemae' x 'vit_h'
 97    """
 98    def __init__(
 99        self,
100        img_size: int = 1024,
101        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
102        encoder: Optional[Union[nn.Module, str]] = "vit_b",
103        decoder: Optional[nn.Module] = None,
104        out_channels: int = 1,
105        use_sam_stats: bool = False,
106        use_mae_stats: bool = False,
107        use_dino_stats: bool = False,
108        resize_input: bool = True,
109        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
110        final_activation: Optional[Union[str, nn.Module]] = None,
111        use_skip_connection: bool = True,
112        embed_dim: Optional[int] = None,
113        use_conv_transpose: bool = False,
114        **kwargs
115    ) -> None:
116        super().__init__()
117
118        self.img_size = img_size
119        self.use_sam_stats = use_sam_stats
120        self.use_mae_stats = use_mae_stats
121        self.use_dino_stats = use_dino_stats
122        self.use_skip_connection = use_skip_connection
123        self.resize_input = resize_input
124        self.use_conv_transpose = use_conv_transpose
125        self.backbone = backbone
126
127        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
128            print(f"Using {encoder} from {backbone.upper()}")
129            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
130
131            if encoder_checkpoint is not None:
132                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
133
134            if embed_dim is None:
135                embed_dim = self.encoder.embed_dim
136
137            # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change.
138            if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck:
139                embed_dim = self.encoder.neck[2].out_channels  # the value is 256
140
141        else:  # `nn.Module` ViT backbone
142            self.encoder = encoder
143
144            have_neck = False
145            for name, _ in self.encoder.named_parameters():
146                if name.startswith("neck"):
147                    have_neck = True
148
149            if embed_dim is None:
150                if have_neck:
151                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
152                else:
153                    embed_dim = self.encoder.patch_embed.proj.out_channels
154
155        self.embed_dim = embed_dim
156        self.final_activation = self._get_activation(final_activation)
157
158    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
159        """Function to load pretrained weights to the image encoder.
160        """
161        if isinstance(checkpoint, str):
162            if backbone == "sam" and isinstance(encoder, str):
163                # If we have a SAM encoder, then we first try to load the full SAM Model
164                # (using micro_sam) and otherwise fall back on directly loading the encoder state
165                # from the checkpoint
166                try:
167                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
168                    encoder_state = model.image_encoder.state_dict()
169                except Exception:
170                    # Try loading the encoder state directly from a checkpoint.
171                    encoder_state = torch.load(checkpoint, weights_only=False)
172
173            elif backbone == "cellpose_sam" and isinstance(encoder, str):
174                # The architecture matches CellposeSAM exactly (same rel_pos sizes),
175                # so weights load directly without any interpolation.
176                encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False)
177                # Handle DataParallel/DistributedDataParallel prefix.
178                if any(k.startswith("module.") for k in encoder_state.keys()):
179                    encoder_state = OrderedDict(
180                        {k[len("module."):]: v for k, v in encoder_state.items()}
181                    )
182                # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix).
183                if any(k.startswith("encoder.") for k in encoder_state.keys()):
184                    encoder_state = OrderedDict(
185                        {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")}
186                    )
187
188            elif backbone == "sam2" and isinstance(encoder, str):
189                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
190                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
191                # from the checkpoint
192                try:
193                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
194                    encoder_state = model.image_encoder.state_dict()
195                except Exception:
196                    # Try loading the encoder state directly from a checkpoint.
197                    encoder_state = torch.load(checkpoint, weights_only=False)
198
199            elif backbone == "sam3" and isinstance(encoder, str):
200                # If we have a SAM3 encoder, then we first try to load the full SAM3 Model.
201                # (using micro_sam3) and otherwise fall back on directly loading the encoder state
202                # from the checkpoint
203                try:
204                    model = get_sam3_model(checkpoint_path=checkpoint)
205                    encoder_state = model.backbone.vision_backbone.state_dict()
206                    # Let's align loading the encoder weights with expected parameter names
207                    encoder_state = {
208                        k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items()
209                    }
210                    # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks.
211                    encoder_state = {
212                        k: v for k, v in encoder_state.items()
213                        if not (k.startswith("convs.") or k.startswith("sam2_convs."))
214                    }
215                except Exception:
216                    # Try loading the encoder state directly from a checkpoint.
217                    encoder_state = torch.load(checkpoint, weights_only=False)
218
219            elif backbone == "mae":
220                # vit initialization hints from:
221                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
222                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
223                encoder_state = OrderedDict({
224                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
225                })
226                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
227                current_encoder_state = self.encoder.state_dict()
228                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
229                    del self.encoder.head
230
231            elif backbone == "scalemae":
232                # Load the encoder state directly from a checkpoint.
233                encoder_state = torch.load(checkpoint)["model"]
234                encoder_state = OrderedDict({
235                    k: v for k, v in encoder_state.items()
236                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
237                })
238
239                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
240                current_encoder_state = self.encoder.state_dict()
241                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
242                    del self.encoder.head
243
244                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
245                    del self.encoder.pos_embed
246
247            elif backbone in ["dinov2", "dinov3"]:  # Load the encoder state directly from a checkpoint.
248                encoder_state = torch.load(checkpoint)
249
250            else:
251                raise ValueError(
252                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
253                )
254
255        else:
256            encoder_state = checkpoint
257
258        self.encoder.load_state_dict(encoder_state)
259
260    def _get_activation(self, activation):
261        return_activation = None
262        if activation is None:
263            return None
264        if isinstance(activation, nn.Module):
265            return activation
266        if isinstance(activation, str):
267            return_activation = getattr(nn, activation, None)
268        if return_activation is None:
269            raise ValueError(f"Invalid activation: {activation}")
270
271        return return_activation()
272
273    @staticmethod
274    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
275        """Compute the output size given input size and target long side length.
276
277        Args:
278            oldh: The input image height.
279            oldw: The input image width.
280            long_side_length: The longest side length for resizing.
281
282        Returns:
283            The new image height.
284            The new image width.
285        """
286        scale = long_side_length * 1.0 / max(oldh, oldw)
287        newh, neww = oldh * scale, oldw * scale
288        neww = int(neww + 0.5)
289        newh = int(newh + 0.5)
290        return (newh, neww)
291
292    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
293        """Resize the image so that the longest side has the correct length.
294
295        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
296
297        Args:
298            image: The input image.
299
300        Returns:
301            The resized image.
302        """
303        if image.ndim == 4:  # i.e. 2d image
304            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
305            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
306        elif image.ndim == 5:  # i.e. 3d volume
307            B, C, Z, H, W = image.shape
308            target_size = self.get_preprocess_shape(H, W, self.img_size)
309            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
310        else:
311            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
312
313    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
314        """@private
315        """
316        # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1).
317        view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
318        pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
319        pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
320        return pixel_mean, pixel_std
321
322    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
323        """@private
324        """
325        device = x.device
326        is_3d = (x.ndim == 5)
327        device, dtype = x.device, x.dtype
328
329        if self.use_sam_stats:
330            mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
331        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
332            raise NotImplementedError
333        elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"):
334            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
335        elif self.use_sam_stats and self.backbone == "sam3":
336            mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
337        else:
338            mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
339
340        pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
341
342        if self.resize_input:
343            x = self.resize_longest_side(x)
344        input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
345
346        x = (x - pixel_mean) / pixel_std
347        h, w = x.shape[-2:]
348        padh = self.encoder.img_size - h
349        padw = self.encoder.img_size - w
350
351        if is_3d:
352            x = F.pad(x, (0, padw, 0, padh, 0, 0))
353        else:
354            x = F.pad(x, (0, padw, 0, padh))
355
356        return x, input_shape
357
358    def postprocess_masks(
359        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
360    ) -> torch.Tensor:
361        """@private
362        """
363        if masks.ndim == 4:  # i.e. 2d labels
364            masks = F.interpolate(
365                masks,
366                (self.encoder.img_size, self.encoder.img_size),
367                mode="bilinear",
368                align_corners=False,
369            )
370            masks = masks[..., : input_size[0], : input_size[1]]
371            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
372
373        elif masks.ndim == 5:  # i.e. 3d volumetric labels
374            masks = F.interpolate(
375                masks,
376                (input_size[0], self.img_size, self.img_size),
377                mode="trilinear",
378                align_corners=False,
379            )
380            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
381            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
382
383        else:
384            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
385
386        return masks
387
388
389class UNETR(UNETRBase):
390    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
391    """
392    def __init__(
393        self,
394        img_size: int = 1024,
395        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
396        encoder: Optional[Union[nn.Module, str]] = "vit_b",
397        decoder: Optional[nn.Module] = None,
398        out_channels: int = 1,
399        use_sam_stats: bool = False,
400        use_mae_stats: bool = False,
401        use_dino_stats: bool = False,
402        resize_input: bool = True,
403        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
404        final_activation: Optional[Union[str, nn.Module]] = None,
405        use_skip_connection: bool = True,
406        embed_dim: Optional[int] = None,
407        use_conv_transpose: bool = False,
408        **kwargs
409    ) -> None:
410
411        super().__init__(
412            img_size=img_size,
413            backbone=backbone,
414            encoder=encoder,
415            decoder=decoder,
416            out_channels=out_channels,
417            use_sam_stats=use_sam_stats,
418            use_mae_stats=use_mae_stats,
419            use_dino_stats=use_dino_stats,
420            resize_input=resize_input,
421            encoder_checkpoint=encoder_checkpoint,
422            final_activation=final_activation,
423            use_skip_connection=use_skip_connection,
424            embed_dim=embed_dim,
425            use_conv_transpose=use_conv_transpose,
426            **kwargs,
427        )
428
429        encoder = self.encoder
430
431        if backbone == "sam2" and hasattr(encoder, "trunk"):
432            in_chans = encoder.trunk.patch_embed.proj.in_channels
433        elif hasattr(encoder, "in_chans"):
434            in_chans = encoder.in_chans
435        else:  # `nn.Module` ViT backbone.
436            try:
437                in_chans = encoder.patch_embed.proj.in_channels
438            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
439                in_chans = encoder.patch_embed.seq[0].c.in_channels
440
441        # parameters for the decoder network
442        depth = 3
443        initial_features = 64
444        gain = 2
445        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
446        scale_factors = depth * [2]
447        self.out_channels = out_channels
448
449        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
450        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
451
452        self.decoder = decoder or Decoder(
453            features=features_decoder,
454            scale_factors=scale_factors[::-1],
455            conv_block_impl=ConvBlock2d,
456            sampler_impl=_upsampler,
457        )
458
459        if use_skip_connection:
460            self.deconv1 = Deconv2DBlock(
461                in_channels=self.embed_dim,
462                out_channels=features_decoder[0],
463                use_conv_transpose=use_conv_transpose,
464            )
465            self.deconv2 = nn.Sequential(
466                Deconv2DBlock(
467                    in_channels=self.embed_dim,
468                    out_channels=features_decoder[0],
469                    use_conv_transpose=use_conv_transpose,
470                ),
471                Deconv2DBlock(
472                    in_channels=features_decoder[0],
473                    out_channels=features_decoder[1],
474                    use_conv_transpose=use_conv_transpose,
475                )
476            )
477            self.deconv3 = nn.Sequential(
478                Deconv2DBlock(
479                    in_channels=self.embed_dim,
480                    out_channels=features_decoder[0],
481                    use_conv_transpose=use_conv_transpose,
482                ),
483                Deconv2DBlock(
484                    in_channels=features_decoder[0],
485                    out_channels=features_decoder[1],
486                    use_conv_transpose=use_conv_transpose,
487                ),
488                Deconv2DBlock(
489                    in_channels=features_decoder[1],
490                    out_channels=features_decoder[2],
491                    use_conv_transpose=use_conv_transpose,
492                )
493            )
494            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
495        else:
496            self.deconv1 = Deconv2DBlock(
497                in_channels=self.embed_dim,
498                out_channels=features_decoder[0],
499                use_conv_transpose=use_conv_transpose,
500            )
501            self.deconv2 = Deconv2DBlock(
502                in_channels=features_decoder[0],
503                out_channels=features_decoder[1],
504                use_conv_transpose=use_conv_transpose,
505            )
506            self.deconv3 = Deconv2DBlock(
507                in_channels=features_decoder[1],
508                out_channels=features_decoder[2],
509                use_conv_transpose=use_conv_transpose,
510            )
511            self.deconv4 = Deconv2DBlock(
512                in_channels=features_decoder[2],
513                out_channels=features_decoder[3],
514                use_conv_transpose=use_conv_transpose,
515            )
516
517        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
518        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
519        self.deconv_out = _upsampler(
520            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
521        )
522        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
523
524    def forward(self, x: torch.Tensor) -> torch.Tensor:
525        """Apply the UNETR to the input data.
526
527        Args:
528            x: The input tensor.
529
530        Returns:
531            The UNETR output.
532        """
533        original_shape = x.shape[-2:]
534
535        # Reshape the inputs to the shape expected by the encoder
536        # and normalize the inputs if normalization is part of the model.
537        x, input_shape = self.preprocess(x)
538
539        encoder_outputs = self.encoder(x)
540
541        if isinstance(encoder_outputs[-1], list):
542            # `encoder_outputs` can be arranged in only two forms:
543            #   - either we only return the image embeddings
544            #   - or, we return the image embeddings and the "list" of global attention layers
545            z12, from_encoder = encoder_outputs
546        else:
547            z12 = encoder_outputs
548
549        if self.use_skip_connection:
550            from_encoder = from_encoder[::-1]
551            z9 = self.deconv1(from_encoder[0])
552            z6 = self.deconv2(from_encoder[1])
553            z3 = self.deconv3(from_encoder[2])
554            z0 = self.deconv4(x)
555
556        else:
557            z9 = self.deconv1(z12)
558            z6 = self.deconv2(z9)
559            z3 = self.deconv3(z6)
560            z0 = self.deconv4(z3)
561
562        updated_from_encoder = [z9, z6, z3]
563
564        x = self.base(z12)
565        x = self.decoder(x, encoder_inputs=updated_from_encoder)
566        x = self.deconv_out(x)
567
568        x = torch.cat([x, z0], dim=1)
569        x = self.decoder_head(x)
570
571        x = self.out_conv(x)
572        if self.final_activation is not None:
573            x = self.final_activation(x)
574
575        x = self.postprocess_masks(x, input_shape, original_shape)
576        return x
577
578
579class UNETR2D(UNETR):
580    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
581    """
582    pass
583
584
585class UNETR3D(UNETRBase):
586    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
587    """
588    def __init__(
589        self,
590        img_size: int = 1024,
591        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
592        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
593        decoder: Optional[nn.Module] = None,
594        out_channels: int = 1,
595        use_sam_stats: bool = False,
596        use_mae_stats: bool = False,
597        use_dino_stats: bool = False,
598        resize_input: bool = True,
599        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
600        final_activation: Optional[Union[str, nn.Module]] = None,
601        use_skip_connection: bool = False,
602        embed_dim: Optional[int] = None,
603        use_conv_transpose: bool = False,
604        use_strip_pooling: bool = True,
605        **kwargs
606    ):
607        if use_skip_connection:
608            raise NotImplementedError("The framework cannot handle skip connections atm.")
609        if use_conv_transpose:
610            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
611
612        # Sort the `embed_dim` out
613        embed_dim = 256 if embed_dim is None else embed_dim
614
615        super().__init__(
616            img_size=img_size,
617            backbone=backbone,
618            encoder=encoder,
619            decoder=decoder,
620            out_channels=out_channels,
621            use_sam_stats=use_sam_stats,
622            use_mae_stats=use_mae_stats,
623            use_dino_stats=use_dino_stats,
624            resize_input=resize_input,
625            encoder_checkpoint=encoder_checkpoint,
626            final_activation=final_activation,
627            use_skip_connection=use_skip_connection,
628            embed_dim=embed_dim,
629            use_conv_transpose=use_conv_transpose,
630            **kwargs,
631        )
632
633        # The 3d convolutional decoder.
634        # First, get the important parameters for the decoder.
635        depth = 3
636        initial_features = 64
637        gain = 2
638        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
639        scale_factors = [1, 2, 2]
640        self.out_channels = out_channels
641
642        # The mapping blocks.
643        self.deconv1 = Deconv3DBlock(
644            in_channels=embed_dim,
645            out_channels=features_decoder[0],
646            scale_factor=scale_factors,
647            use_strip_pooling=use_strip_pooling,
648        )
649        self.deconv2 = Deconv3DBlock(
650            in_channels=features_decoder[0],
651            out_channels=features_decoder[1],
652            scale_factor=scale_factors,
653            use_strip_pooling=use_strip_pooling,
654        )
655        self.deconv3 = Deconv3DBlock(
656            in_channels=features_decoder[1],
657            out_channels=features_decoder[2],
658            scale_factor=scale_factors,
659            use_strip_pooling=use_strip_pooling,
660        )
661        self.deconv4 = Deconv3DBlock(
662            in_channels=features_decoder[2],
663            out_channels=features_decoder[3],
664            scale_factor=scale_factors,
665            use_strip_pooling=use_strip_pooling,
666        )
667
668        # The core decoder block.
669        self.decoder = decoder or Decoder(
670            features=features_decoder,
671            scale_factors=[scale_factors] * depth,
672            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
673            sampler_impl=Upsampler3d,
674        )
675
676        # And the final upsampler to match the expected dimensions.
677        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
678            in_channels=features_decoder[-1],
679            out_channels=features_decoder[-1],
680            scale_factor=scale_factors,
681            use_strip_pooling=use_strip_pooling,
682        )
683
684        # Additional conjunction blocks.
685        self.base = ConvBlock3dWithStrip(
686            in_channels=embed_dim,
687            out_channels=features_decoder[0],
688            use_strip_pooling=use_strip_pooling,
689        )
690
691        # And the output layers.
692        self.decoder_head = ConvBlock3dWithStrip(
693            in_channels=2 * features_decoder[-1],
694            out_channels=features_decoder[-1],
695            use_strip_pooling=use_strip_pooling,
696        )
697        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
698
699    def forward(self, x: torch.Tensor):
700        """Forward pass of the UNETR-3D model.
701
702        Args:
703            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
704
705        Returns:
706            The UNETR output.
707        """
708        B, C, Z, H, W = x.shape
709        original_shape = (Z, H, W)
710
711        # Preprocessing step
712        x, input_shape = self.preprocess(x)
713
714        # Run the image encoder.
715        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
716
717        # Prepare the counterparts for the decoder.
718        # NOTE: The section below is sequential, there's no skip connections atm.
719        z9 = self.deconv1(curr_features)
720        z6 = self.deconv2(z9)
721        z3 = self.deconv3(z6)
722        z0 = self.deconv4(z3)
723
724        updated_from_encoder = [z9, z6, z3]
725
726        # Align the features through the base block.
727        x = self.base(curr_features)
728        # Run the decoder
729        x = self.decoder(x, encoder_inputs=updated_from_encoder)
730        x = self.deconv_out(x)  # NOTE before `end_up`
731
732        # And the final output head.
733        x = torch.cat([x, z0], dim=1)
734        x = self.decoder_head(x)
735        x = self.out_conv(x)
736        if self.final_activation is not None:
737            x = self.final_activation(x)
738
739        # Postprocess the output back to original size.
740        x = self.postprocess_masks(x, input_shape, original_shape)
741        return x
742
743#
744#  ADDITIONAL FUNCTIONALITIES
745#
746
747
748def _strip_pooling_layers(enabled, channels) -> nn.Module:
749    return DepthStripPooling(channels) if enabled else nn.Identity()
750
751
752class DepthStripPooling(nn.Module):
753    """@private
754    """
755    def __init__(self, channels: int, reduction: int = 4):
756        """Block for strip pooling along the depth dimension (only).
757
758        eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling
759        to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to
760        modulate the original features (using a gated residual).
761
762        For 2D (Z == 1): returns input unchanged (no-op).
763
764        Args:
765            channels: The output channels.
766            reduction: The reduction of the hidden layers.
767        """
768        super().__init__()
769        hidden = max(1, channels // reduction)
770        self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1)
771        self.bn1 = nn.BatchNorm3d(hidden)
772        self.relu = nn.ReLU(inplace=True)
773        self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1)
774
775    def forward(self, x: torch.Tensor) -> torch.Tensor:
776        if x.dim() != 5:
777            raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.")
778
779        B, C, Z, H, W = x.shape
780        if Z == 1:  # i.e. always the case of all 2d.
781            return x  # We simply do nothing there.
782
783        # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W)
784        feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W))
785        feat = self.conv1(feat)
786        feat = self.bn1(feat)
787        feat = self.relu(feat)
788        feat = self.conv2(feat)
789        gate = torch.sigmoid(feat).expand(B, C, Z, H, W)  # Broadcast the collapsed depth context back to all slices
790
791        # Gated residual fusion
792        return x * gate + x
793
794
795class Deconv3DBlock(nn.Module):
796    """@private
797    """
798    def __init__(
799        self,
800        scale_factor,
801        in_channels,
802        out_channels,
803        kernel_size=3,
804        anisotropic_kernel=True,
805        use_strip_pooling=True,
806    ):
807        super().__init__()
808        conv_block_kwargs = {
809            "in_channels": out_channels,
810            "out_channels": out_channels,
811            "kernel_size": kernel_size,
812            "padding": ((kernel_size - 1) // 2),
813        }
814        if anisotropic_kernel:
815            conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor)
816
817        self.block = nn.Sequential(
818            Upsampler3d(scale_factor, in_channels, out_channels),
819            nn.Conv3d(**conv_block_kwargs),
820            nn.BatchNorm3d(out_channels),
821            nn.ReLU(True),
822            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
823        )
824
825    def forward(self, x):
826        return self.block(x)
827
828
829class ConvBlock3dWithStrip(nn.Module):
830    """@private
831    """
832    def __init__(
833        self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs
834    ):
835        super().__init__()
836        self.block = nn.Sequential(
837            ConvBlock3d(in_channels, out_channels, **kwargs),
838            _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels),
839        )
840
841    def forward(self, x):
842        return self.block(x)
843
844
845class SingleDeconv2DBlock(nn.Module):
846    """@private
847    """
848    def __init__(self, scale_factor, in_channels, out_channels):
849        super().__init__()
850        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
851
852    def forward(self, x):
853        return self.block(x)
854
855
856class SingleConv2DBlock(nn.Module):
857    """@private
858    """
859    def __init__(self, in_channels, out_channels, kernel_size):
860        super().__init__()
861        self.block = nn.Conv2d(
862            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
863        )
864
865    def forward(self, x):
866        return self.block(x)
867
868
869class Conv2DBlock(nn.Module):
870    """@private
871    """
872    def __init__(self, in_channels, out_channels, kernel_size=3):
873        super().__init__()
874        self.block = nn.Sequential(
875            SingleConv2DBlock(in_channels, out_channels, kernel_size),
876            nn.BatchNorm2d(out_channels),
877            nn.ReLU(True)
878        )
879
880    def forward(self, x):
881        return self.block(x)
882
883
884class Deconv2DBlock(nn.Module):
885    """@private
886    """
887    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
888        super().__init__()
889        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
890        self.block = nn.Sequential(
891            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
892            SingleConv2DBlock(out_channels, out_channels, kernel_size),
893            nn.BatchNorm2d(out_channels),
894            nn.ReLU(True)
895        )
896
897    def forward(self, x):
898        return self.block(x)
class UNETRBase(torch.nn.modules.module.Module):
 34class UNETRBase(nn.Module):
 35    """Base class for implementing a UNETR.
 36
 37    Args:
 38        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 39        backbone: The name of the vision transformer implementation.
 40            One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"
 41            (see all combinations below)
 42        encoder: The vision transformer. Can either be a name, such as "vit_b"
 43            (see all combinations for this below) or a torch module.
 44        decoder: The convolutional decoder.
 45        out_channels: The number of output channels of the UNETR.
 46        use_sam_stats: Whether to normalize the input data with the statistics of the
 47            pretrained SAM / SAM2 / SAM3 model.
 48        use_dino_stats: Whether to normalize the input data with the statistics of the
 49            pretrained DINOv2 / DINOv3 model.
 50        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 51        resize_input: Whether to resize the input images to match `img_size`.
 52            By default, it resizes the inputs to match the `img_size`.
 53        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 54            Can either be a filepath or an already loaded checkpoint.
 55        final_activation: The activation to apply to the UNETR output.
 56        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 57        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 58        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 59            By default, it uses resampling for upsampling.
 60
 61        NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
 62
 63        SAM_family_models:
 64            - 'sam' x 'vit_b'
 65            - 'sam' x 'vit_l'
 66            - 'sam' x 'vit_h'
 67            - 'sam2' x 'hvit_t'
 68            - 'sam2' x 'hvit_s'
 69            - 'sam2' x 'hvit_b'
 70            - 'sam2' x 'hvit_l'
 71            - 'sam3' x 'vit_pe'
 72            - 'cellpose_sam' x 'vit_l'
 73
 74        DINO_family_models:
 75            - 'dinov2' x 'vit_s'
 76            - 'dinov2' x 'vit_b'
 77            - 'dinov2' x 'vit_l'
 78            - 'dinov2' x 'vit_g'
 79            - 'dinov2' x 'vit_s_reg4'
 80            - 'dinov2' x 'vit_b_reg4'
 81            - 'dinov2' x 'vit_l_reg4'
 82            - 'dinov2' x 'vit_g_reg4'
 83            - 'dinov3' x 'vit_s'
 84            - 'dinov3' x 'vit_s+'
 85            - 'dinov3' x 'vit_b'
 86            - 'dinov3' x 'vit_l'
 87            - 'dinov3' x 'vit_l+'
 88            - 'dinov3' x 'vit_h+'
 89            - 'dinov3' x 'vit_7b'
 90
 91        MAE_family_models:
 92            - 'mae' x 'vit_b'
 93            - 'mae' x 'vit_l'
 94            - 'mae' x 'vit_h'
 95            - 'scalemae' x 'vit_b'
 96            - 'scalemae' x 'vit_l'
 97            - 'scalemae' x 'vit_h'
 98    """
 99    def __init__(
100        self,
101        img_size: int = 1024,
102        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
103        encoder: Optional[Union[nn.Module, str]] = "vit_b",
104        decoder: Optional[nn.Module] = None,
105        out_channels: int = 1,
106        use_sam_stats: bool = False,
107        use_mae_stats: bool = False,
108        use_dino_stats: bool = False,
109        resize_input: bool = True,
110        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
111        final_activation: Optional[Union[str, nn.Module]] = None,
112        use_skip_connection: bool = True,
113        embed_dim: Optional[int] = None,
114        use_conv_transpose: bool = False,
115        **kwargs
116    ) -> None:
117        super().__init__()
118
119        self.img_size = img_size
120        self.use_sam_stats = use_sam_stats
121        self.use_mae_stats = use_mae_stats
122        self.use_dino_stats = use_dino_stats
123        self.use_skip_connection = use_skip_connection
124        self.resize_input = resize_input
125        self.use_conv_transpose = use_conv_transpose
126        self.backbone = backbone
127
128        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
129            print(f"Using {encoder} from {backbone.upper()}")
130            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
131
132            if encoder_checkpoint is not None:
133                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
134
135            if embed_dim is None:
136                embed_dim = self.encoder.embed_dim
137
138            # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change.
139            if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck:
140                embed_dim = self.encoder.neck[2].out_channels  # the value is 256
141
142        else:  # `nn.Module` ViT backbone
143            self.encoder = encoder
144
145            have_neck = False
146            for name, _ in self.encoder.named_parameters():
147                if name.startswith("neck"):
148                    have_neck = True
149
150            if embed_dim is None:
151                if have_neck:
152                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
153                else:
154                    embed_dim = self.encoder.patch_embed.proj.out_channels
155
156        self.embed_dim = embed_dim
157        self.final_activation = self._get_activation(final_activation)
158
159    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
160        """Function to load pretrained weights to the image encoder.
161        """
162        if isinstance(checkpoint, str):
163            if backbone == "sam" and isinstance(encoder, str):
164                # If we have a SAM encoder, then we first try to load the full SAM Model
165                # (using micro_sam) and otherwise fall back on directly loading the encoder state
166                # from the checkpoint
167                try:
168                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
169                    encoder_state = model.image_encoder.state_dict()
170                except Exception:
171                    # Try loading the encoder state directly from a checkpoint.
172                    encoder_state = torch.load(checkpoint, weights_only=False)
173
174            elif backbone == "cellpose_sam" and isinstance(encoder, str):
175                # The architecture matches CellposeSAM exactly (same rel_pos sizes),
176                # so weights load directly without any interpolation.
177                encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False)
178                # Handle DataParallel/DistributedDataParallel prefix.
179                if any(k.startswith("module.") for k in encoder_state.keys()):
180                    encoder_state = OrderedDict(
181                        {k[len("module."):]: v for k, v in encoder_state.items()}
182                    )
183                # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix).
184                if any(k.startswith("encoder.") for k in encoder_state.keys()):
185                    encoder_state = OrderedDict(
186                        {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")}
187                    )
188
189            elif backbone == "sam2" and isinstance(encoder, str):
190                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
191                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
192                # from the checkpoint
193                try:
194                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
195                    encoder_state = model.image_encoder.state_dict()
196                except Exception:
197                    # Try loading the encoder state directly from a checkpoint.
198                    encoder_state = torch.load(checkpoint, weights_only=False)
199
200            elif backbone == "sam3" and isinstance(encoder, str):
201                # If we have a SAM3 encoder, then we first try to load the full SAM3 Model.
202                # (using micro_sam3) and otherwise fall back on directly loading the encoder state
203                # from the checkpoint
204                try:
205                    model = get_sam3_model(checkpoint_path=checkpoint)
206                    encoder_state = model.backbone.vision_backbone.state_dict()
207                    # Let's align loading the encoder weights with expected parameter names
208                    encoder_state = {
209                        k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items()
210                    }
211                    # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks.
212                    encoder_state = {
213                        k: v for k, v in encoder_state.items()
214                        if not (k.startswith("convs.") or k.startswith("sam2_convs."))
215                    }
216                except Exception:
217                    # Try loading the encoder state directly from a checkpoint.
218                    encoder_state = torch.load(checkpoint, weights_only=False)
219
220            elif backbone == "mae":
221                # vit initialization hints from:
222                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
223                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
224                encoder_state = OrderedDict({
225                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
226                })
227                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
228                current_encoder_state = self.encoder.state_dict()
229                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
230                    del self.encoder.head
231
232            elif backbone == "scalemae":
233                # Load the encoder state directly from a checkpoint.
234                encoder_state = torch.load(checkpoint)["model"]
235                encoder_state = OrderedDict({
236                    k: v for k, v in encoder_state.items()
237                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
238                })
239
240                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
241                current_encoder_state = self.encoder.state_dict()
242                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
243                    del self.encoder.head
244
245                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
246                    del self.encoder.pos_embed
247
248            elif backbone in ["dinov2", "dinov3"]:  # Load the encoder state directly from a checkpoint.
249                encoder_state = torch.load(checkpoint)
250
251            else:
252                raise ValueError(
253                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
254                )
255
256        else:
257            encoder_state = checkpoint
258
259        self.encoder.load_state_dict(encoder_state)
260
261    def _get_activation(self, activation):
262        return_activation = None
263        if activation is None:
264            return None
265        if isinstance(activation, nn.Module):
266            return activation
267        if isinstance(activation, str):
268            return_activation = getattr(nn, activation, None)
269        if return_activation is None:
270            raise ValueError(f"Invalid activation: {activation}")
271
272        return return_activation()
273
274    @staticmethod
275    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
276        """Compute the output size given input size and target long side length.
277
278        Args:
279            oldh: The input image height.
280            oldw: The input image width.
281            long_side_length: The longest side length for resizing.
282
283        Returns:
284            The new image height.
285            The new image width.
286        """
287        scale = long_side_length * 1.0 / max(oldh, oldw)
288        newh, neww = oldh * scale, oldw * scale
289        neww = int(neww + 0.5)
290        newh = int(newh + 0.5)
291        return (newh, neww)
292
293    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
294        """Resize the image so that the longest side has the correct length.
295
296        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
297
298        Args:
299            image: The input image.
300
301        Returns:
302            The resized image.
303        """
304        if image.ndim == 4:  # i.e. 2d image
305            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
306            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
307        elif image.ndim == 5:  # i.e. 3d volume
308            B, C, Z, H, W = image.shape
309            target_size = self.get_preprocess_shape(H, W, self.img_size)
310            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
311        else:
312            raise ValueError("Expected 4d or 5d inputs, got", image.shape)
313
314    def _as_stats(self, mean, std, device, dtype, is_3d: bool):
315        """@private
316        """
317        # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1).
318        view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1)
319        pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape)
320        pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape)
321        return pixel_mean, pixel_std
322
323    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
324        """@private
325        """
326        device = x.device
327        is_3d = (x.ndim == 5)
328        device, dtype = x.device, x.dtype
329
330        if self.use_sam_stats:
331            mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375)
332        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
333            raise NotImplementedError
334        elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"):
335            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
336        elif self.use_sam_stats and self.backbone == "sam3":
337            mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
338        else:
339            mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
340
341        pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d)
342
343        if self.resize_input:
344            x = self.resize_longest_side(x)
345        input_shape = x.shape[-3:] if is_3d else x.shape[-2:]
346
347        x = (x - pixel_mean) / pixel_std
348        h, w = x.shape[-2:]
349        padh = self.encoder.img_size - h
350        padw = self.encoder.img_size - w
351
352        if is_3d:
353            x = F.pad(x, (0, padw, 0, padh, 0, 0))
354        else:
355            x = F.pad(x, (0, padw, 0, padh))
356
357        return x, input_shape
358
359    def postprocess_masks(
360        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
361    ) -> torch.Tensor:
362        """@private
363        """
364        if masks.ndim == 4:  # i.e. 2d labels
365            masks = F.interpolate(
366                masks,
367                (self.encoder.img_size, self.encoder.img_size),
368                mode="bilinear",
369                align_corners=False,
370            )
371            masks = masks[..., : input_size[0], : input_size[1]]
372            masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
373
374        elif masks.ndim == 5:  # i.e. 3d volumetric labels
375            masks = F.interpolate(
376                masks,
377                (input_size[0], self.img_size, self.img_size),
378                mode="trilinear",
379                align_corners=False,
380            )
381            masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
382            masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False)
383
384        else:
385            raise ValueError("Expected 4d or 5d labels, got", masks.shape)
386
387        return masks

Base class for implementing a UNETR.

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • backbone: The name of the vision transformer implementation. One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
  • encoder: The vision transformer. Can either be a name, such as "vit_b" (see all combinations for this below) or a torch module.
  • decoder: The convolutional decoder.
  • out_channels: The number of output channels of the UNETR.
  • use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM / SAM2 / SAM3 model.
  • use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv2 / DINOv3 model.
  • use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  • resize_input: Whether to resize the input images to match img_size. By default, it resizes the inputs to match the img_size.
  • encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
  • final_activation: The activation to apply to the UNETR output.
  • use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
  • embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
  • use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
  • NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
  • SAM_family_models: - 'sam' x 'vit_b'
    • 'sam' x 'vit_l'
    • 'sam' x 'vit_h'
    • 'sam2' x 'hvit_t'
    • 'sam2' x 'hvit_s'
    • 'sam2' x 'hvit_b'
    • 'sam2' x 'hvit_l'
    • 'sam3' x 'vit_pe'
    • 'cellpose_sam' x 'vit_l'
  • DINO_family_models: - 'dinov2' x 'vit_s'
    • 'dinov2' x 'vit_b'
    • 'dinov2' x 'vit_l'
    • 'dinov2' x 'vit_g'
    • 'dinov2' x 'vit_s_reg4'
    • 'dinov2' x 'vit_b_reg4'
    • 'dinov2' x 'vit_l_reg4'
    • 'dinov2' x 'vit_g_reg4'
    • 'dinov3' x 'vit_s'
    • 'dinov3' x 'vit_s+'
    • 'dinov3' x 'vit_b'
    • 'dinov3' x 'vit_l'
    • 'dinov3' x 'vit_l+'
    • 'dinov3' x 'vit_h+'
    • 'dinov3' x 'vit_7b'
  • MAE_family_models: - 'mae' x 'vit_b'
    • 'mae' x 'vit_l'
    • 'mae' x 'vit_h'
    • 'scalemae' x 'vit_b'
    • 'scalemae' x 'vit_l'
    • 'scalemae' x 'vit_h'
UNETRBase( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
 99    def __init__(
100        self,
101        img_size: int = 1024,
102        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
103        encoder: Optional[Union[nn.Module, str]] = "vit_b",
104        decoder: Optional[nn.Module] = None,
105        out_channels: int = 1,
106        use_sam_stats: bool = False,
107        use_mae_stats: bool = False,
108        use_dino_stats: bool = False,
109        resize_input: bool = True,
110        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
111        final_activation: Optional[Union[str, nn.Module]] = None,
112        use_skip_connection: bool = True,
113        embed_dim: Optional[int] = None,
114        use_conv_transpose: bool = False,
115        **kwargs
116    ) -> None:
117        super().__init__()
118
119        self.img_size = img_size
120        self.use_sam_stats = use_sam_stats
121        self.use_mae_stats = use_mae_stats
122        self.use_dino_stats = use_dino_stats
123        self.use_skip_connection = use_skip_connection
124        self.resize_input = resize_input
125        self.use_conv_transpose = use_conv_transpose
126        self.backbone = backbone
127
128        if isinstance(encoder, str):  # e.g. "vit_b" / "hvit_b" / "vit_pe"
129            print(f"Using {encoder} from {backbone.upper()}")
130            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
131
132            if encoder_checkpoint is not None:
133                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
134
135            if embed_dim is None:
136                embed_dim = self.encoder.embed_dim
137
138            # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change.
139            if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck:
140                embed_dim = self.encoder.neck[2].out_channels  # the value is 256
141
142        else:  # `nn.Module` ViT backbone
143            self.encoder = encoder
144
145            have_neck = False
146            for name, _ in self.encoder.named_parameters():
147                if name.startswith("neck"):
148                    have_neck = True
149
150            if embed_dim is None:
151                if have_neck:
152                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
153                else:
154                    embed_dim = self.encoder.patch_embed.proj.out_channels
155
156        self.embed_dim = embed_dim
157        self.final_activation = self._get_activation(final_activation)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

img_size
use_sam_stats
use_mae_stats
use_dino_stats
use_skip_connection
resize_input
use_conv_transpose
backbone
embed_dim
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
274    @staticmethod
275    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
276        """Compute the output size given input size and target long side length.
277
278        Args:
279            oldh: The input image height.
280            oldw: The input image width.
281            long_side_length: The longest side length for resizing.
282
283        Returns:
284            The new image height.
285            The new image width.
286        """
287        scale = long_side_length * 1.0 / max(oldh, oldw)
288        newh, neww = oldh * scale, oldw * scale
289        neww = int(neww + 0.5)
290        newh = int(newh + 0.5)
291        return (newh, neww)

Compute the output size given input size and target long side length.

Arguments:
  • oldh: The input image height.
  • oldw: The input image width.
  • long_side_length: The longest side length for resizing.
Returns:

The new image height. The new image width.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
293    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
294        """Resize the image so that the longest side has the correct length.
295
296        Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
297
298        Args:
299            image: The input image.
300
301        Returns:
302            The resized image.
303        """
304        if image.ndim == 4:  # i.e. 2d image
305            target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
306            return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
307        elif image.ndim == 5:  # i.e. 3d volume
308            B, C, Z, H, W = image.shape
309            target_size = self.get_preprocess_shape(H, W, self.img_size)
310            return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False)
311        else:
312            raise ValueError("Expected 4d or 5d inputs, got", image.shape)

Resize the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.

Arguments:
  • image: The input image.
Returns:

The resized image.

class UNETR(UNETRBase):
390class UNETR(UNETRBase):
391    """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
392    """
393    def __init__(
394        self,
395        img_size: int = 1024,
396        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
397        encoder: Optional[Union[nn.Module, str]] = "vit_b",
398        decoder: Optional[nn.Module] = None,
399        out_channels: int = 1,
400        use_sam_stats: bool = False,
401        use_mae_stats: bool = False,
402        use_dino_stats: bool = False,
403        resize_input: bool = True,
404        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
405        final_activation: Optional[Union[str, nn.Module]] = None,
406        use_skip_connection: bool = True,
407        embed_dim: Optional[int] = None,
408        use_conv_transpose: bool = False,
409        **kwargs
410    ) -> None:
411
412        super().__init__(
413            img_size=img_size,
414            backbone=backbone,
415            encoder=encoder,
416            decoder=decoder,
417            out_channels=out_channels,
418            use_sam_stats=use_sam_stats,
419            use_mae_stats=use_mae_stats,
420            use_dino_stats=use_dino_stats,
421            resize_input=resize_input,
422            encoder_checkpoint=encoder_checkpoint,
423            final_activation=final_activation,
424            use_skip_connection=use_skip_connection,
425            embed_dim=embed_dim,
426            use_conv_transpose=use_conv_transpose,
427            **kwargs,
428        )
429
430        encoder = self.encoder
431
432        if backbone == "sam2" and hasattr(encoder, "trunk"):
433            in_chans = encoder.trunk.patch_embed.proj.in_channels
434        elif hasattr(encoder, "in_chans"):
435            in_chans = encoder.in_chans
436        else:  # `nn.Module` ViT backbone.
437            try:
438                in_chans = encoder.patch_embed.proj.in_channels
439            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
440                in_chans = encoder.patch_embed.seq[0].c.in_channels
441
442        # parameters for the decoder network
443        depth = 3
444        initial_features = 64
445        gain = 2
446        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
447        scale_factors = depth * [2]
448        self.out_channels = out_channels
449
450        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
451        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
452
453        self.decoder = decoder or Decoder(
454            features=features_decoder,
455            scale_factors=scale_factors[::-1],
456            conv_block_impl=ConvBlock2d,
457            sampler_impl=_upsampler,
458        )
459
460        if use_skip_connection:
461            self.deconv1 = Deconv2DBlock(
462                in_channels=self.embed_dim,
463                out_channels=features_decoder[0],
464                use_conv_transpose=use_conv_transpose,
465            )
466            self.deconv2 = nn.Sequential(
467                Deconv2DBlock(
468                    in_channels=self.embed_dim,
469                    out_channels=features_decoder[0],
470                    use_conv_transpose=use_conv_transpose,
471                ),
472                Deconv2DBlock(
473                    in_channels=features_decoder[0],
474                    out_channels=features_decoder[1],
475                    use_conv_transpose=use_conv_transpose,
476                )
477            )
478            self.deconv3 = nn.Sequential(
479                Deconv2DBlock(
480                    in_channels=self.embed_dim,
481                    out_channels=features_decoder[0],
482                    use_conv_transpose=use_conv_transpose,
483                ),
484                Deconv2DBlock(
485                    in_channels=features_decoder[0],
486                    out_channels=features_decoder[1],
487                    use_conv_transpose=use_conv_transpose,
488                ),
489                Deconv2DBlock(
490                    in_channels=features_decoder[1],
491                    out_channels=features_decoder[2],
492                    use_conv_transpose=use_conv_transpose,
493                )
494            )
495            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
496        else:
497            self.deconv1 = Deconv2DBlock(
498                in_channels=self.embed_dim,
499                out_channels=features_decoder[0],
500                use_conv_transpose=use_conv_transpose,
501            )
502            self.deconv2 = Deconv2DBlock(
503                in_channels=features_decoder[0],
504                out_channels=features_decoder[1],
505                use_conv_transpose=use_conv_transpose,
506            )
507            self.deconv3 = Deconv2DBlock(
508                in_channels=features_decoder[1],
509                out_channels=features_decoder[2],
510                use_conv_transpose=use_conv_transpose,
511            )
512            self.deconv4 = Deconv2DBlock(
513                in_channels=features_decoder[2],
514                out_channels=features_decoder[3],
515                use_conv_transpose=use_conv_transpose,
516            )
517
518        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
519        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
520        self.deconv_out = _upsampler(
521            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
522        )
523        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
524
525    def forward(self, x: torch.Tensor) -> torch.Tensor:
526        """Apply the UNETR to the input data.
527
528        Args:
529            x: The input tensor.
530
531        Returns:
532            The UNETR output.
533        """
534        original_shape = x.shape[-2:]
535
536        # Reshape the inputs to the shape expected by the encoder
537        # and normalize the inputs if normalization is part of the model.
538        x, input_shape = self.preprocess(x)
539
540        encoder_outputs = self.encoder(x)
541
542        if isinstance(encoder_outputs[-1], list):
543            # `encoder_outputs` can be arranged in only two forms:
544            #   - either we only return the image embeddings
545            #   - or, we return the image embeddings and the "list" of global attention layers
546            z12, from_encoder = encoder_outputs
547        else:
548            z12 = encoder_outputs
549
550        if self.use_skip_connection:
551            from_encoder = from_encoder[::-1]
552            z9 = self.deconv1(from_encoder[0])
553            z6 = self.deconv2(from_encoder[1])
554            z3 = self.deconv3(from_encoder[2])
555            z0 = self.deconv4(x)
556
557        else:
558            z9 = self.deconv1(z12)
559            z6 = self.deconv2(z9)
560            z3 = self.deconv3(z6)
561            z0 = self.deconv4(z3)
562
563        updated_from_encoder = [z9, z6, z3]
564
565        x = self.base(z12)
566        x = self.decoder(x, encoder_inputs=updated_from_encoder)
567        x = self.deconv_out(x)
568
569        x = torch.cat([x, z0], dim=1)
570        x = self.decoder_head(x)
571
572        x = self.out_conv(x)
573        if self.final_activation is not None:
574            x = self.final_activation(x)
575
576        x = self.postprocess_masks(x, input_shape, original_shape)
577        return x

A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
393    def __init__(
394        self,
395        img_size: int = 1024,
396        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
397        encoder: Optional[Union[nn.Module, str]] = "vit_b",
398        decoder: Optional[nn.Module] = None,
399        out_channels: int = 1,
400        use_sam_stats: bool = False,
401        use_mae_stats: bool = False,
402        use_dino_stats: bool = False,
403        resize_input: bool = True,
404        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
405        final_activation: Optional[Union[str, nn.Module]] = None,
406        use_skip_connection: bool = True,
407        embed_dim: Optional[int] = None,
408        use_conv_transpose: bool = False,
409        **kwargs
410    ) -> None:
411
412        super().__init__(
413            img_size=img_size,
414            backbone=backbone,
415            encoder=encoder,
416            decoder=decoder,
417            out_channels=out_channels,
418            use_sam_stats=use_sam_stats,
419            use_mae_stats=use_mae_stats,
420            use_dino_stats=use_dino_stats,
421            resize_input=resize_input,
422            encoder_checkpoint=encoder_checkpoint,
423            final_activation=final_activation,
424            use_skip_connection=use_skip_connection,
425            embed_dim=embed_dim,
426            use_conv_transpose=use_conv_transpose,
427            **kwargs,
428        )
429
430        encoder = self.encoder
431
432        if backbone == "sam2" and hasattr(encoder, "trunk"):
433            in_chans = encoder.trunk.patch_embed.proj.in_channels
434        elif hasattr(encoder, "in_chans"):
435            in_chans = encoder.in_chans
436        else:  # `nn.Module` ViT backbone.
437            try:
438                in_chans = encoder.patch_embed.proj.in_channels
439            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
440                in_chans = encoder.patch_embed.seq[0].c.in_channels
441
442        # parameters for the decoder network
443        depth = 3
444        initial_features = 64
445        gain = 2
446        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
447        scale_factors = depth * [2]
448        self.out_channels = out_channels
449
450        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
451        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
452
453        self.decoder = decoder or Decoder(
454            features=features_decoder,
455            scale_factors=scale_factors[::-1],
456            conv_block_impl=ConvBlock2d,
457            sampler_impl=_upsampler,
458        )
459
460        if use_skip_connection:
461            self.deconv1 = Deconv2DBlock(
462                in_channels=self.embed_dim,
463                out_channels=features_decoder[0],
464                use_conv_transpose=use_conv_transpose,
465            )
466            self.deconv2 = nn.Sequential(
467                Deconv2DBlock(
468                    in_channels=self.embed_dim,
469                    out_channels=features_decoder[0],
470                    use_conv_transpose=use_conv_transpose,
471                ),
472                Deconv2DBlock(
473                    in_channels=features_decoder[0],
474                    out_channels=features_decoder[1],
475                    use_conv_transpose=use_conv_transpose,
476                )
477            )
478            self.deconv3 = nn.Sequential(
479                Deconv2DBlock(
480                    in_channels=self.embed_dim,
481                    out_channels=features_decoder[0],
482                    use_conv_transpose=use_conv_transpose,
483                ),
484                Deconv2DBlock(
485                    in_channels=features_decoder[0],
486                    out_channels=features_decoder[1],
487                    use_conv_transpose=use_conv_transpose,
488                ),
489                Deconv2DBlock(
490                    in_channels=features_decoder[1],
491                    out_channels=features_decoder[2],
492                    use_conv_transpose=use_conv_transpose,
493                )
494            )
495            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
496        else:
497            self.deconv1 = Deconv2DBlock(
498                in_channels=self.embed_dim,
499                out_channels=features_decoder[0],
500                use_conv_transpose=use_conv_transpose,
501            )
502            self.deconv2 = Deconv2DBlock(
503                in_channels=features_decoder[0],
504                out_channels=features_decoder[1],
505                use_conv_transpose=use_conv_transpose,
506            )
507            self.deconv3 = Deconv2DBlock(
508                in_channels=features_decoder[1],
509                out_channels=features_decoder[2],
510                use_conv_transpose=use_conv_transpose,
511            )
512            self.deconv4 = Deconv2DBlock(
513                in_channels=features_decoder[2],
514                out_channels=features_decoder[3],
515                use_conv_transpose=use_conv_transpose,
516            )
517
518        self.base = ConvBlock2d(self.embed_dim, features_decoder[0])
519        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
520        self.deconv_out = _upsampler(
521            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
522        )
523        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
decoder
base
out_conv
deconv_out
decoder_head
def forward(self, x: torch.Tensor) -> torch.Tensor:
525    def forward(self, x: torch.Tensor) -> torch.Tensor:
526        """Apply the UNETR to the input data.
527
528        Args:
529            x: The input tensor.
530
531        Returns:
532            The UNETR output.
533        """
534        original_shape = x.shape[-2:]
535
536        # Reshape the inputs to the shape expected by the encoder
537        # and normalize the inputs if normalization is part of the model.
538        x, input_shape = self.preprocess(x)
539
540        encoder_outputs = self.encoder(x)
541
542        if isinstance(encoder_outputs[-1], list):
543            # `encoder_outputs` can be arranged in only two forms:
544            #   - either we only return the image embeddings
545            #   - or, we return the image embeddings and the "list" of global attention layers
546            z12, from_encoder = encoder_outputs
547        else:
548            z12 = encoder_outputs
549
550        if self.use_skip_connection:
551            from_encoder = from_encoder[::-1]
552            z9 = self.deconv1(from_encoder[0])
553            z6 = self.deconv2(from_encoder[1])
554            z3 = self.deconv3(from_encoder[2])
555            z0 = self.deconv4(x)
556
557        else:
558            z9 = self.deconv1(z12)
559            z6 = self.deconv2(z9)
560            z3 = self.deconv3(z6)
561            z0 = self.deconv4(z3)
562
563        updated_from_encoder = [z9, z6, z3]
564
565        x = self.base(z12)
566        x = self.decoder(x, encoder_inputs=updated_from_encoder)
567        x = self.deconv_out(x)
568
569        x = torch.cat([x, z0], dim=1)
570        x = self.decoder_head(x)
571
572        x = self.out_conv(x)
573        if self.final_activation is not None:
574            x = self.final_activation(x)
575
576        x = self.postprocess_masks(x, input_shape, original_shape)
577        return x

Apply the UNETR to the input data.

Arguments:
  • x: The input tensor.
Returns:

The UNETR output.

class UNETR2D(UNETR):
580class UNETR2D(UNETR):
581    """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
582    """
583    pass

A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

class UNETR3D(UNETRBase):
586class UNETR3D(UNETRBase):
587    """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
588    """
589    def __init__(
590        self,
591        img_size: int = 1024,
592        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
593        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
594        decoder: Optional[nn.Module] = None,
595        out_channels: int = 1,
596        use_sam_stats: bool = False,
597        use_mae_stats: bool = False,
598        use_dino_stats: bool = False,
599        resize_input: bool = True,
600        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
601        final_activation: Optional[Union[str, nn.Module]] = None,
602        use_skip_connection: bool = False,
603        embed_dim: Optional[int] = None,
604        use_conv_transpose: bool = False,
605        use_strip_pooling: bool = True,
606        **kwargs
607    ):
608        if use_skip_connection:
609            raise NotImplementedError("The framework cannot handle skip connections atm.")
610        if use_conv_transpose:
611            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
612
613        # Sort the `embed_dim` out
614        embed_dim = 256 if embed_dim is None else embed_dim
615
616        super().__init__(
617            img_size=img_size,
618            backbone=backbone,
619            encoder=encoder,
620            decoder=decoder,
621            out_channels=out_channels,
622            use_sam_stats=use_sam_stats,
623            use_mae_stats=use_mae_stats,
624            use_dino_stats=use_dino_stats,
625            resize_input=resize_input,
626            encoder_checkpoint=encoder_checkpoint,
627            final_activation=final_activation,
628            use_skip_connection=use_skip_connection,
629            embed_dim=embed_dim,
630            use_conv_transpose=use_conv_transpose,
631            **kwargs,
632        )
633
634        # The 3d convolutional decoder.
635        # First, get the important parameters for the decoder.
636        depth = 3
637        initial_features = 64
638        gain = 2
639        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
640        scale_factors = [1, 2, 2]
641        self.out_channels = out_channels
642
643        # The mapping blocks.
644        self.deconv1 = Deconv3DBlock(
645            in_channels=embed_dim,
646            out_channels=features_decoder[0],
647            scale_factor=scale_factors,
648            use_strip_pooling=use_strip_pooling,
649        )
650        self.deconv2 = Deconv3DBlock(
651            in_channels=features_decoder[0],
652            out_channels=features_decoder[1],
653            scale_factor=scale_factors,
654            use_strip_pooling=use_strip_pooling,
655        )
656        self.deconv3 = Deconv3DBlock(
657            in_channels=features_decoder[1],
658            out_channels=features_decoder[2],
659            scale_factor=scale_factors,
660            use_strip_pooling=use_strip_pooling,
661        )
662        self.deconv4 = Deconv3DBlock(
663            in_channels=features_decoder[2],
664            out_channels=features_decoder[3],
665            scale_factor=scale_factors,
666            use_strip_pooling=use_strip_pooling,
667        )
668
669        # The core decoder block.
670        self.decoder = decoder or Decoder(
671            features=features_decoder,
672            scale_factors=[scale_factors] * depth,
673            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
674            sampler_impl=Upsampler3d,
675        )
676
677        # And the final upsampler to match the expected dimensions.
678        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
679            in_channels=features_decoder[-1],
680            out_channels=features_decoder[-1],
681            scale_factor=scale_factors,
682            use_strip_pooling=use_strip_pooling,
683        )
684
685        # Additional conjunction blocks.
686        self.base = ConvBlock3dWithStrip(
687            in_channels=embed_dim,
688            out_channels=features_decoder[0],
689            use_strip_pooling=use_strip_pooling,
690        )
691
692        # And the output layers.
693        self.decoder_head = ConvBlock3dWithStrip(
694            in_channels=2 * features_decoder[-1],
695            out_channels=features_decoder[-1],
696            use_strip_pooling=use_strip_pooling,
697        )
698        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
699
700    def forward(self, x: torch.Tensor):
701        """Forward pass of the UNETR-3D model.
702
703        Args:
704            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
705
706        Returns:
707            The UNETR output.
708        """
709        B, C, Z, H, W = x.shape
710        original_shape = (Z, H, W)
711
712        # Preprocessing step
713        x, input_shape = self.preprocess(x)
714
715        # Run the image encoder.
716        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
717
718        # Prepare the counterparts for the decoder.
719        # NOTE: The section below is sequential, there's no skip connections atm.
720        z9 = self.deconv1(curr_features)
721        z6 = self.deconv2(z9)
722        z3 = self.deconv3(z6)
723        z0 = self.deconv4(z3)
724
725        updated_from_encoder = [z9, z6, z3]
726
727        # Align the features through the base block.
728        x = self.base(curr_features)
729        # Run the decoder
730        x = self.decoder(x, encoder_inputs=updated_from_encoder)
731        x = self.deconv_out(x)  # NOTE before `end_up`
732
733        # And the final output head.
734        x = torch.cat([x, z0], dim=1)
735        x = self.decoder_head(x)
736        x = self.out_conv(x)
737        if self.final_activation is not None:
738            x = self.final_activation(x)
739
740        # Postprocess the output back to original size.
741        x = self.postprocess_masks(x, input_shape, original_shape)
742        return x

A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.

UNETR3D( img_size: int = 1024, backbone: Literal['sam', 'sam2', 'sam3', 'cellpose_sam', 'mae', 'scalemae', 'dinov2', 'dinov3'] = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'hvit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = False, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, use_strip_pooling: bool = True, **kwargs)
589    def __init__(
590        self,
591        img_size: int = 1024,
592        backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam",
593        encoder: Optional[Union[nn.Module, str]] = "hvit_b",
594        decoder: Optional[nn.Module] = None,
595        out_channels: int = 1,
596        use_sam_stats: bool = False,
597        use_mae_stats: bool = False,
598        use_dino_stats: bool = False,
599        resize_input: bool = True,
600        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
601        final_activation: Optional[Union[str, nn.Module]] = None,
602        use_skip_connection: bool = False,
603        embed_dim: Optional[int] = None,
604        use_conv_transpose: bool = False,
605        use_strip_pooling: bool = True,
606        **kwargs
607    ):
608        if use_skip_connection:
609            raise NotImplementedError("The framework cannot handle skip connections atm.")
610        if use_conv_transpose:
611            raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.")
612
613        # Sort the `embed_dim` out
614        embed_dim = 256 if embed_dim is None else embed_dim
615
616        super().__init__(
617            img_size=img_size,
618            backbone=backbone,
619            encoder=encoder,
620            decoder=decoder,
621            out_channels=out_channels,
622            use_sam_stats=use_sam_stats,
623            use_mae_stats=use_mae_stats,
624            use_dino_stats=use_dino_stats,
625            resize_input=resize_input,
626            encoder_checkpoint=encoder_checkpoint,
627            final_activation=final_activation,
628            use_skip_connection=use_skip_connection,
629            embed_dim=embed_dim,
630            use_conv_transpose=use_conv_transpose,
631            **kwargs,
632        )
633
634        # The 3d convolutional decoder.
635        # First, get the important parameters for the decoder.
636        depth = 3
637        initial_features = 64
638        gain = 2
639        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
640        scale_factors = [1, 2, 2]
641        self.out_channels = out_channels
642
643        # The mapping blocks.
644        self.deconv1 = Deconv3DBlock(
645            in_channels=embed_dim,
646            out_channels=features_decoder[0],
647            scale_factor=scale_factors,
648            use_strip_pooling=use_strip_pooling,
649        )
650        self.deconv2 = Deconv3DBlock(
651            in_channels=features_decoder[0],
652            out_channels=features_decoder[1],
653            scale_factor=scale_factors,
654            use_strip_pooling=use_strip_pooling,
655        )
656        self.deconv3 = Deconv3DBlock(
657            in_channels=features_decoder[1],
658            out_channels=features_decoder[2],
659            scale_factor=scale_factors,
660            use_strip_pooling=use_strip_pooling,
661        )
662        self.deconv4 = Deconv3DBlock(
663            in_channels=features_decoder[2],
664            out_channels=features_decoder[3],
665            scale_factor=scale_factors,
666            use_strip_pooling=use_strip_pooling,
667        )
668
669        # The core decoder block.
670        self.decoder = decoder or Decoder(
671            features=features_decoder,
672            scale_factors=[scale_factors] * depth,
673            conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling),
674            sampler_impl=Upsampler3d,
675        )
676
677        # And the final upsampler to match the expected dimensions.
678        self.deconv_out = Deconv3DBlock(  # NOTE: changed `end_up` to `deconv_out`
679            in_channels=features_decoder[-1],
680            out_channels=features_decoder[-1],
681            scale_factor=scale_factors,
682            use_strip_pooling=use_strip_pooling,
683        )
684
685        # Additional conjunction blocks.
686        self.base = ConvBlock3dWithStrip(
687            in_channels=embed_dim,
688            out_channels=features_decoder[0],
689            use_strip_pooling=use_strip_pooling,
690        )
691
692        # And the output layers.
693        self.decoder_head = ConvBlock3dWithStrip(
694            in_channels=2 * features_decoder[-1],
695            out_channels=features_decoder[-1],
696            use_strip_pooling=use_strip_pooling,
697        )
698        self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

out_channels
deconv1
deconv2
deconv3
deconv4
decoder
deconv_out
base
decoder_head
out_conv
def forward(self, x: torch.Tensor):
700    def forward(self, x: torch.Tensor):
701        """Forward pass of the UNETR-3D model.
702
703        Args:
704            x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
705
706        Returns:
707            The UNETR output.
708        """
709        B, C, Z, H, W = x.shape
710        original_shape = (Z, H, W)
711
712        # Preprocessing step
713        x, input_shape = self.preprocess(x)
714
715        # Run the image encoder.
716        curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2)
717
718        # Prepare the counterparts for the decoder.
719        # NOTE: The section below is sequential, there's no skip connections atm.
720        z9 = self.deconv1(curr_features)
721        z6 = self.deconv2(z9)
722        z3 = self.deconv3(z6)
723        z0 = self.deconv4(z3)
724
725        updated_from_encoder = [z9, z6, z3]
726
727        # Align the features through the base block.
728        x = self.base(curr_features)
729        # Run the decoder
730        x = self.decoder(x, encoder_inputs=updated_from_encoder)
731        x = self.deconv_out(x)  # NOTE before `end_up`
732
733        # And the final output head.
734        x = torch.cat([x, z0], dim=1)
735        x = self.decoder_head(x)
736        x = self.out_conv(x)
737        if self.final_activation is not None:
738            x = self.final_activation(x)
739
740        # Postprocess the output back to original size.
741        x = self.postprocess_masks(x, input_shape, original_shape)
742        return x

Forward pass of the UNETR-3D model.

Arguments:
  • x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:

The UNETR output.