torch_em.model.unetr

  1from collections import OrderedDict
  2from typing import Optional, Tuple, Union
  3
  4import torch
  5import torch.nn as nn
  6import torch.nn.functional as F
  7
  8from .vit import get_vision_transformer
  9from .unet import Decoder, ConvBlock2d, Upsampler2d
 10
 11try:
 12    from micro_sam.util import get_sam_model
 13except ImportError:
 14    get_sam_model = None
 15
 16
 17#
 18# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / MAE / ScaleMAE) + UNet Decoder from `torch_em`]
 19#
 20
 21
 22class UNETR(nn.Module):
 23    """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
 24
 25    Args:
 26        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 27        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
 28        encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
 29        decoder: The convolutional decoder.
 30        out_channels: The number of output channels of the UNETR.
 31        use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
 32        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 33        resize_input: Whether to resize the input images to match `img_size`.
 34            By default, it resizes the inputs to match the `img_size`.
 35        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 36            Can either be a filepath or an already loaded checkpoint.
 37        final_activation: The activation to apply to the UNETR output.
 38        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 39        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 40        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 41            By default, it uses resampling for upsampling.
 42    """
 43    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 44        """Function to load pretrained weights to the image encoder.
 45        """
 46        if isinstance(checkpoint, str):
 47            if backbone == "sam" and isinstance(encoder, str):
 48                # If we have a SAM encoder, then we first try to load the full SAM Model
 49                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 50                # from the checkpoint
 51                try:
 52                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
 53                    encoder_state = model.image_encoder.state_dict()
 54                except Exception:
 55                    # Try loading the encoder state directly from a checkpoint.
 56                    encoder_state = torch.load(checkpoint, weights_only=False)
 57
 58            elif backbone == "mae":
 59                # vit initialization hints from:
 60                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 61                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
 62                encoder_state = OrderedDict({
 63                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
 64                })
 65                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 66                current_encoder_state = self.encoder.state_dict()
 67                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 68                    del self.encoder.head
 69
 70            elif backbone == "scalemae":
 71                # Load the encoder state directly from a checkpoint.
 72                encoder_state = torch.load(checkpoint)["model"]
 73                encoder_state = OrderedDict({
 74                    k: v for k, v in encoder_state.items()
 75                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
 76                })
 77
 78                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 79                current_encoder_state = self.encoder.state_dict()
 80                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 81                    del self.encoder.head
 82
 83                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
 84                    del self.encoder.pos_embed
 85
 86        else:
 87            encoder_state = checkpoint
 88
 89        self.encoder.load_state_dict(encoder_state)
 90
 91    def __init__(
 92        self,
 93        img_size: int = 1024,
 94        backbone: str = "sam",
 95        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 96        decoder: Optional[nn.Module] = None,
 97        out_channels: int = 1,
 98        use_sam_stats: bool = False,
 99        use_mae_stats: bool = False,
100        resize_input: bool = True,
101        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
102        final_activation: Optional[Union[str, nn.Module]] = None,
103        use_skip_connection: bool = True,
104        embed_dim: Optional[int] = None,
105        use_conv_transpose: bool = False,
106        **kwargs
107    ) -> None:
108        super().__init__()
109
110        self.use_sam_stats = use_sam_stats
111        self.use_mae_stats = use_mae_stats
112        self.use_skip_connection = use_skip_connection
113        self.resize_input = resize_input
114
115        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
116            print(f"Using {encoder} from {backbone.upper()}")
117            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
118
119            if encoder_checkpoint is not None:
120                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
121
122            in_chans = self.encoder.in_chans
123            if embed_dim is None:
124                embed_dim = self.encoder.embed_dim
125
126        else:  # `nn.Module` ViT backbone
127            self.encoder = encoder
128
129            have_neck = False
130            for name, _ in self.encoder.named_parameters():
131                if name.startswith("neck"):
132                    have_neck = True
133
134            if embed_dim is None:
135                if have_neck:
136                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
137                else:
138                    embed_dim = self.encoder.patch_embed.proj.out_channels
139
140            try:
141                in_chans = self.encoder.patch_embed.proj.in_channels
142            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
143                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
144
145        # parameters for the decoder network
146        depth = 3
147        initial_features = 64
148        gain = 2
149        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
150        scale_factors = depth * [2]
151        self.out_channels = out_channels
152
153        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
154        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
155
156        if decoder is None:
157            self.decoder = Decoder(
158                features=features_decoder,
159                scale_factors=scale_factors[::-1],
160                conv_block_impl=ConvBlock2d,
161                sampler_impl=_upsampler,
162            )
163        else:
164            self.decoder = decoder
165
166        if use_skip_connection:
167            self.deconv1 = Deconv2DBlock(
168                in_channels=embed_dim,
169                out_channels=features_decoder[0],
170                use_conv_transpose=use_conv_transpose,
171            )
172            self.deconv2 = nn.Sequential(
173                Deconv2DBlock(
174                    in_channels=embed_dim,
175                    out_channels=features_decoder[0],
176                    use_conv_transpose=use_conv_transpose,
177                ),
178                Deconv2DBlock(
179                    in_channels=features_decoder[0],
180                    out_channels=features_decoder[1],
181                    use_conv_transpose=use_conv_transpose,
182                )
183            )
184            self.deconv3 = nn.Sequential(
185                Deconv2DBlock(
186                    in_channels=embed_dim,
187                    out_channels=features_decoder[0],
188                    use_conv_transpose=use_conv_transpose,
189                ),
190                Deconv2DBlock(
191                    in_channels=features_decoder[0],
192                    out_channels=features_decoder[1],
193                    use_conv_transpose=use_conv_transpose,
194                ),
195                Deconv2DBlock(
196                    in_channels=features_decoder[1],
197                    out_channels=features_decoder[2],
198                    use_conv_transpose=use_conv_transpose,
199                )
200            )
201            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
202        else:
203            self.deconv1 = Deconv2DBlock(
204                in_channels=embed_dim,
205                out_channels=features_decoder[0],
206                use_conv_transpose=use_conv_transpose,
207            )
208            self.deconv2 = Deconv2DBlock(
209                in_channels=features_decoder[0],
210                out_channels=features_decoder[1],
211                use_conv_transpose=use_conv_transpose,
212            )
213            self.deconv3 = Deconv2DBlock(
214                in_channels=features_decoder[1],
215                out_channels=features_decoder[2],
216                use_conv_transpose=use_conv_transpose,
217            )
218            self.deconv4 = Deconv2DBlock(
219                in_channels=features_decoder[2],
220                out_channels=features_decoder[3],
221                use_conv_transpose=use_conv_transpose,
222            )
223
224        self.base = ConvBlock2d(embed_dim, features_decoder[0])
225        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
226        self.deconv_out = _upsampler(
227            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
228        )
229        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
230        self.final_activation = self._get_activation(final_activation)
231
232    def _get_activation(self, activation):
233        return_activation = None
234        if activation is None:
235            return None
236        if isinstance(activation, nn.Module):
237            return activation
238        if isinstance(activation, str):
239            return_activation = getattr(nn, activation, None)
240        if return_activation is None:
241            raise ValueError(f"Invalid activation: {activation}")
242
243        return return_activation()
244
245    @staticmethod
246    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
247        """Compute the output size given input size and target long side length.
248
249        Args:
250            oldh: The input image height.
251            oldw: The input image width.
252            long_side_length: The longest side length for resizing.
253
254        Returns:
255            The new image height.
256            The new image width.
257        """
258        scale = long_side_length * 1.0 / max(oldh, oldw)
259        newh, neww = oldh * scale, oldw * scale
260        neww = int(neww + 0.5)
261        newh = int(newh + 0.5)
262        return (newh, neww)
263
264    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
265        """Resize the image so that the longest side has the correct length.
266
267        Expects batched images with shape BxCxHxW and float format.
268
269        Args:
270            image: The input image.
271
272        Returns:
273            The resized image.
274        """
275        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
276        return F.interpolate(
277            image, target_size, mode="bilinear", align_corners=False, antialias=True
278        )
279
280    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
281        """@private
282        """
283        device = x.device
284
285        if self.use_sam_stats:
286            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
287            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
288        elif self.use_mae_stats:
289            # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
290            raise NotImplementedError
291        else:
292            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
293            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
294
295        if self.resize_input:
296            x = self.resize_longest_side(x)
297        input_shape = x.shape[-2:]
298
299        x = (x - pixel_mean) / pixel_std
300        h, w = x.shape[-2:]
301        padh = self.encoder.img_size - h
302        padw = self.encoder.img_size - w
303        x = F.pad(x, (0, padw, 0, padh))
304        return x, input_shape
305
306    def postprocess_masks(
307        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
308    ) -> torch.Tensor:
309        """@private
310        """
311        masks = F.interpolate(
312            masks,
313            (self.encoder.img_size, self.encoder.img_size),
314            mode="bilinear",
315            align_corners=False,
316        )
317        masks = masks[..., : input_size[0], : input_size[1]]
318        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
319        return masks
320
321    def forward(self, x: torch.Tensor) -> torch.Tensor:
322        """Apply the UNETR to the input data.
323
324        Args:
325            x: The input tensor.
326
327        Returns:
328            The UNETR output.
329        """
330        original_shape = x.shape[-2:]
331
332        # Reshape the inputs to the shape expected by the encoder
333        # and normalize the inputs if normalization is part of the model.
334        x, input_shape = self.preprocess(x)
335
336        use_skip_connection = getattr(self, "use_skip_connection", True)
337
338        encoder_outputs = self.encoder(x)
339
340        if isinstance(encoder_outputs[-1], list):
341            # `encoder_outputs` can be arranged in only two forms:
342            #   - either we only return the image embeddings
343            #   - or, we return the image embeddings and the "list" of global attention layers
344            z12, from_encoder = encoder_outputs
345        else:
346            z12 = encoder_outputs
347
348        if use_skip_connection:
349            from_encoder = from_encoder[::-1]
350            z9 = self.deconv1(from_encoder[0])
351            z6 = self.deconv2(from_encoder[1])
352            z3 = self.deconv3(from_encoder[2])
353            z0 = self.deconv4(x)
354
355        else:
356            z9 = self.deconv1(z12)
357            z6 = self.deconv2(z9)
358            z3 = self.deconv3(z6)
359            z0 = self.deconv4(z3)
360
361        updated_from_encoder = [z9, z6, z3]
362
363        x = self.base(z12)
364        x = self.decoder(x, encoder_inputs=updated_from_encoder)
365        x = self.deconv_out(x)
366
367        x = torch.cat([x, z0], dim=1)
368        x = self.decoder_head(x)
369
370        x = self.out_conv(x)
371        if self.final_activation is not None:
372            x = self.final_activation(x)
373
374        x = self.postprocess_masks(x, input_shape, original_shape)
375        return x
376
377
378#
379#  ADDITIONAL FUNCTIONALITIES
380#
381
382
383class SingleDeconv2DBlock(nn.Module):
384    """@private
385    """
386    def __init__(self, scale_factor, in_channels, out_channels):
387        super().__init__()
388        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
389
390    def forward(self, x):
391        return self.block(x)
392
393
394class SingleConv2DBlock(nn.Module):
395    """@private
396    """
397    def __init__(self, in_channels, out_channels, kernel_size):
398        super().__init__()
399        self.block = nn.Conv2d(
400            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
401        )
402
403    def forward(self, x):
404        return self.block(x)
405
406
407class Conv2DBlock(nn.Module):
408    """@private
409    """
410    def __init__(self, in_channels, out_channels, kernel_size=3):
411        super().__init__()
412        self.block = nn.Sequential(
413            SingleConv2DBlock(in_channels, out_channels, kernel_size),
414            nn.BatchNorm2d(out_channels),
415            nn.ReLU(True)
416        )
417
418    def forward(self, x):
419        return self.block(x)
420
421
422class Deconv2DBlock(nn.Module):
423    """@private
424    """
425    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
426        super().__init__()
427        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
428        self.block = nn.Sequential(
429            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
430            SingleConv2DBlock(out_channels, out_channels, kernel_size),
431            nn.BatchNorm2d(out_channels),
432            nn.ReLU(True)
433        )
434
435    def forward(self, x):
436        return self.block(x)
class UNETR(torch.nn.modules.module.Module):
 23class UNETR(nn.Module):
 24    """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
 25
 26    Args:
 27        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 28        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
 29        encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
 30        decoder: The convolutional decoder.
 31        out_channels: The number of output channels of the UNETR.
 32        use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
 33        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 34        resize_input: Whether to resize the input images to match `img_size`.
 35            By default, it resizes the inputs to match the `img_size`.
 36        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 37            Can either be a filepath or an already loaded checkpoint.
 38        final_activation: The activation to apply to the UNETR output.
 39        use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
 40        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 41        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 42            By default, it uses resampling for upsampling.
 43    """
 44    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 45        """Function to load pretrained weights to the image encoder.
 46        """
 47        if isinstance(checkpoint, str):
 48            if backbone == "sam" and isinstance(encoder, str):
 49                # If we have a SAM encoder, then we first try to load the full SAM Model
 50                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 51                # from the checkpoint
 52                try:
 53                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
 54                    encoder_state = model.image_encoder.state_dict()
 55                except Exception:
 56                    # Try loading the encoder state directly from a checkpoint.
 57                    encoder_state = torch.load(checkpoint, weights_only=False)
 58
 59            elif backbone == "mae":
 60                # vit initialization hints from:
 61                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 62                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
 63                encoder_state = OrderedDict({
 64                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
 65                })
 66                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 67                current_encoder_state = self.encoder.state_dict()
 68                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 69                    del self.encoder.head
 70
 71            elif backbone == "scalemae":
 72                # Load the encoder state directly from a checkpoint.
 73                encoder_state = torch.load(checkpoint)["model"]
 74                encoder_state = OrderedDict({
 75                    k: v for k, v in encoder_state.items()
 76                    if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
 77                })
 78
 79                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 80                current_encoder_state = self.encoder.state_dict()
 81                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 82                    del self.encoder.head
 83
 84                if "pos_embed" in current_encoder_state:  # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
 85                    del self.encoder.pos_embed
 86
 87        else:
 88            encoder_state = checkpoint
 89
 90        self.encoder.load_state_dict(encoder_state)
 91
 92    def __init__(
 93        self,
 94        img_size: int = 1024,
 95        backbone: str = "sam",
 96        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 97        decoder: Optional[nn.Module] = None,
 98        out_channels: int = 1,
 99        use_sam_stats: bool = False,
100        use_mae_stats: bool = False,
101        resize_input: bool = True,
102        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
103        final_activation: Optional[Union[str, nn.Module]] = None,
104        use_skip_connection: bool = True,
105        embed_dim: Optional[int] = None,
106        use_conv_transpose: bool = False,
107        **kwargs
108    ) -> None:
109        super().__init__()
110
111        self.use_sam_stats = use_sam_stats
112        self.use_mae_stats = use_mae_stats
113        self.use_skip_connection = use_skip_connection
114        self.resize_input = resize_input
115
116        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
117            print(f"Using {encoder} from {backbone.upper()}")
118            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
119
120            if encoder_checkpoint is not None:
121                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
122
123            in_chans = self.encoder.in_chans
124            if embed_dim is None:
125                embed_dim = self.encoder.embed_dim
126
127        else:  # `nn.Module` ViT backbone
128            self.encoder = encoder
129
130            have_neck = False
131            for name, _ in self.encoder.named_parameters():
132                if name.startswith("neck"):
133                    have_neck = True
134
135            if embed_dim is None:
136                if have_neck:
137                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
138                else:
139                    embed_dim = self.encoder.patch_embed.proj.out_channels
140
141            try:
142                in_chans = self.encoder.patch_embed.proj.in_channels
143            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
144                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
145
146        # parameters for the decoder network
147        depth = 3
148        initial_features = 64
149        gain = 2
150        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
151        scale_factors = depth * [2]
152        self.out_channels = out_channels
153
154        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
155        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
156
157        if decoder is None:
158            self.decoder = Decoder(
159                features=features_decoder,
160                scale_factors=scale_factors[::-1],
161                conv_block_impl=ConvBlock2d,
162                sampler_impl=_upsampler,
163            )
164        else:
165            self.decoder = decoder
166
167        if use_skip_connection:
168            self.deconv1 = Deconv2DBlock(
169                in_channels=embed_dim,
170                out_channels=features_decoder[0],
171                use_conv_transpose=use_conv_transpose,
172            )
173            self.deconv2 = nn.Sequential(
174                Deconv2DBlock(
175                    in_channels=embed_dim,
176                    out_channels=features_decoder[0],
177                    use_conv_transpose=use_conv_transpose,
178                ),
179                Deconv2DBlock(
180                    in_channels=features_decoder[0],
181                    out_channels=features_decoder[1],
182                    use_conv_transpose=use_conv_transpose,
183                )
184            )
185            self.deconv3 = nn.Sequential(
186                Deconv2DBlock(
187                    in_channels=embed_dim,
188                    out_channels=features_decoder[0],
189                    use_conv_transpose=use_conv_transpose,
190                ),
191                Deconv2DBlock(
192                    in_channels=features_decoder[0],
193                    out_channels=features_decoder[1],
194                    use_conv_transpose=use_conv_transpose,
195                ),
196                Deconv2DBlock(
197                    in_channels=features_decoder[1],
198                    out_channels=features_decoder[2],
199                    use_conv_transpose=use_conv_transpose,
200                )
201            )
202            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
203        else:
204            self.deconv1 = Deconv2DBlock(
205                in_channels=embed_dim,
206                out_channels=features_decoder[0],
207                use_conv_transpose=use_conv_transpose,
208            )
209            self.deconv2 = Deconv2DBlock(
210                in_channels=features_decoder[0],
211                out_channels=features_decoder[1],
212                use_conv_transpose=use_conv_transpose,
213            )
214            self.deconv3 = Deconv2DBlock(
215                in_channels=features_decoder[1],
216                out_channels=features_decoder[2],
217                use_conv_transpose=use_conv_transpose,
218            )
219            self.deconv4 = Deconv2DBlock(
220                in_channels=features_decoder[2],
221                out_channels=features_decoder[3],
222                use_conv_transpose=use_conv_transpose,
223            )
224
225        self.base = ConvBlock2d(embed_dim, features_decoder[0])
226        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
227        self.deconv_out = _upsampler(
228            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
229        )
230        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
231        self.final_activation = self._get_activation(final_activation)
232
233    def _get_activation(self, activation):
234        return_activation = None
235        if activation is None:
236            return None
237        if isinstance(activation, nn.Module):
238            return activation
239        if isinstance(activation, str):
240            return_activation = getattr(nn, activation, None)
241        if return_activation is None:
242            raise ValueError(f"Invalid activation: {activation}")
243
244        return return_activation()
245
246    @staticmethod
247    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
248        """Compute the output size given input size and target long side length.
249
250        Args:
251            oldh: The input image height.
252            oldw: The input image width.
253            long_side_length: The longest side length for resizing.
254
255        Returns:
256            The new image height.
257            The new image width.
258        """
259        scale = long_side_length * 1.0 / max(oldh, oldw)
260        newh, neww = oldh * scale, oldw * scale
261        neww = int(neww + 0.5)
262        newh = int(newh + 0.5)
263        return (newh, neww)
264
265    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
266        """Resize the image so that the longest side has the correct length.
267
268        Expects batched images with shape BxCxHxW and float format.
269
270        Args:
271            image: The input image.
272
273        Returns:
274            The resized image.
275        """
276        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
277        return F.interpolate(
278            image, target_size, mode="bilinear", align_corners=False, antialias=True
279        )
280
281    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
282        """@private
283        """
284        device = x.device
285
286        if self.use_sam_stats:
287            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
288            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
289        elif self.use_mae_stats:
290            # TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
291            raise NotImplementedError
292        else:
293            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
294            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
295
296        if self.resize_input:
297            x = self.resize_longest_side(x)
298        input_shape = x.shape[-2:]
299
300        x = (x - pixel_mean) / pixel_std
301        h, w = x.shape[-2:]
302        padh = self.encoder.img_size - h
303        padw = self.encoder.img_size - w
304        x = F.pad(x, (0, padw, 0, padh))
305        return x, input_shape
306
307    def postprocess_masks(
308        self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
309    ) -> torch.Tensor:
310        """@private
311        """
312        masks = F.interpolate(
313            masks,
314            (self.encoder.img_size, self.encoder.img_size),
315            mode="bilinear",
316            align_corners=False,
317        )
318        masks = masks[..., : input_size[0], : input_size[1]]
319        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
320        return masks
321
322    def forward(self, x: torch.Tensor) -> torch.Tensor:
323        """Apply the UNETR to the input data.
324
325        Args:
326            x: The input tensor.
327
328        Returns:
329            The UNETR output.
330        """
331        original_shape = x.shape[-2:]
332
333        # Reshape the inputs to the shape expected by the encoder
334        # and normalize the inputs if normalization is part of the model.
335        x, input_shape = self.preprocess(x)
336
337        use_skip_connection = getattr(self, "use_skip_connection", True)
338
339        encoder_outputs = self.encoder(x)
340
341        if isinstance(encoder_outputs[-1], list):
342            # `encoder_outputs` can be arranged in only two forms:
343            #   - either we only return the image embeddings
344            #   - or, we return the image embeddings and the "list" of global attention layers
345            z12, from_encoder = encoder_outputs
346        else:
347            z12 = encoder_outputs
348
349        if use_skip_connection:
350            from_encoder = from_encoder[::-1]
351            z9 = self.deconv1(from_encoder[0])
352            z6 = self.deconv2(from_encoder[1])
353            z3 = self.deconv3(from_encoder[2])
354            z0 = self.deconv4(x)
355
356        else:
357            z9 = self.deconv1(z12)
358            z6 = self.deconv2(z9)
359            z3 = self.deconv3(z6)
360            z0 = self.deconv4(z3)
361
362        updated_from_encoder = [z9, z6, z3]
363
364        x = self.base(z12)
365        x = self.decoder(x, encoder_inputs=updated_from_encoder)
366        x = self.deconv_out(x)
367
368        x = torch.cat([x, z0], dim=1)
369        x = self.decoder_head(x)
370
371        x = self.out_conv(x)
372        if self.final_activation is not None:
373            x = self.final_activation(x)
374
375        x = self.postprocess_masks(x, input_shape, original_shape)
376        return x

A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • backbone: The name of the vision transformer implementation. One of "sam" or "mae".
  • encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
  • decoder: The convolutional decoder.
  • out_channels: The number of output channels of the UNETR.
  • use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
  • use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  • resize_input: Whether to resize the input images to match img_size. By default, it resizes the inputs to match the img_size.
  • encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
  • final_activation: The activation to apply to the UNETR output.
  • use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
  • embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
  • use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
UNETR( img_size: int = 1024, backbone: str = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
 92    def __init__(
 93        self,
 94        img_size: int = 1024,
 95        backbone: str = "sam",
 96        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 97        decoder: Optional[nn.Module] = None,
 98        out_channels: int = 1,
 99        use_sam_stats: bool = False,
100        use_mae_stats: bool = False,
101        resize_input: bool = True,
102        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
103        final_activation: Optional[Union[str, nn.Module]] = None,
104        use_skip_connection: bool = True,
105        embed_dim: Optional[int] = None,
106        use_conv_transpose: bool = False,
107        **kwargs
108    ) -> None:
109        super().__init__()
110
111        self.use_sam_stats = use_sam_stats
112        self.use_mae_stats = use_mae_stats
113        self.use_skip_connection = use_skip_connection
114        self.resize_input = resize_input
115
116        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
117            print(f"Using {encoder} from {backbone.upper()}")
118            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs)
119
120            if encoder_checkpoint is not None:
121                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
122
123            in_chans = self.encoder.in_chans
124            if embed_dim is None:
125                embed_dim = self.encoder.embed_dim
126
127        else:  # `nn.Module` ViT backbone
128            self.encoder = encoder
129
130            have_neck = False
131            for name, _ in self.encoder.named_parameters():
132                if name.startswith("neck"):
133                    have_neck = True
134
135            if embed_dim is None:
136                if have_neck:
137                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
138                else:
139                    embed_dim = self.encoder.patch_embed.proj.out_channels
140
141            try:
142                in_chans = self.encoder.patch_embed.proj.in_channels
143            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
144                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
145
146        # parameters for the decoder network
147        depth = 3
148        initial_features = 64
149        gain = 2
150        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
151        scale_factors = depth * [2]
152        self.out_channels = out_channels
153
154        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
155        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
156
157        if decoder is None:
158            self.decoder = Decoder(
159                features=features_decoder,
160                scale_factors=scale_factors[::-1],
161                conv_block_impl=ConvBlock2d,
162                sampler_impl=_upsampler,
163            )
164        else:
165            self.decoder = decoder
166
167        if use_skip_connection:
168            self.deconv1 = Deconv2DBlock(
169                in_channels=embed_dim,
170                out_channels=features_decoder[0],
171                use_conv_transpose=use_conv_transpose,
172            )
173            self.deconv2 = nn.Sequential(
174                Deconv2DBlock(
175                    in_channels=embed_dim,
176                    out_channels=features_decoder[0],
177                    use_conv_transpose=use_conv_transpose,
178                ),
179                Deconv2DBlock(
180                    in_channels=features_decoder[0],
181                    out_channels=features_decoder[1],
182                    use_conv_transpose=use_conv_transpose,
183                )
184            )
185            self.deconv3 = nn.Sequential(
186                Deconv2DBlock(
187                    in_channels=embed_dim,
188                    out_channels=features_decoder[0],
189                    use_conv_transpose=use_conv_transpose,
190                ),
191                Deconv2DBlock(
192                    in_channels=features_decoder[0],
193                    out_channels=features_decoder[1],
194                    use_conv_transpose=use_conv_transpose,
195                ),
196                Deconv2DBlock(
197                    in_channels=features_decoder[1],
198                    out_channels=features_decoder[2],
199                    use_conv_transpose=use_conv_transpose,
200                )
201            )
202            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
203        else:
204            self.deconv1 = Deconv2DBlock(
205                in_channels=embed_dim,
206                out_channels=features_decoder[0],
207                use_conv_transpose=use_conv_transpose,
208            )
209            self.deconv2 = Deconv2DBlock(
210                in_channels=features_decoder[0],
211                out_channels=features_decoder[1],
212                use_conv_transpose=use_conv_transpose,
213            )
214            self.deconv3 = Deconv2DBlock(
215                in_channels=features_decoder[1],
216                out_channels=features_decoder[2],
217                use_conv_transpose=use_conv_transpose,
218            )
219            self.deconv4 = Deconv2DBlock(
220                in_channels=features_decoder[2],
221                out_channels=features_decoder[3],
222                use_conv_transpose=use_conv_transpose,
223            )
224
225        self.base = ConvBlock2d(embed_dim, features_decoder[0])
226        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
227        self.deconv_out = _upsampler(
228            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
229        )
230        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
231        self.final_activation = self._get_activation(final_activation)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

use_sam_stats
use_mae_stats
use_skip_connection
resize_input
out_channels
base
out_conv
deconv_out
decoder_head
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
246    @staticmethod
247    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
248        """Compute the output size given input size and target long side length.
249
250        Args:
251            oldh: The input image height.
252            oldw: The input image width.
253            long_side_length: The longest side length for resizing.
254
255        Returns:
256            The new image height.
257            The new image width.
258        """
259        scale = long_side_length * 1.0 / max(oldh, oldw)
260        newh, neww = oldh * scale, oldw * scale
261        neww = int(neww + 0.5)
262        newh = int(newh + 0.5)
263        return (newh, neww)

Compute the output size given input size and target long side length.

Arguments:
  • oldh: The input image height.
  • oldw: The input image width.
  • long_side_length: The longest side length for resizing.
Returns:

The new image height. The new image width.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
265    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
266        """Resize the image so that the longest side has the correct length.
267
268        Expects batched images with shape BxCxHxW and float format.
269
270        Args:
271            image: The input image.
272
273        Returns:
274            The resized image.
275        """
276        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
277        return F.interpolate(
278            image, target_size, mode="bilinear", align_corners=False, antialias=True
279        )

Resize the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW and float format.

Arguments:
  • image: The input image.
Returns:

The resized image.

def forward(self, x: torch.Tensor) -> torch.Tensor:
322    def forward(self, x: torch.Tensor) -> torch.Tensor:
323        """Apply the UNETR to the input data.
324
325        Args:
326            x: The input tensor.
327
328        Returns:
329            The UNETR output.
330        """
331        original_shape = x.shape[-2:]
332
333        # Reshape the inputs to the shape expected by the encoder
334        # and normalize the inputs if normalization is part of the model.
335        x, input_shape = self.preprocess(x)
336
337        use_skip_connection = getattr(self, "use_skip_connection", True)
338
339        encoder_outputs = self.encoder(x)
340
341        if isinstance(encoder_outputs[-1], list):
342            # `encoder_outputs` can be arranged in only two forms:
343            #   - either we only return the image embeddings
344            #   - or, we return the image embeddings and the "list" of global attention layers
345            z12, from_encoder = encoder_outputs
346        else:
347            z12 = encoder_outputs
348
349        if use_skip_connection:
350            from_encoder = from_encoder[::-1]
351            z9 = self.deconv1(from_encoder[0])
352            z6 = self.deconv2(from_encoder[1])
353            z3 = self.deconv3(from_encoder[2])
354            z0 = self.deconv4(x)
355
356        else:
357            z9 = self.deconv1(z12)
358            z6 = self.deconv2(z9)
359            z3 = self.deconv3(z6)
360            z0 = self.deconv4(z3)
361
362        updated_from_encoder = [z9, z6, z3]
363
364        x = self.base(z12)
365        x = self.decoder(x, encoder_inputs=updated_from_encoder)
366        x = self.deconv_out(x)
367
368        x = torch.cat([x, z0], dim=1)
369        x = self.decoder_head(x)
370
371        x = self.out_conv(x)
372        if self.final_activation is not None:
373            x = self.final_activation(x)
374
375        x = self.postprocess_masks(x, input_shape, original_shape)
376        return x

Apply the UNETR to the input data.

Arguments:
  • x: The input tensor.
Returns:

The UNETR output.