torch_em.model.unetr

  1from collections import OrderedDict
  2from typing import Optional, Tuple, Union
  3
  4import torch
  5import torch.nn as nn
  6import torch.nn.functional as F
  7
  8from .vit import get_vision_transformer
  9from .unet import Decoder, ConvBlock2d, Upsampler2d
 10
 11try:
 12    from micro_sam.util import get_sam_model
 13except ImportError:
 14    get_sam_model = None
 15
 16try:
 17    from micro_sam2.util import get_sam2_model
 18except ImportError:
 19    get_sam2_model = None
 20
 21
 22#
 23# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / MAE / ScaleMAE) + UNet Decoder from `torch_em`]
 24#
 25
 26
 27class UNETR(nn.Module):
 28    """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
 29
 30    Args:
 31        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 32        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
 33        encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
 34        decoder: The convolutional decoder.
 35        out_channels: The number of output channels of the UNETR.
 36        use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
 37        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 38        resize_input: Whether to resize the input images to match `img_size`.
 39            By default, it resizes the inputs to match the `img_size`.
 40        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 41            Can either be a filepath or an already loaded checkpoint.
 42        final_activation: The activation to apply to the UNETR output.
 43        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 44        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 45        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 46            By default, it uses resampling for upsampling.
 47    """
 48    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 49        """Function to load pretrained weights to the image encoder.
 50        """
 51        if isinstance(checkpoint, str):
 52            if backbone == "sam" and isinstance(encoder, str):
 53                # If we have a SAM encoder, then we first try to load the full SAM Model
 54                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 55                # from the checkpoint
 56                try:
 57                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
 58                    encoder_state = model.image_encoder.state_dict()
 59                except Exception:
 60                    # Try loading the encoder state directly from a checkpoint.
 61                    encoder_state = torch.load(checkpoint, weights_only=False)
 62
 63            elif backbone == "sam2" and isinstance(encoder, str):
 64                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
 65                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
 66                # from the checkpoint
 67                try:
 68                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
 69                    encoder_state = model.image_encoder.state_dict()
 70                except Exception:
 71                    # Try loading the encoder state directly from a checkpoint.
 72                    encoder_state = torch.load(checkpoint, weights_only=False)
 73
 74            elif backbone == "mae":
 75                # vit initialization hints from:
 76                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 77                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
 78                encoder_state = OrderedDict({
 79                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
 80                })
 81                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 82                current_encoder_state = self.encoder.state_dict()
 83                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 84                    del self.encoder.head
 85
 86            elif backbone == "scalemae":
 87                # Load the encoder state directly from a checkpoint.
 88                encoder_state = torch.load(checkpoint)["model"]
 89                encoder_state = OrderedDict({
 90                    k: v for k, v in encoder_state.items()
 91                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
 92                })
 93
 94                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 95                current_encoder_state = self.encoder.state_dict()
 96                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 97                    del self.encoder.head
 98
 99                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
100                    del self.encoder.pos_embed
101
102        else:
103            encoder_state = checkpoint
104
105        self.encoder.load_state_dict(encoder_state)
106
107    def __init__(
108        self,
109        img_size: int = 1024,
110        backbone: str = "sam",
111        encoder: Optional[Union[nn.Module, str]] = "vit_b",
112        decoder: Optional[nn.Module] = None,
113        out_channels: int = 1,
114        use_sam_stats: bool = False,
115        use_mae_stats: bool = False,
116        resize_input: bool = True,
117        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
118        final_activation: Optional[Union[str, nn.Module]] = None,
119        use_skip_connection: bool = True,
120        embed_dim: Optional[int] = None,
121        use_conv_transpose: bool = False,
122        **kwargs
123    ) -> None:
124        super().__init__()
125
126        self.use_sam_stats = use_sam_stats
127        self.use_mae_stats = use_mae_stats
128        self.use_skip_connection = use_skip_connection
129        self.resize_input = resize_input
130
131        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
132            print(f"Using {encoder} from {backbone.upper()}")
133            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
134
135            if encoder_checkpoint is not None:
136                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
137
138            if backbone == "sam2":
139                in_chans = self.encoder.trunk.patch_embed.proj.in_channels
140            else:
141                in_chans = self.encoder.in_chans
142
143            if embed_dim is None:
144                embed_dim = self.encoder.embed_dim
145
146        else:  # `nn.Module` ViT backbone
147            self.encoder = encoder
148
149            have_neck = False
150            for name, _ in self.encoder.named_parameters():
151                if name.startswith("neck"):
152                    have_neck = True
153
154            if embed_dim is None:
155                if have_neck:
156                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
157                else:
158                    embed_dim = self.encoder.patch_embed.proj.out_channels
159
160            try:
161                in_chans = self.encoder.patch_embed.proj.in_channels
162            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
163                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
164
165        # parameters for the decoder network
166        depth = 3
167        initial_features = 64
168        gain = 2
169        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
170        scale_factors = depth * [2]
171        self.out_channels = out_channels
172
173        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
174        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
175
176        if decoder is None:
177            self.decoder = Decoder(
178                features=features_decoder,
179                scale_factors=scale_factors[::-1],
180                conv_block_impl=ConvBlock2d,
181                sampler_impl=_upsampler,
182            )
183        else:
184            self.decoder = decoder
185
186        if use_skip_connection:
187            self.deconv1 = Deconv2DBlock(
188                in_channels=embed_dim,
189                out_channels=features_decoder[0],
190                use_conv_transpose=use_conv_transpose,
191            )
192            self.deconv2 = nn.Sequential(
193                Deconv2DBlock(
194                    in_channels=embed_dim,
195                    out_channels=features_decoder[0],
196                    use_conv_transpose=use_conv_transpose,
197                ),
198                Deconv2DBlock(
199                    in_channels=features_decoder[0],
200                    out_channels=features_decoder[1],
201                    use_conv_transpose=use_conv_transpose,
202                )
203            )
204            self.deconv3 = nn.Sequential(
205                Deconv2DBlock(
206                    in_channels=embed_dim,
207                    out_channels=features_decoder[0],
208                    use_conv_transpose=use_conv_transpose,
209                ),
210                Deconv2DBlock(
211                    in_channels=features_decoder[0],
212                    out_channels=features_decoder[1],
213                    use_conv_transpose=use_conv_transpose,
214                ),
215                Deconv2DBlock(
216                    in_channels=features_decoder[1],
217                    out_channels=features_decoder[2],
218                    use_conv_transpose=use_conv_transpose,
219                )
220            )
221            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
222        else:
223            self.deconv1 = Deconv2DBlock(
224                in_channels=embed_dim,
225                out_channels=features_decoder[0],
226                use_conv_transpose=use_conv_transpose,
227            )
228            self.deconv2 = Deconv2DBlock(
229                in_channels=features_decoder[0],
230                out_channels=features_decoder[1],
231                use_conv_transpose=use_conv_transpose,
232            )
233            self.deconv3 = Deconv2DBlock(
234                in_channels=features_decoder[1],
235                out_channels=features_decoder[2],
236                use_conv_transpose=use_conv_transpose,
237            )
238            self.deconv4 = Deconv2DBlock(
239                in_channels=features_decoder[2],
240                out_channels=features_decoder[3],
241                use_conv_transpose=use_conv_transpose,
242            )
243
244        self.base = ConvBlock2d(embed_dim, features_decoder[0])
245        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
246        self.deconv_out = _upsampler(
247            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
248        )
249        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
250        self.final_activation = self._get_activation(final_activation)
251
252    def _get_activation(self, activation):
253        return_activation = None
254        if activation is None:
255            return None
256        if isinstance(activation, nn.Module):
257            return activation
258        if isinstance(activation, str):
259            return_activation = getattr(nn, activation, None)
260        if return_activation is None:
261            raise ValueError(f"Invalid activation: {activation}")
262
263        return return_activation()
264
265    @staticmethod
266    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
267        """Compute the output size given input size and target long side length.
268
269        Args:
270            oldh: The input image height.
271            oldw: The input image width.
272            long_side_length: The longest side length for resizing.
273
274        Returns:
275            The new image height.
276            The new image width.
277        """
278        scale = long_side_length * 1.0 / max(oldh, oldw)
279        newh, neww = oldh * scale, oldw * scale
280        neww = int(neww + 0.5)
281        newh = int(newh + 0.5)
282        return (newh, neww)
283
284    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
285        """Resize the image so that the longest side has the correct length.
286
287        Expects batched images with shape BxCxHxW and float format.
288
289        Args:
290            image: The input image.
291
292        Returns:
293            The resized image.
294        """
295        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
296        return F.interpolate(
297            image, target_size, mode="bilinear", align_corners=False, antialias=True
298        )
299
300    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
301        """@private
302        """
303        device = x.device
304
305        if self.use_sam_stats:
306            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
307            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
308        elif self.use_mae_stats:
309            # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
310            raise NotImplementedError
311        else:
312            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
313            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
314
315        if self.resize_input:
316            x = self.resize_longest_side(x)
317        input_shape = x.shape[-2:]
318
319        x = (x - pixel_mean) / pixel_std
320        h, w = x.shape[-2:]
321        padh = self.encoder.img_size - h
322        padw = self.encoder.img_size - w
323        x = F.pad(x, (0, padw, 0, padh))
324        return x, input_shape
325
326    def postprocess_masks(
327        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
328    ) -> torch.Tensor:
329        """@private
330        """
331        masks = F.interpolate(
332            masks,
333            (self.encoder.img_size, self.encoder.img_size),
334            mode="bilinear",
335            align_corners=False,
336        )
337        masks = masks[..., : input_size[0], : input_size[1]]
338        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
339        return masks
340
341    def forward(self, x: torch.Tensor) -> torch.Tensor:
342        """Apply the UNETR to the input data.
343
344        Args:
345            x: The input tensor.
346
347        Returns:
348            The UNETR output.
349        """
350        original_shape = x.shape[-2:]
351
352        # Reshape the inputs to the shape expected by the encoder
353        # and normalize the inputs if normalization is part of the model.
354        x, input_shape = self.preprocess(x)
355
356        use_skip_connection = getattr(self, "use_skip_connection", True)
357
358        encoder_outputs = self.encoder(x)
359
360        if isinstance(encoder_outputs[-1], list):
361            # `encoder_outputs` can be arranged in only two forms:
362            #   - either we only return the image embeddings
363            #   - or, we return the image embeddings and the "list" of global attention layers
364            z12, from_encoder = encoder_outputs
365        else:
366            z12 = encoder_outputs
367
368        if use_skip_connection:
369            from_encoder = from_encoder[::-1]
370            z9 = self.deconv1(from_encoder[0])
371            z6 = self.deconv2(from_encoder[1])
372            z3 = self.deconv3(from_encoder[2])
373            z0 = self.deconv4(x)
374
375        else:
376            z9 = self.deconv1(z12)
377            z6 = self.deconv2(z9)
378            z3 = self.deconv3(z6)
379            z0 = self.deconv4(z3)
380
381        updated_from_encoder = [z9, z6, z3]
382
383        x = self.base(z12)
384        x = self.decoder(x, encoder_inputs=updated_from_encoder)
385        x = self.deconv_out(x)
386
387        x = torch.cat([x, z0], dim=1)
388        x = self.decoder_head(x)
389
390        x = self.out_conv(x)
391        if self.final_activation is not None:
392            x = self.final_activation(x)
393
394        x = self.postprocess_masks(x, input_shape, original_shape)
395        return x
396
397
398#
399#  ADDITIONAL FUNCTIONALITIES
400#
401
402
403class SingleDeconv2DBlock(nn.Module):
404    """@private
405    """
406    def __init__(self, scale_factor, in_channels, out_channels):
407        super().__init__()
408        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
409
410    def forward(self, x):
411        return self.block(x)
412
413
414class SingleConv2DBlock(nn.Module):
415    """@private
416    """
417    def __init__(self, in_channels, out_channels, kernel_size):
418        super().__init__()
419        self.block = nn.Conv2d(
420            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
421        )
422
423    def forward(self, x):
424        return self.block(x)
425
426
427class Conv2DBlock(nn.Module):
428    """@private
429    """
430    def __init__(self, in_channels, out_channels, kernel_size=3):
431        super().__init__()
432        self.block = nn.Sequential(
433            SingleConv2DBlock(in_channels, out_channels, kernel_size),
434            nn.BatchNorm2d(out_channels),
435            nn.ReLU(True)
436        )
437
438    def forward(self, x):
439        return self.block(x)
440
441
442class Deconv2DBlock(nn.Module):
443    """@private
444    """
445    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
446        super().__init__()
447        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
448        self.block = nn.Sequential(
449            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
450            SingleConv2DBlock(out_channels, out_channels, kernel_size),
451            nn.BatchNorm2d(out_channels),
452            nn.ReLU(True)
453        )
454
455    def forward(self, x):
456        return self.block(x)
class UNETR(torch.nn.modules.module.Module):
 28class UNETR(nn.Module):
 29    """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
 30
 31    Args:
 32        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 33        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
 34        encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
 35        decoder: The convolutional decoder.
 36        out_channels: The number of output channels of the UNETR.
 37        use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
 38        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 39        resize_input: Whether to resize the input images to match `img_size`.
 40            By default, it resizes the inputs to match the `img_size`.
 41        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 42            Can either be a filepath or an already loaded checkpoint.
 43        final_activation: The activation to apply to the UNETR output.
 44        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 45        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 46        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 47            By default, it uses resampling for upsampling.
 48    """
 49    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 50        """Function to load pretrained weights to the image encoder.
 51        """
 52        if isinstance(checkpoint, str):
 53            if backbone == "sam" and isinstance(encoder, str):
 54                # If we have a SAM encoder, then we first try to load the full SAM Model
 55                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 56                # from the checkpoint
 57                try:
 58                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
 59                    encoder_state = model.image_encoder.state_dict()
 60                except Exception:
 61                    # Try loading the encoder state directly from a checkpoint.
 62                    encoder_state = torch.load(checkpoint, weights_only=False)
 63
 64            elif backbone == "sam2" and isinstance(encoder, str):
 65                # If we have a SAM2 encoder, then we first try to load the full SAM2 Model.
 66                # (using micro_sam2) and otherwise fall back on directly loading the encoder state
 67                # from the checkpoint
 68                try:
 69                    model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint)
 70                    encoder_state = model.image_encoder.state_dict()
 71                except Exception:
 72                    # Try loading the encoder state directly from a checkpoint.
 73                    encoder_state = torch.load(checkpoint, weights_only=False)
 74
 75            elif backbone == "mae":
 76                # vit initialization hints from:
 77                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 78                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
 79                encoder_state = OrderedDict({
 80                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
 81                })
 82                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 83                current_encoder_state = self.encoder.state_dict()
 84                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 85                    del self.encoder.head
 86
 87            elif backbone == "scalemae":
 88                # Load the encoder state directly from a checkpoint.
 89                encoder_state = torch.load(checkpoint)["model"]
 90                encoder_state = OrderedDict({
 91                    k: v for k, v in encoder_state.items()
 92                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
 93                })
 94
 95                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 96                current_encoder_state = self.encoder.state_dict()
 97                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 98                    del self.encoder.head
 99
100                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
101                    del self.encoder.pos_embed
102
103        else:
104            encoder_state = checkpoint
105
106        self.encoder.load_state_dict(encoder_state)
107
108    def __init__(
109        self,
110        img_size: int = 1024,
111        backbone: str = "sam",
112        encoder: Optional[Union[nn.Module, str]] = "vit_b",
113        decoder: Optional[nn.Module] = None,
114        out_channels: int = 1,
115        use_sam_stats: bool = False,
116        use_mae_stats: bool = False,
117        resize_input: bool = True,
118        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
119        final_activation: Optional[Union[str, nn.Module]] = None,
120        use_skip_connection: bool = True,
121        embed_dim: Optional[int] = None,
122        use_conv_transpose: bool = False,
123        **kwargs
124    ) -> None:
125        super().__init__()
126
127        self.use_sam_stats = use_sam_stats
128        self.use_mae_stats = use_mae_stats
129        self.use_skip_connection = use_skip_connection
130        self.resize_input = resize_input
131
132        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
133            print(f"Using {encoder} from {backbone.upper()}")
134            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
135
136            if encoder_checkpoint is not None:
137                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
138
139            if backbone == "sam2":
140                in_chans = self.encoder.trunk.patch_embed.proj.in_channels
141            else:
142                in_chans = self.encoder.in_chans
143
144            if embed_dim is None:
145                embed_dim = self.encoder.embed_dim
146
147        else:  # `nn.Module` ViT backbone
148            self.encoder = encoder
149
150            have_neck = False
151            for name, _ in self.encoder.named_parameters():
152                if name.startswith("neck"):
153                    have_neck = True
154
155            if embed_dim is None:
156                if have_neck:
157                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
158                else:
159                    embed_dim = self.encoder.patch_embed.proj.out_channels
160
161            try:
162                in_chans = self.encoder.patch_embed.proj.in_channels
163            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
164                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
165
166        # parameters for the decoder network
167        depth = 3
168        initial_features = 64
169        gain = 2
170        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
171        scale_factors = depth * [2]
172        self.out_channels = out_channels
173
174        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
175        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
176
177        if decoder is None:
178            self.decoder = Decoder(
179                features=features_decoder,
180                scale_factors=scale_factors[::-1],
181                conv_block_impl=ConvBlock2d,
182                sampler_impl=_upsampler,
183            )
184        else:
185            self.decoder = decoder
186
187        if use_skip_connection:
188            self.deconv1 = Deconv2DBlock(
189                in_channels=embed_dim,
190                out_channels=features_decoder[0],
191                use_conv_transpose=use_conv_transpose,
192            )
193            self.deconv2 = nn.Sequential(
194                Deconv2DBlock(
195                    in_channels=embed_dim,
196                    out_channels=features_decoder[0],
197                    use_conv_transpose=use_conv_transpose,
198                ),
199                Deconv2DBlock(
200                    in_channels=features_decoder[0],
201                    out_channels=features_decoder[1],
202                    use_conv_transpose=use_conv_transpose,
203                )
204            )
205            self.deconv3 = nn.Sequential(
206                Deconv2DBlock(
207                    in_channels=embed_dim,
208                    out_channels=features_decoder[0],
209                    use_conv_transpose=use_conv_transpose,
210                ),
211                Deconv2DBlock(
212                    in_channels=features_decoder[0],
213                    out_channels=features_decoder[1],
214                    use_conv_transpose=use_conv_transpose,
215                ),
216                Deconv2DBlock(
217                    in_channels=features_decoder[1],
218                    out_channels=features_decoder[2],
219                    use_conv_transpose=use_conv_transpose,
220                )
221            )
222            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
223        else:
224            self.deconv1 = Deconv2DBlock(
225                in_channels=embed_dim,
226                out_channels=features_decoder[0],
227                use_conv_transpose=use_conv_transpose,
228            )
229            self.deconv2 = Deconv2DBlock(
230                in_channels=features_decoder[0],
231                out_channels=features_decoder[1],
232                use_conv_transpose=use_conv_transpose,
233            )
234            self.deconv3 = Deconv2DBlock(
235                in_channels=features_decoder[1],
236                out_channels=features_decoder[2],
237                use_conv_transpose=use_conv_transpose,
238            )
239            self.deconv4 = Deconv2DBlock(
240                in_channels=features_decoder[2],
241                out_channels=features_decoder[3],
242                use_conv_transpose=use_conv_transpose,
243            )
244
245        self.base = ConvBlock2d(embed_dim, features_decoder[0])
246        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
247        self.deconv_out = _upsampler(
248            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
249        )
250        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
251        self.final_activation = self._get_activation(final_activation)
252
253    def _get_activation(self, activation):
254        return_activation = None
255        if activation is None:
256            return None
257        if isinstance(activation, nn.Module):
258            return activation
259        if isinstance(activation, str):
260            return_activation = getattr(nn, activation, None)
261        if return_activation is None:
262            raise ValueError(f"Invalid activation: {activation}")
263
264        return return_activation()
265
266    @staticmethod
267    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
268        """Compute the output size given input size and target long side length.
269
270        Args:
271            oldh: The input image height.
272            oldw: The input image width.
273            long_side_length: The longest side length for resizing.
274
275        Returns:
276            The new image height.
277            The new image width.
278        """
279        scale = long_side_length * 1.0 / max(oldh, oldw)
280        newh, neww = oldh * scale, oldw * scale
281        neww = int(neww + 0.5)
282        newh = int(newh + 0.5)
283        return (newh, neww)
284
285    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
286        """Resize the image so that the longest side has the correct length.
287
288        Expects batched images with shape BxCxHxW and float format.
289
290        Args:
291            image: The input image.
292
293        Returns:
294            The resized image.
295        """
296        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
297        return F.interpolate(
298            image, target_size, mode="bilinear", align_corners=False, antialias=True
299        )
300
301    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
302        """@private
303        """
304        device = x.device
305
306        if self.use_sam_stats:
307            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
308            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
309        elif self.use_mae_stats:
310            # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
311            raise NotImplementedError
312        else:
313            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
314            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
315
316        if self.resize_input:
317            x = self.resize_longest_side(x)
318        input_shape = x.shape[-2:]
319
320        x = (x - pixel_mean) / pixel_std
321        h, w = x.shape[-2:]
322        padh = self.encoder.img_size - h
323        padw = self.encoder.img_size - w
324        x = F.pad(x, (0, padw, 0, padh))
325        return x, input_shape
326
327    def postprocess_masks(
328        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
329    ) -> torch.Tensor:
330        """@private
331        """
332        masks = F.interpolate(
333            masks,
334            (self.encoder.img_size, self.encoder.img_size),
335            mode="bilinear",
336            align_corners=False,
337        )
338        masks = masks[..., : input_size[0], : input_size[1]]
339        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
340        return masks
341
342    def forward(self, x: torch.Tensor) -> torch.Tensor:
343        """Apply the UNETR to the input data.
344
345        Args:
346            x: The input tensor.
347
348        Returns:
349            The UNETR output.
350        """
351        original_shape = x.shape[-2:]
352
353        # Reshape the inputs to the shape expected by the encoder
354        # and normalize the inputs if normalization is part of the model.
355        x, input_shape = self.preprocess(x)
356
357        use_skip_connection = getattr(self, "use_skip_connection", True)
358
359        encoder_outputs = self.encoder(x)
360
361        if isinstance(encoder_outputs[-1], list):
362            # `encoder_outputs` can be arranged in only two forms:
363            #   - either we only return the image embeddings
364            #   - or, we return the image embeddings and the "list" of global attention layers
365            z12, from_encoder = encoder_outputs
366        else:
367            z12 = encoder_outputs
368
369        if use_skip_connection:
370            from_encoder = from_encoder[::-1]
371            z9 = self.deconv1(from_encoder[0])
372            z6 = self.deconv2(from_encoder[1])
373            z3 = self.deconv3(from_encoder[2])
374            z0 = self.deconv4(x)
375
376        else:
377            z9 = self.deconv1(z12)
378            z6 = self.deconv2(z9)
379            z3 = self.deconv3(z6)
380            z0 = self.deconv4(z3)
381
382        updated_from_encoder = [z9, z6, z3]
383
384        x = self.base(z12)
385        x = self.decoder(x, encoder_inputs=updated_from_encoder)
386        x = self.deconv_out(x)
387
388        x = torch.cat([x, z0], dim=1)
389        x = self.decoder_head(x)
390
391        x = self.out_conv(x)
392        if self.final_activation is not None:
393            x = self.final_activation(x)
394
395        x = self.postprocess_masks(x, input_shape, original_shape)
396        return x

A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • backbone: The name of the vision transformer implementation. One of "sam" or "mae".
  • encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
  • decoder: The convolutional decoder.
  • out_channels: The number of output channels of the UNETR.
  • use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
  • use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  • resize_input: Whether to resize the input images to match img_size. By default, it resizes the inputs to match the img_size.
  • encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
  • final_activation: The activation to apply to the UNETR output.
  • use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
  • embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
  • use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
UNETR( img_size: int = 1024, backbone: str = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
108    def __init__(
109        self,
110        img_size: int = 1024,
111        backbone: str = "sam",
112        encoder: Optional[Union[nn.Module, str]] = "vit_b",
113        decoder: Optional[nn.Module] = None,
114        out_channels: int = 1,
115        use_sam_stats: bool = False,
116        use_mae_stats: bool = False,
117        resize_input: bool = True,
118        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
119        final_activation: Optional[Union[str, nn.Module]] = None,
120        use_skip_connection: bool = True,
121        embed_dim: Optional[int] = None,
122        use_conv_transpose: bool = False,
123        **kwargs
124    ) -> None:
125        super().__init__()
126
127        self.use_sam_stats = use_sam_stats
128        self.use_mae_stats = use_mae_stats
129        self.use_skip_connection = use_skip_connection
130        self.resize_input = resize_input
131
132        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
133            print(f"Using {encoder} from {backbone.upper()}")
134            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
135
136            if encoder_checkpoint is not None:
137                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
138
139            if backbone == "sam2":
140                in_chans = self.encoder.trunk.patch_embed.proj.in_channels
141            else:
142                in_chans = self.encoder.in_chans
143
144            if embed_dim is None:
145                embed_dim = self.encoder.embed_dim
146
147        else:  # `nn.Module` ViT backbone
148            self.encoder = encoder
149
150            have_neck = False
151            for name, _ in self.encoder.named_parameters():
152                if name.startswith("neck"):
153                    have_neck = True
154
155            if embed_dim is None:
156                if have_neck:
157                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
158                else:
159                    embed_dim = self.encoder.patch_embed.proj.out_channels
160
161            try:
162                in_chans = self.encoder.patch_embed.proj.in_channels
163            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
164                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
165
166        # parameters for the decoder network
167        depth = 3
168        initial_features = 64
169        gain = 2
170        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
171        scale_factors = depth * [2]
172        self.out_channels = out_channels
173
174        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
175        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
176
177        if decoder is None:
178            self.decoder = Decoder(
179                features=features_decoder,
180                scale_factors=scale_factors[::-1],
181                conv_block_impl=ConvBlock2d,
182                sampler_impl=_upsampler,
183            )
184        else:
185            self.decoder = decoder
186
187        if use_skip_connection:
188            self.deconv1 = Deconv2DBlock(
189                in_channels=embed_dim,
190                out_channels=features_decoder[0],
191                use_conv_transpose=use_conv_transpose,
192            )
193            self.deconv2 = nn.Sequential(
194                Deconv2DBlock(
195                    in_channels=embed_dim,
196                    out_channels=features_decoder[0],
197                    use_conv_transpose=use_conv_transpose,
198                ),
199                Deconv2DBlock(
200                    in_channels=features_decoder[0],
201                    out_channels=features_decoder[1],
202                    use_conv_transpose=use_conv_transpose,
203                )
204            )
205            self.deconv3 = nn.Sequential(
206                Deconv2DBlock(
207                    in_channels=embed_dim,
208                    out_channels=features_decoder[0],
209                    use_conv_transpose=use_conv_transpose,
210                ),
211                Deconv2DBlock(
212                    in_channels=features_decoder[0],
213                    out_channels=features_decoder[1],
214                    use_conv_transpose=use_conv_transpose,
215                ),
216                Deconv2DBlock(
217                    in_channels=features_decoder[1],
218                    out_channels=features_decoder[2],
219                    use_conv_transpose=use_conv_transpose,
220                )
221            )
222            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
223        else:
224            self.deconv1 = Deconv2DBlock(
225                in_channels=embed_dim,
226                out_channels=features_decoder[0],
227                use_conv_transpose=use_conv_transpose,
228            )
229            self.deconv2 = Deconv2DBlock(
230                in_channels=features_decoder[0],
231                out_channels=features_decoder[1],
232                use_conv_transpose=use_conv_transpose,
233            )
234            self.deconv3 = Deconv2DBlock(
235                in_channels=features_decoder[1],
236                out_channels=features_decoder[2],
237                use_conv_transpose=use_conv_transpose,
238            )
239            self.deconv4 = Deconv2DBlock(
240                in_channels=features_decoder[2],
241                out_channels=features_decoder[3],
242                use_conv_transpose=use_conv_transpose,
243            )
244
245        self.base = ConvBlock2d(embed_dim, features_decoder[0])
246        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
247        self.deconv_out = _upsampler(
248            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
249        )
250        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
251        self.final_activation = self._get_activation(final_activation)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

use_sam_stats
use_mae_stats
use_skip_connection
resize_input
out_channels
base
out_conv
deconv_out
decoder_head
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
266    @staticmethod
267    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
268        """Compute the output size given input size and target long side length.
269
270        Args:
271            oldh: The input image height.
272            oldw: The input image width.
273            long_side_length: The longest side length for resizing.
274
275        Returns:
276            The new image height.
277            The new image width.
278        """
279        scale = long_side_length * 1.0 / max(oldh, oldw)
280        newh, neww = oldh * scale, oldw * scale
281        neww = int(neww + 0.5)
282        newh = int(newh + 0.5)
283        return (newh, neww)

Compute the output size given input size and target long side length.

Arguments:
  • oldh: The input image height.
  • oldw: The input image width.
  • long_side_length: The longest side length for resizing.
Returns:

The new image height. The new image width.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
285    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
286        """Resize the image so that the longest side has the correct length.
287
288        Expects batched images with shape BxCxHxW and float format.
289
290        Args:
291            image: The input image.
292
293        Returns:
294            The resized image.
295        """
296        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
297        return F.interpolate(
298            image, target_size, mode="bilinear", align_corners=False, antialias=True
299        )

Resize the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW and float format.

Arguments:
  • image: The input image.
Returns:

The resized image.

def forward(self, x: torch.Tensor) -> torch.Tensor:
342    def forward(self, x: torch.Tensor) -> torch.Tensor:
343        """Apply the UNETR to the input data.
344
345        Args:
346            x: The input tensor.
347
348        Returns:
349            The UNETR output.
350        """
351        original_shape = x.shape[-2:]
352
353        # Reshape the inputs to the shape expected by the encoder
354        # and normalize the inputs if normalization is part of the model.
355        x, input_shape = self.preprocess(x)
356
357        use_skip_connection = getattr(self, "use_skip_connection", True)
358
359        encoder_outputs = self.encoder(x)
360
361        if isinstance(encoder_outputs[-1], list):
362            # `encoder_outputs` can be arranged in only two forms:
363            #   - either we only return the image embeddings
364            #   - or, we return the image embeddings and the "list" of global attention layers
365            z12, from_encoder = encoder_outputs
366        else:
367            z12 = encoder_outputs
368
369        if use_skip_connection:
370            from_encoder = from_encoder[::-1]
371            z9 = self.deconv1(from_encoder[0])
372            z6 = self.deconv2(from_encoder[1])
373            z3 = self.deconv3(from_encoder[2])
374            z0 = self.deconv4(x)
375
376        else:
377            z9 = self.deconv1(z12)
378            z6 = self.deconv2(z9)
379            z3 = self.deconv3(z6)
380            z0 = self.deconv4(z3)
381
382        updated_from_encoder = [z9, z6, z3]
383
384        x = self.base(z12)
385        x = self.decoder(x, encoder_inputs=updated_from_encoder)
386        x = self.deconv_out(x)
387
388        x = torch.cat([x, z0], dim=1)
389        x = self.decoder_head(x)
390
391        x = self.out_conv(x)
392        if self.final_activation is not None:
393            x = self.final_activation(x)
394
395        x = self.postprocess_masks(x, input_shape, original_shape)
396        return x

Apply the UNETR to the input data.

Arguments:
  • x: The input tensor.
Returns:

The UNETR output.