torch_em.model.unetr

  1from collections import OrderedDict
  2from typing import Optional, Tuple, Union
  3
  4import torch
  5import torch.nn as nn
  6import torch.nn.functional as F
  7
  8from .vit import get_vision_transformer
  9from .unet import Decoder, ConvBlock2d, Upsampler2d
 10
 11try:
 12    from micro_sam.util import get_sam_model
 13except ImportError:
 14    get_sam_model = None
 15
 16try:
 17    from micro_sam2.util import get_sam2_model
 18except ImportError:
 19    get_sam2_model = None
 20
 21
 22#
 23# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / MAE / ScaleMAE) + UNet Decoder from `torch_em`]
 24#
 25
 26
 27class UNETR(nn.Module):
 28    """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
 29
 30    Args:
 31        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 32        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
 33        encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
 34        decoder: The convolutional decoder.
 35        out_channels: The number of output channels of the UNETR.
 36        use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
 37        use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model.
 38        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 39        resize_input: Whether to resize the input images to match `img_size`.
 40            By default, it resizes the inputs to match the `img_size`.
 41        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 42            Can either be a filepath or an already loaded checkpoint.
 43        final_activation: The activation to apply to the UNETR output.
 44        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 45        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 46        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 47            By default, it uses resampling for upsampling.
 48    """
 49    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 50        """Function to load pretrained weights to the image encoder.
 51        """
 52        if isinstance(checkpoint, str):
 53            if backbone == "sam" and isinstance(encoder, str):
 54                # If we have a SAM encoder, then we first try to load the full SAM Model
 55                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 56                # from the checkpoint
 57                try:
 58                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
 59                    encoder_state = model.image_encoder.state_dict()
 60                except Exception:
 61                    # Try loading the encoder state directly from a checkpoint.
 62                    encoder_state = torch.load(checkpoint, weights_only=False)
 63
 64            elif backbone == "sam2" and isinstance(encoder, str):
 65                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
 66                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
 67                # from the checkpoint
 68                try:
 69                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
 70                    encoder_state = model.image_encoder.state_dict()
 71                except Exception:
 72                    # Try loading the encoder state directly from a checkpoint.
 73                    encoder_state = torch.load(checkpoint, weights_only=False)
 74
 75            elif backbone == "mae":
 76                # vit initialization hints from:
 77                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 78                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
 79                encoder_state = OrderedDict({
 80                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
 81                })
 82                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 83                current_encoder_state = self.encoder.state_dict()
 84                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 85                    del self.encoder.head
 86
 87            elif backbone == "scalemae":
 88                # Load the encoder state directly from a checkpoint.
 89                encoder_state = torch.load(checkpoint)["model"]
 90                encoder_state = OrderedDict({
 91                    k: v for k, v in encoder_state.items()
 92                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
 93                })
 94
 95                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 96                current_encoder_state = self.encoder.state_dict()
 97                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 98                    del self.encoder.head
 99
100                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
101                    del self.encoder.pos_embed
102
103            elif backbone == "dinov3":  # Load the encoder state directly from a checkpoint.
104                encoder_state = torch.load(checkpoint)
105
106            else:
107                raise ValueError(
108                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
109                )
110
111        else:
112            encoder_state = checkpoint
113
114        self.encoder.load_state_dict(encoder_state)
115
116    def __init__(
117        self,
118        img_size: int = 1024,
119        backbone: str = "sam",
120        encoder: Optional[Union[nn.Module, str]] = "vit_b",
121        decoder: Optional[nn.Module] = None,
122        out_channels: int = 1,
123        use_sam_stats: bool = False,
124        use_mae_stats: bool = False,
125        use_dino_stats: bool = False,
126        resize_input: bool = True,
127        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
128        final_activation: Optional[Union[str, nn.Module]] = None,
129        use_skip_connection: bool = True,
130        embed_dim: Optional[int] = None,
131        use_conv_transpose: bool = False,
132        **kwargs
133    ) -> None:
134        super().__init__()
135
136        self.use_sam_stats = use_sam_stats
137        self.use_mae_stats = use_mae_stats
138        self.use_dino_stats = use_dino_stats
139        self.use_skip_connection = use_skip_connection
140        self.resize_input = resize_input
141
142        if isinstance(encoder, str):  # e.g. "vit_b" / "vit_l" / "vit_h"
143            print(f"Using {encoder} from {backbone.upper()}")
144            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
145
146            if encoder_checkpoint is not None:
147                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
148
149            if backbone == "sam2":
150                in_chans = self.encoder.trunk.patch_embed.proj.in_channels
151            else:
152                in_chans = self.encoder.in_chans
153
154            if embed_dim is None:
155                embed_dim = self.encoder.embed_dim
156
157        else:  # `nn.Module` ViT backbone
158            self.encoder = encoder
159
160            have_neck = False
161            for name, _ in self.encoder.named_parameters():
162                if name.startswith("neck"):
163                    have_neck = True
164
165            if embed_dim is None:
166                if have_neck:
167                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
168                else:
169                    embed_dim = self.encoder.patch_embed.proj.out_channels
170
171            try:
172                in_chans = self.encoder.patch_embed.proj.in_channels
173            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
174                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
175
176        # parameters for the decoder network
177        depth = 3
178        initial_features = 64
179        gain = 2
180        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
181        scale_factors = depth * [2]
182        self.out_channels = out_channels
183
184        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
185        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
186
187        if decoder is None:
188            self.decoder = Decoder(
189                features=features_decoder,
190                scale_factors=scale_factors[::-1],
191                conv_block_impl=ConvBlock2d,
192                sampler_impl=_upsampler,
193            )
194        else:
195            self.decoder = decoder
196
197        if use_skip_connection:
198            self.deconv1 = Deconv2DBlock(
199                in_channels=embed_dim,
200                out_channels=features_decoder[0],
201                use_conv_transpose=use_conv_transpose,
202            )
203            self.deconv2 = nn.Sequential(
204                Deconv2DBlock(
205                    in_channels=embed_dim,
206                    out_channels=features_decoder[0],
207                    use_conv_transpose=use_conv_transpose,
208                ),
209                Deconv2DBlock(
210                    in_channels=features_decoder[0],
211                    out_channels=features_decoder[1],
212                    use_conv_transpose=use_conv_transpose,
213                )
214            )
215            self.deconv3 = nn.Sequential(
216                Deconv2DBlock(
217                    in_channels=embed_dim,
218                    out_channels=features_decoder[0],
219                    use_conv_transpose=use_conv_transpose,
220                ),
221                Deconv2DBlock(
222                    in_channels=features_decoder[0],
223                    out_channels=features_decoder[1],
224                    use_conv_transpose=use_conv_transpose,
225                ),
226                Deconv2DBlock(
227                    in_channels=features_decoder[1],
228                    out_channels=features_decoder[2],
229                    use_conv_transpose=use_conv_transpose,
230                )
231            )
232            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
233        else:
234            self.deconv1 = Deconv2DBlock(
235                in_channels=embed_dim,
236                out_channels=features_decoder[0],
237                use_conv_transpose=use_conv_transpose,
238            )
239            self.deconv2 = Deconv2DBlock(
240                in_channels=features_decoder[0],
241                out_channels=features_decoder[1],
242                use_conv_transpose=use_conv_transpose,
243            )
244            self.deconv3 = Deconv2DBlock(
245                in_channels=features_decoder[1],
246                out_channels=features_decoder[2],
247                use_conv_transpose=use_conv_transpose,
248            )
249            self.deconv4 = Deconv2DBlock(
250                in_channels=features_decoder[2],
251                out_channels=features_decoder[3],
252                use_conv_transpose=use_conv_transpose,
253            )
254
255        self.base = ConvBlock2d(embed_dim, features_decoder[0])
256        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
257        self.deconv_out = _upsampler(
258            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
259        )
260        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
261        self.final_activation = self._get_activation(final_activation)
262
263    def _get_activation(self, activation):
264        return_activation = None
265        if activation is None:
266            return None
267        if isinstance(activation, nn.Module):
268            return activation
269        if isinstance(activation, str):
270            return_activation = getattr(nn, activation, None)
271        if return_activation is None:
272            raise ValueError(f"Invalid activation: {activation}")
273
274        return return_activation()
275
276    @staticmethod
277    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
278        """Compute the output size given input size and target long side length.
279
280        Args:
281            oldh: The input image height.
282            oldw: The input image width.
283            long_side_length: The longest side length for resizing.
284
285        Returns:
286            The new image height.
287            The new image width.
288        """
289        scale = long_side_length * 1.0 / max(oldh, oldw)
290        newh, neww = oldh * scale, oldw * scale
291        neww = int(neww + 0.5)
292        newh = int(newh + 0.5)
293        return (newh, neww)
294
295    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
296        """Resize the image so that the longest side has the correct length.
297
298        Expects batched images with shape BxCxHxW and float format.
299
300        Args:
301            image: The input image.
302
303        Returns:
304            The resized image.
305        """
306        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
307        return F.interpolate(
308            image, target_size, mode="bilinear", align_corners=False, antialias=True
309        )
310
311    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
312        """@private
313        """
314        device = x.device
315
316        if self.use_sam_stats:
317            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
318            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
319        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
320            raise NotImplementedError
321        elif self.use_dino_stats:
322            pixel_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to(device)
323            pixel_std = torch.Tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to(device)
324        else:
325            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
326            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
327
328        if self.resize_input:
329            x = self.resize_longest_side(x)
330        input_shape = x.shape[-2:]
331
332        x = (x - pixel_mean) / pixel_std
333        h, w = x.shape[-2:]
334        padh = self.encoder.img_size - h
335        padw = self.encoder.img_size - w
336        x = F.pad(x, (0, padw, 0, padh))
337        return x, input_shape
338
339    def postprocess_masks(
340        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
341    ) -> torch.Tensor:
342        """@private
343        """
344        masks = F.interpolate(
345            masks,
346            (self.encoder.img_size, self.encoder.img_size),
347            mode="bilinear",
348            align_corners=False,
349        )
350        masks = masks[..., : input_size[0], : input_size[1]]
351        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
352        return masks
353
354    def forward(self, x: torch.Tensor) -> torch.Tensor:
355        """Apply the UNETR to the input data.
356
357        Args:
358            x: The input tensor.
359
360        Returns:
361            The UNETR output.
362        """
363        original_shape = x.shape[-2:]
364
365        # Reshape the inputs to the shape expected by the encoder
366        # and normalize the inputs if normalization is part of the model.
367        x, input_shape = self.preprocess(x)
368
369        use_skip_connection = getattr(self, "use_skip_connection", True)
370
371        encoder_outputs = self.encoder(x)
372
373        if isinstance(encoder_outputs[-1], list):
374            # `encoder_outputs` can be arranged in only two forms:
375            #   - either we only return the image embeddings
376            #   - or, we return the image embeddings and the "list" of global attention layers
377            z12, from_encoder = encoder_outputs
378        else:
379            z12 = encoder_outputs
380
381        if use_skip_connection:
382            from_encoder = from_encoder[::-1]
383            z9 = self.deconv1(from_encoder[0])
384            z6 = self.deconv2(from_encoder[1])
385            z3 = self.deconv3(from_encoder[2])
386            z0 = self.deconv4(x)
387
388        else:
389            z9 = self.deconv1(z12)
390            z6 = self.deconv2(z9)
391            z3 = self.deconv3(z6)
392            z0 = self.deconv4(z3)
393
394        updated_from_encoder = [z9, z6, z3]
395
396        x = self.base(z12)
397        x = self.decoder(x, encoder_inputs=updated_from_encoder)
398        x = self.deconv_out(x)
399
400        x = torch.cat([x, z0], dim=1)
401        x = self.decoder_head(x)
402
403        x = self.out_conv(x)
404        if self.final_activation is not None:
405            x = self.final_activation(x)
406
407        x = self.postprocess_masks(x, input_shape, original_shape)
408        return x
409
410
411#
412#  ADDITIONAL FUNCTIONALITIES
413#
414
415
416class SingleDeconv2DBlock(nn.Module):
417    """@private
418    """
419    def __init__(self, scale_factor, in_channels, out_channels):
420        super().__init__()
421        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
422
423    def forward(self, x):
424        return self.block(x)
425
426
427class SingleConv2DBlock(nn.Module):
428    """@private
429    """
430    def __init__(self, in_channels, out_channels, kernel_size):
431        super().__init__()
432        self.block = nn.Conv2d(
433            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
434        )
435
436    def forward(self, x):
437        return self.block(x)
438
439
440class Conv2DBlock(nn.Module):
441    """@private
442    """
443    def __init__(self, in_channels, out_channels, kernel_size=3):
444        super().__init__()
445        self.block = nn.Sequential(
446            SingleConv2DBlock(in_channels, out_channels, kernel_size),
447            nn.BatchNorm2d(out_channels),
448            nn.ReLU(True)
449        )
450
451    def forward(self, x):
452        return self.block(x)
453
454
455class Deconv2DBlock(nn.Module):
456    """@private
457    """
458    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
459        super().__init__()
460        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
461        self.block = nn.Sequential(
462            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
463            SingleConv2DBlock(out_channels, out_channels, kernel_size),
464            nn.BatchNorm2d(out_channels),
465            nn.ReLU(True)
466        )
467
468    def forward(self, x):
469        return self.block(x)
class UNETR(torch.nn.modules.module.Module):
 28class UNETR(nn.Module):
 29    """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
 30
 31    Args:
 32        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 33        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
 34        encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
 35        decoder: The convolutional decoder.
 36        out_channels: The number of output channels of the UNETR.
 37        use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
 38        use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model.
 39        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 40        resize_input: Whether to resize the input images to match `img_size`.
 41            By default, it resizes the inputs to match the `img_size`.
 42        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 43            Can either be a filepath or an already loaded checkpoint.
 44        final_activation: The activation to apply to the UNETR output.
 45        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 46        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 47        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 48            By default, it uses resampling for upsampling.
 49    """
 50    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 51        """Function to load pretrained weights to the image encoder.
 52        """
 53        if isinstance(checkpoint, str):
 54            if backbone == "sam" and isinstance(encoder, str):
 55                # If we have a SAM encoder, then we first try to load the full SAM Model
 56                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 57                # from the checkpoint
 58                try:
 59                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
 60                    encoder_state = model.image_encoder.state_dict()
 61                except Exception:
 62                    # Try loading the encoder state directly from a checkpoint.
 63                    encoder_state = torch.load(checkpoint, weights_only=False)
 64
 65            elif backbone == "sam2" and isinstance(encoder, str):
 66                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
 67                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
 68                # from the checkpoint
 69                try:
 70                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
 71                    encoder_state = model.image_encoder.state_dict()
 72                except Exception:
 73                    # Try loading the encoder state directly from a checkpoint.
 74                    encoder_state = torch.load(checkpoint, weights_only=False)
 75
 76            elif backbone == "mae":
 77                # vit initialization hints from:
 78                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 79                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
 80                encoder_state = OrderedDict({
 81                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
 82                })
 83                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 84                current_encoder_state = self.encoder.state_dict()
 85                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 86                    del self.encoder.head
 87
 88            elif backbone == "scalemae":
 89                # Load the encoder state directly from a checkpoint.
 90                encoder_state = torch.load(checkpoint)["model"]
 91                encoder_state = OrderedDict({
 92                    k: v for k, v in encoder_state.items()
 93                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
 94                })
 95
 96                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 97                current_encoder_state = self.encoder.state_dict()
 98                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 99                    del self.encoder.head
100
101                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
102                    del self.encoder.pos_embed
103
104            elif backbone == "dinov3":  # Load the encoder state directly from a checkpoint.
105                encoder_state = torch.load(checkpoint)
106
107            else:
108                raise ValueError(
109                    f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)."
110                )
111
112        else:
113            encoder_state = checkpoint
114
115        self.encoder.load_state_dict(encoder_state)
116
117    def __init__(
118        self,
119        img_size: int = 1024,
120        backbone: str = "sam",
121        encoder: Optional[Union[nn.Module, str]] = "vit_b",
122        decoder: Optional[nn.Module] = None,
123        out_channels: int = 1,
124        use_sam_stats: bool = False,
125        use_mae_stats: bool = False,
126        use_dino_stats: bool = False,
127        resize_input: bool = True,
128        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
129        final_activation: Optional[Union[str, nn.Module]] = None,
130        use_skip_connection: bool = True,
131        embed_dim: Optional[int] = None,
132        use_conv_transpose: bool = False,
133        **kwargs
134    ) -> None:
135        super().__init__()
136
137        self.use_sam_stats = use_sam_stats
138        self.use_mae_stats = use_mae_stats
139        self.use_dino_stats = use_dino_stats
140        self.use_skip_connection = use_skip_connection
141        self.resize_input = resize_input
142
143        if isinstance(encoder, str):  # e.g. "vit_b" / "vit_l" / "vit_h"
144            print(f"Using {encoder} from {backbone.upper()}")
145            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
146
147            if encoder_checkpoint is not None:
148                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
149
150            if backbone == "sam2":
151                in_chans = self.encoder.trunk.patch_embed.proj.in_channels
152            else:
153                in_chans = self.encoder.in_chans
154
155            if embed_dim is None:
156                embed_dim = self.encoder.embed_dim
157
158        else:  # `nn.Module` ViT backbone
159            self.encoder = encoder
160
161            have_neck = False
162            for name, _ in self.encoder.named_parameters():
163                if name.startswith("neck"):
164                    have_neck = True
165
166            if embed_dim is None:
167                if have_neck:
168                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
169                else:
170                    embed_dim = self.encoder.patch_embed.proj.out_channels
171
172            try:
173                in_chans = self.encoder.patch_embed.proj.in_channels
174            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
175                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
176
177        # parameters for the decoder network
178        depth = 3
179        initial_features = 64
180        gain = 2
181        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
182        scale_factors = depth * [2]
183        self.out_channels = out_channels
184
185        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
186        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
187
188        if decoder is None:
189            self.decoder = Decoder(
190                features=features_decoder,
191                scale_factors=scale_factors[::-1],
192                conv_block_impl=ConvBlock2d,
193                sampler_impl=_upsampler,
194            )
195        else:
196            self.decoder = decoder
197
198        if use_skip_connection:
199            self.deconv1 = Deconv2DBlock(
200                in_channels=embed_dim,
201                out_channels=features_decoder[0],
202                use_conv_transpose=use_conv_transpose,
203            )
204            self.deconv2 = nn.Sequential(
205                Deconv2DBlock(
206                    in_channels=embed_dim,
207                    out_channels=features_decoder[0],
208                    use_conv_transpose=use_conv_transpose,
209                ),
210                Deconv2DBlock(
211                    in_channels=features_decoder[0],
212                    out_channels=features_decoder[1],
213                    use_conv_transpose=use_conv_transpose,
214                )
215            )
216            self.deconv3 = nn.Sequential(
217                Deconv2DBlock(
218                    in_channels=embed_dim,
219                    out_channels=features_decoder[0],
220                    use_conv_transpose=use_conv_transpose,
221                ),
222                Deconv2DBlock(
223                    in_channels=features_decoder[0],
224                    out_channels=features_decoder[1],
225                    use_conv_transpose=use_conv_transpose,
226                ),
227                Deconv2DBlock(
228                    in_channels=features_decoder[1],
229                    out_channels=features_decoder[2],
230                    use_conv_transpose=use_conv_transpose,
231                )
232            )
233            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
234        else:
235            self.deconv1 = Deconv2DBlock(
236                in_channels=embed_dim,
237                out_channels=features_decoder[0],
238                use_conv_transpose=use_conv_transpose,
239            )
240            self.deconv2 = Deconv2DBlock(
241                in_channels=features_decoder[0],
242                out_channels=features_decoder[1],
243                use_conv_transpose=use_conv_transpose,
244            )
245            self.deconv3 = Deconv2DBlock(
246                in_channels=features_decoder[1],
247                out_channels=features_decoder[2],
248                use_conv_transpose=use_conv_transpose,
249            )
250            self.deconv4 = Deconv2DBlock(
251                in_channels=features_decoder[2],
252                out_channels=features_decoder[3],
253                use_conv_transpose=use_conv_transpose,
254            )
255
256        self.base = ConvBlock2d(embed_dim, features_decoder[0])
257        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
258        self.deconv_out = _upsampler(
259            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
260        )
261        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
262        self.final_activation = self._get_activation(final_activation)
263
264    def _get_activation(self, activation):
265        return_activation = None
266        if activation is None:
267            return None
268        if isinstance(activation, nn.Module):
269            return activation
270        if isinstance(activation, str):
271            return_activation = getattr(nn, activation, None)
272        if return_activation is None:
273            raise ValueError(f"Invalid activation: {activation}")
274
275        return return_activation()
276
277    @staticmethod
278    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
279        """Compute the output size given input size and target long side length.
280
281        Args:
282            oldh: The input image height.
283            oldw: The input image width.
284            long_side_length: The longest side length for resizing.
285
286        Returns:
287            The new image height.
288            The new image width.
289        """
290        scale = long_side_length * 1.0 / max(oldh, oldw)
291        newh, neww = oldh * scale, oldw * scale
292        neww = int(neww + 0.5)
293        newh = int(newh + 0.5)
294        return (newh, neww)
295
296    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
297        """Resize the image so that the longest side has the correct length.
298
299        Expects batched images with shape BxCxHxW and float format.
300
301        Args:
302            image: The input image.
303
304        Returns:
305            The resized image.
306        """
307        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
308        return F.interpolate(
309            image, target_size, mode="bilinear", align_corners=False, antialias=True
310        )
311
312    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
313        """@private
314        """
315        device = x.device
316
317        if self.use_sam_stats:
318            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
319            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
320        elif self.use_mae_stats:  # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
321            raise NotImplementedError
322        elif self.use_dino_stats:
323            pixel_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to(device)
324            pixel_std = torch.Tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to(device)
325        else:
326            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
327            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
328
329        if self.resize_input:
330            x = self.resize_longest_side(x)
331        input_shape = x.shape[-2:]
332
333        x = (x - pixel_mean) / pixel_std
334        h, w = x.shape[-2:]
335        padh = self.encoder.img_size - h
336        padw = self.encoder.img_size - w
337        x = F.pad(x, (0, padw, 0, padh))
338        return x, input_shape
339
340    def postprocess_masks(
341        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
342    ) -> torch.Tensor:
343        """@private
344        """
345        masks = F.interpolate(
346            masks,
347            (self.encoder.img_size, self.encoder.img_size),
348            mode="bilinear",
349            align_corners=False,
350        )
351        masks = masks[..., : input_size[0], : input_size[1]]
352        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
353        return masks
354
355    def forward(self, x: torch.Tensor) -> torch.Tensor:
356        """Apply the UNETR to the input data.
357
358        Args:
359            x: The input tensor.
360
361        Returns:
362            The UNETR output.
363        """
364        original_shape = x.shape[-2:]
365
366        # Reshape the inputs to the shape expected by the encoder
367        # and normalize the inputs if normalization is part of the model.
368        x, input_shape = self.preprocess(x)
369
370        use_skip_connection = getattr(self, "use_skip_connection", True)
371
372        encoder_outputs = self.encoder(x)
373
374        if isinstance(encoder_outputs[-1], list):
375            # `encoder_outputs` can be arranged in only two forms:
376            #   - either we only return the image embeddings
377            #   - or, we return the image embeddings and the "list" of global attention layers
378            z12, from_encoder = encoder_outputs
379        else:
380            z12 = encoder_outputs
381
382        if use_skip_connection:
383            from_encoder = from_encoder[::-1]
384            z9 = self.deconv1(from_encoder[0])
385            z6 = self.deconv2(from_encoder[1])
386            z3 = self.deconv3(from_encoder[2])
387            z0 = self.deconv4(x)
388
389        else:
390            z9 = self.deconv1(z12)
391            z6 = self.deconv2(z9)
392            z3 = self.deconv3(z6)
393            z0 = self.deconv4(z3)
394
395        updated_from_encoder = [z9, z6, z3]
396
397        x = self.base(z12)
398        x = self.decoder(x, encoder_inputs=updated_from_encoder)
399        x = self.deconv_out(x)
400
401        x = torch.cat([x, z0], dim=1)
402        x = self.decoder_head(x)
403
404        x = self.out_conv(x)
405        if self.final_activation is not None:
406            x = self.final_activation(x)
407
408        x = self.postprocess_masks(x, input_shape, original_shape)
409        return x

A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • backbone: The name of the vision transformer implementation. One of "sam" or "mae".
  • encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
  • decoder: The convolutional decoder.
  • out_channels: The number of output channels of the UNETR.
  • use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
  • use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model.
  • use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  • resize_input: Whether to resize the input images to match img_size. By default, it resizes the inputs to match the img_size.
  • encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
  • final_activation: The activation to apply to the UNETR output.
  • use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
  • embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
  • use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
UNETR( img_size: int = 1024, backbone: str = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
117    def __init__(
118        self,
119        img_size: int = 1024,
120        backbone: str = "sam",
121        encoder: Optional[Union[nn.Module, str]] = "vit_b",
122        decoder: Optional[nn.Module] = None,
123        out_channels: int = 1,
124        use_sam_stats: bool = False,
125        use_mae_stats: bool = False,
126        use_dino_stats: bool = False,
127        resize_input: bool = True,
128        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
129        final_activation: Optional[Union[str, nn.Module]] = None,
130        use_skip_connection: bool = True,
131        embed_dim: Optional[int] = None,
132        use_conv_transpose: bool = False,
133        **kwargs
134    ) -> None:
135        super().__init__()
136
137        self.use_sam_stats = use_sam_stats
138        self.use_mae_stats = use_mae_stats
139        self.use_dino_stats = use_dino_stats
140        self.use_skip_connection = use_skip_connection
141        self.resize_input = resize_input
142
143        if isinstance(encoder, str):  # e.g. "vit_b" / "vit_l" / "vit_h"
144            print(f"Using {encoder} from {backbone.upper()}")
145            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
146
147            if encoder_checkpoint is not None:
148                self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint)
149
150            if backbone == "sam2":
151                in_chans = self.encoder.trunk.patch_embed.proj.in_channels
152            else:
153                in_chans = self.encoder.in_chans
154
155            if embed_dim is None:
156                embed_dim = self.encoder.embed_dim
157
158        else:  # `nn.Module` ViT backbone
159            self.encoder = encoder
160
161            have_neck = False
162            for name, _ in self.encoder.named_parameters():
163                if name.startswith("neck"):
164                    have_neck = True
165
166            if embed_dim is None:
167                if have_neck:
168                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
169                else:
170                    embed_dim = self.encoder.patch_embed.proj.out_channels
171
172            try:
173                in_chans = self.encoder.patch_embed.proj.in_channels
174            except AttributeError:  # for getting the input channels while using 'vit_t' from MobileSam
175                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
176
177        # parameters for the decoder network
178        depth = 3
179        initial_features = 64
180        gain = 2
181        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
182        scale_factors = depth * [2]
183        self.out_channels = out_channels
184
185        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
186        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
187
188        if decoder is None:
189            self.decoder = Decoder(
190                features=features_decoder,
191                scale_factors=scale_factors[::-1],
192                conv_block_impl=ConvBlock2d,
193                sampler_impl=_upsampler,
194            )
195        else:
196            self.decoder = decoder
197
198        if use_skip_connection:
199            self.deconv1 = Deconv2DBlock(
200                in_channels=embed_dim,
201                out_channels=features_decoder[0],
202                use_conv_transpose=use_conv_transpose,
203            )
204            self.deconv2 = nn.Sequential(
205                Deconv2DBlock(
206                    in_channels=embed_dim,
207                    out_channels=features_decoder[0],
208                    use_conv_transpose=use_conv_transpose,
209                ),
210                Deconv2DBlock(
211                    in_channels=features_decoder[0],
212                    out_channels=features_decoder[1],
213                    use_conv_transpose=use_conv_transpose,
214                )
215            )
216            self.deconv3 = nn.Sequential(
217                Deconv2DBlock(
218                    in_channels=embed_dim,
219                    out_channels=features_decoder[0],
220                    use_conv_transpose=use_conv_transpose,
221                ),
222                Deconv2DBlock(
223                    in_channels=features_decoder[0],
224                    out_channels=features_decoder[1],
225                    use_conv_transpose=use_conv_transpose,
226                ),
227                Deconv2DBlock(
228                    in_channels=features_decoder[1],
229                    out_channels=features_decoder[2],
230                    use_conv_transpose=use_conv_transpose,
231                )
232            )
233            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
234        else:
235            self.deconv1 = Deconv2DBlock(
236                in_channels=embed_dim,
237                out_channels=features_decoder[0],
238                use_conv_transpose=use_conv_transpose,
239            )
240            self.deconv2 = Deconv2DBlock(
241                in_channels=features_decoder[0],
242                out_channels=features_decoder[1],
243                use_conv_transpose=use_conv_transpose,
244            )
245            self.deconv3 = Deconv2DBlock(
246                in_channels=features_decoder[1],
247                out_channels=features_decoder[2],
248                use_conv_transpose=use_conv_transpose,
249            )
250            self.deconv4 = Deconv2DBlock(
251                in_channels=features_decoder[2],
252                out_channels=features_decoder[3],
253                use_conv_transpose=use_conv_transpose,
254            )
255
256        self.base = ConvBlock2d(embed_dim, features_decoder[0])
257        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
258        self.deconv_out = _upsampler(
259            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
260        )
261        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
262        self.final_activation = self._get_activation(final_activation)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

use_sam_stats
use_mae_stats
use_dino_stats
use_skip_connection
resize_input
out_channels
base
out_conv
deconv_out
decoder_head
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
277    @staticmethod
278    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
279        """Compute the output size given input size and target long side length.
280
281        Args:
282            oldh: The input image height.
283            oldw: The input image width.
284            long_side_length: The longest side length for resizing.
285
286        Returns:
287            The new image height.
288            The new image width.
289        """
290        scale = long_side_length * 1.0 / max(oldh, oldw)
291        newh, neww = oldh * scale, oldw * scale
292        neww = int(neww + 0.5)
293        newh = int(newh + 0.5)
294        return (newh, neww)

Compute the output size given input size and target long side length.

Arguments:
  • oldh: The input image height.
  • oldw: The input image width.
  • long_side_length: The longest side length for resizing.
Returns:

The new image height. The new image width.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
296    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
297        """Resize the image so that the longest side has the correct length.
298
299        Expects batched images with shape BxCxHxW and float format.
300
301        Args:
302            image: The input image.
303
304        Returns:
305            The resized image.
306        """
307        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
308        return F.interpolate(
309            image, target_size, mode="bilinear", align_corners=False, antialias=True
310        )

Resize the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW and float format.

Arguments:
  • image: The input image.
Returns:

The resized image.

def forward(self, x: torch.Tensor) -> torch.Tensor:
355    def forward(self, x: torch.Tensor) -> torch.Tensor:
356        """Apply the UNETR to the input data.
357
358        Args:
359            x: The input tensor.
360
361        Returns:
362            The UNETR output.
363        """
364        original_shape = x.shape[-2:]
365
366        # Reshape the inputs to the shape expected by the encoder
367        # and normalize the inputs if normalization is part of the model.
368        x, input_shape = self.preprocess(x)
369
370        use_skip_connection = getattr(self, "use_skip_connection", True)
371
372        encoder_outputs = self.encoder(x)
373
374        if isinstance(encoder_outputs[-1], list):
375            # `encoder_outputs` can be arranged in only two forms:
376            #   - either we only return the image embeddings
377            #   - or, we return the image embeddings and the "list" of global attention layers
378            z12, from_encoder = encoder_outputs
379        else:
380            z12 = encoder_outputs
381
382        if use_skip_connection:
383            from_encoder = from_encoder[::-1]
384            z9 = self.deconv1(from_encoder[0])
385            z6 = self.deconv2(from_encoder[1])
386            z3 = self.deconv3(from_encoder[2])
387            z0 = self.deconv4(x)
388
389        else:
390            z9 = self.deconv1(z12)
391            z6 = self.deconv2(z9)
392            z3 = self.deconv3(z6)
393            z0 = self.deconv4(z3)
394
395        updated_from_encoder = [z9, z6, z3]
396
397        x = self.base(z12)
398        x = self.decoder(x, encoder_inputs=updated_from_encoder)
399        x = self.deconv_out(x)
400
401        x = torch.cat([x, z0], dim=1)
402        x = self.decoder_head(x)
403
404        x = self.out_conv(x)
405        if self.final_activation is not None:
406            x = self.final_activation(x)
407
408        x = self.postprocess_masks(x, input_shape, original_shape)
409        return x

Apply the UNETR to the input data.

Arguments:
  • x: The input tensor.
Returns:

The UNETR output.