torch_em.model.unetr

  1from collections import OrderedDict
  2from typing import Optional, Tuple, Union
  3
  4import torch
  5import torch.nn as nn
  6import torch.nn.functional as F
  7
  8from .vit import get_vision_transformer
  9from .unet import Decoder, ConvBlock2d, Upsampler2d
 10
 11try:
 12    from micro_sam.util import get_sam_model
 13except ImportError:
 14    get_sam_model = None
 15
 16
 17#
 18# UNETR IMPLEMENTATION [Vision Transformer (ViT from MAE / ViT from SAM) + UNet Decoder from `torch_em`]
 19#
 20
 21
 22class UNETR(nn.Module):
 23    """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
 24
 25    Args:
 26        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 27        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
 28        encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
 29        decoder: The convolutional decoder.
 30        out_channels: The number of output channels of the UNETR.
 31        use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
 32        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 33        resize_input: Whether to resize the input images to match `img_size`.
 34        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 35            Can either be a filepath or an already loaded checkpoint.
 36        final_activation: The activation to apply to the UNETR output.
 37        use_skip_connection: Whether to use skip connections.
 38        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 39        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 40    """
 41    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 42        """Function to load pretrained weights to the image encoder.
 43        """
 44        if isinstance(checkpoint, str):
 45            if backbone == "sam" and isinstance(encoder, str):
 46                # If we have a SAM encoder, then we first try to load the full SAM Model
 47                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 48                # from the checkpoint
 49                try:
 50                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
 51                    encoder_state = model.image_encoder.state_dict()
 52                except Exception:
 53                    # Try loading the encoder state directly from a checkpoint.
 54                    encoder_state = torch.load(checkpoint, weights_only=False)
 55
 56            elif backbone == "mae":
 57                # vit initialization hints from:
 58                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 59                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
 60                encoder_state = OrderedDict({
 61                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
 62                })
 63                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 64                current_encoder_state = self.encoder.state_dict()
 65                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 66                    del self.encoder.head
 67
 68        else:
 69            encoder_state = checkpoint
 70
 71        self.encoder.load_state_dict(encoder_state)
 72
 73    def __init__(
 74        self,
 75        img_size: int = 1024,
 76        backbone: str = "sam",
 77        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 78        decoder: Optional[nn.Module] = None,
 79        out_channels: int = 1,
 80        use_sam_stats: bool = False,
 81        use_mae_stats: bool = False,
 82        resize_input: bool = True,
 83        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 84        final_activation: Optional[Union[str, nn.Module]] = None,
 85        use_skip_connection: bool = True,
 86        embed_dim: Optional[int] = None,
 87        use_conv_transpose: bool = True,
 88    ) -> None:
 89        super().__init__()
 90
 91        self.use_sam_stats = use_sam_stats
 92        self.use_mae_stats = use_mae_stats
 93        self.use_skip_connection = use_skip_connection
 94        self.resize_input = resize_input
 95
 96        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
 97            print(f"Using {encoder} from {backbone.upper()}")
 98            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
 99            if encoder_checkpoint is not None:
100                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
101
102            in_chans = self.encoder.in_chans
103            if embed_dim is None:
104                embed_dim = self.encoder.embed_dim
105
106        else:  # `nn.Module` ViT backbone
107            self.encoder = encoder
108
109            have_neck = False
110            for name, _ in self.encoder.named_parameters():
111                if name.startswith("neck"):
112                    have_neck = True
113
114            if embed_dim is None:
115                if have_neck:
116                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
117                else:
118                    embed_dim = self.encoder.patch_embed.proj.out_channels
119
120            try:
121                in_chans = self.encoder.patch_embed.proj.in_channels
122            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
123                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
124
125        # parameters for the decoder network
126        depth = 3
127        initial_features = 64
128        gain = 2
129        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
130        scale_factors = depth * [2]
131        self.out_channels = out_channels
132
133        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
134        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
135
136        if decoder is None:
137            self.decoder = Decoder(
138                features=features_decoder,
139                scale_factors=scale_factors[::-1],
140                conv_block_impl=ConvBlock2d,
141                sampler_impl=_upsampler,
142            )
143        else:
144            self.decoder = decoder
145
146        if use_skip_connection:
147            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
148            self.deconv2 = nn.Sequential(
149                Deconv2DBlock(embed_dim, features_decoder[0]),
150                Deconv2DBlock(features_decoder[0], features_decoder[1])
151            )
152            self.deconv3 = nn.Sequential(
153                Deconv2DBlock(embed_dim, features_decoder[0]),
154                Deconv2DBlock(features_decoder[0], features_decoder[1]),
155                Deconv2DBlock(features_decoder[1], features_decoder[2])
156            )
157            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
158        else:
159            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
160            self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
161            self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
162            self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])
163
164        self.base = ConvBlock2d(embed_dim, features_decoder[0])
165        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
166        self.deconv_out = _upsampler(
167            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
168        )
169        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
170        self.final_activation = self._get_activation(final_activation)
171
172    def _get_activation(self, activation):
173        return_activation = None
174        if activation is None:
175            return None
176        if isinstance(activation, nn.Module):
177            return activation
178        if isinstance(activation, str):
179            return_activation = getattr(nn, activation, None)
180        if return_activation is None:
181            raise ValueError(f"Invalid activation: {activation}")
182        return return_activation()
183
184    @staticmethod
185    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
186        """Compute the output size given input size and target long side length.
187
188        Args:
189            oldh: The input image height.
190            oldw: The input image width.
191            long_side_length: The longest side length for resizing.
192
193        Returns:
194            The new image height.
195            The new image width.
196        """
197        scale = long_side_length * 1.0 / max(oldh, oldw)
198        newh, neww = oldh * scale, oldw * scale
199        neww = int(neww + 0.5)
200        newh = int(newh + 0.5)
201        return (newh, neww)
202
203    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
204        """Resize the image so that the longest side has the correct length.
205
206        Expects batched images with shape BxCxHxW and float format.
207
208        Args:
209            image: The input image.
210
211        Returns:
212            The resized image.
213        """
214        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
215        return F.interpolate(
216            image, target_size, mode="bilinear", align_corners=False, antialias=True
217        )
218
219    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
220        """@private
221        """
222        device = x.device
223
224        if self.use_sam_stats:
225            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
226            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
227        elif self.use_mae_stats:
228            # TODO: add mean std from mae experiments (or open up arguments for this)
229            raise NotImplementedError
230        else:
231            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
232            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
233
234        if self.resize_input:
235            x = self.resize_longest_side(x)
236        input_shape = x.shape[-2:]
237
238        x = (x - pixel_mean) / pixel_std
239        h, w = x.shape[-2:]
240        padh = self.encoder.img_size - h
241        padw = self.encoder.img_size - w
242        x = F.pad(x, (0, padw, 0, padh))
243        return x, input_shape
244
245    def postprocess_masks(
246        self,
247        masks: torch.Tensor,
248        input_size: Tuple[int, ...],
249        original_size: Tuple[int, ...],
250    ) -> torch.Tensor:
251        """@private
252        """
253        masks = F.interpolate(
254            masks,
255            (self.encoder.img_size, self.encoder.img_size),
256            mode="bilinear",
257            align_corners=False,
258        )
259        masks = masks[..., : input_size[0], : input_size[1]]
260        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
261        return masks
262
263    def forward(self, x: torch.Tensor) -> torch.Tensor:
264        """Apply the UNETR to the input data.
265
266        Args:
267            x: The input tensor.
268
269        Returns:
270            The UNETR output.
271        """
272        original_shape = x.shape[-2:]
273
274        # Reshape the inputs to the shape expected by the encoder
275        # and normalize the inputs if normalization is part of the model.
276        x, input_shape = self.preprocess(x)
277
278        use_skip_connection = getattr(self, "use_skip_connection", True)
279
280        encoder_outputs = self.encoder(x)
281
282        if isinstance(encoder_outputs[-1], list):
283            # `encoder_outputs` can be arranged in only two forms:
284            #   - either we only return the image embeddings
285            #   - or, we return the image embeddings and the "list" of global attention layers
286            z12, from_encoder = encoder_outputs
287        else:
288            z12 = encoder_outputs
289
290        if use_skip_connection:
291            from_encoder = from_encoder[::-1]
292            z9 = self.deconv1(from_encoder[0])
293            z6 = self.deconv2(from_encoder[1])
294            z3 = self.deconv3(from_encoder[2])
295            z0 = self.deconv4(x)
296
297        else:
298            z9 = self.deconv1(z12)
299            z6 = self.deconv2(z9)
300            z3 = self.deconv3(z6)
301            z0 = self.deconv4(z3)
302
303        updated_from_encoder = [z9, z6, z3]
304
305        x = self.base(z12)
306        x = self.decoder(x, encoder_inputs=updated_from_encoder)
307        x = self.deconv_out(x)
308
309        x = torch.cat([x, z0], dim=1)
310        x = self.decoder_head(x)
311
312        x = self.out_conv(x)
313        if self.final_activation is not None:
314            x = self.final_activation(x)
315
316        x = self.postprocess_masks(x, input_shape, original_shape)
317        return x
318
319
320#
321#  ADDITIONAL FUNCTIONALITIES
322#
323
324
325class SingleDeconv2DBlock(nn.Module):
326    """@private
327    """
328    def __init__(self, scale_factor, in_channels, out_channels):
329        super().__init__()
330        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
331
332    def forward(self, x):
333        return self.block(x)
334
335
336class SingleConv2DBlock(nn.Module):
337    """@private
338    """
339    def __init__(self, in_channels, out_channels, kernel_size):
340        super().__init__()
341        self.block = nn.Conv2d(
342            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
343        )
344
345    def forward(self, x):
346        return self.block(x)
347
348
349class Conv2DBlock(nn.Module):
350    """@private
351    """
352    def __init__(self, in_channels, out_channels, kernel_size=3):
353        super().__init__()
354        self.block = nn.Sequential(
355            SingleConv2DBlock(in_channels, out_channels, kernel_size),
356            nn.BatchNorm2d(out_channels),
357            nn.ReLU(True)
358        )
359
360    def forward(self, x):
361        return self.block(x)
362
363
364class Deconv2DBlock(nn.Module):
365    """@private
366    """
367    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
368        super().__init__()
369        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
370        self.block = nn.Sequential(
371            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
372            SingleConv2DBlock(out_channels, out_channels, kernel_size),
373            nn.BatchNorm2d(out_channels),
374            nn.ReLU(True)
375        )
376
377    def forward(self, x):
378        return self.block(x)
class UNETR(torch.nn.modules.module.Module):
 23class UNETR(nn.Module):
 24    """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
 25
 26    Args:
 27        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 28        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
 29        encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
 30        decoder: The convolutional decoder.
 31        out_channels: The number of output channels of the UNETR.
 32        use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
 33        use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
 34        resize_input: Whether to resize the input images to match `img_size`.
 35        encoder_checkpoint: Checkpoint for initializing the vision transformer.
 36            Can either be a filepath or an already loaded checkpoint.
 37        final_activation: The activation to apply to the UNETR output.
 38        use_skip_connection: Whether to use skip connections.
 39        embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
 40        use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
 41    """
 42    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 43        """Function to load pretrained weights to the image encoder.
 44        """
 45        if isinstance(checkpoint, str):
 46            if backbone == "sam" and isinstance(encoder, str):
 47                # If we have a SAM encoder, then we first try to load the full SAM Model
 48                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 49                # from the checkpoint
 50                try:
 51                    _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
 52                    encoder_state = model.image_encoder.state_dict()
 53                except Exception:
 54                    # Try loading the encoder state directly from a checkpoint.
 55                    encoder_state = torch.load(checkpoint, weights_only=False)
 56
 57            elif backbone == "mae":
 58                # vit initialization hints from:
 59                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 60                encoder_state = torch.load(checkpoint, weights_only=False)["model"]
 61                encoder_state = OrderedDict({
 62                    k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
 63                })
 64                # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 65                current_encoder_state = self.encoder.state_dict()
 66                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 67                    del self.encoder.head
 68
 69        else:
 70            encoder_state = checkpoint
 71
 72        self.encoder.load_state_dict(encoder_state)
 73
 74    def __init__(
 75        self,
 76        img_size: int = 1024,
 77        backbone: str = "sam",
 78        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 79        decoder: Optional[nn.Module] = None,
 80        out_channels: int = 1,
 81        use_sam_stats: bool = False,
 82        use_mae_stats: bool = False,
 83        resize_input: bool = True,
 84        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 85        final_activation: Optional[Union[str, nn.Module]] = None,
 86        use_skip_connection: bool = True,
 87        embed_dim: Optional[int] = None,
 88        use_conv_transpose: bool = True,
 89    ) -> None:
 90        super().__init__()
 91
 92        self.use_sam_stats = use_sam_stats
 93        self.use_mae_stats = use_mae_stats
 94        self.use_skip_connection = use_skip_connection
 95        self.resize_input = resize_input
 96
 97        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
 98            print(f"Using {encoder} from {backbone.upper()}")
 99            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
100            if encoder_checkpoint is not None:
101                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
102
103            in_chans = self.encoder.in_chans
104            if embed_dim is None:
105                embed_dim = self.encoder.embed_dim
106
107        else:  # `nn.Module` ViT backbone
108            self.encoder = encoder
109
110            have_neck = False
111            for name, _ in self.encoder.named_parameters():
112                if name.startswith("neck"):
113                    have_neck = True
114
115            if embed_dim is None:
116                if have_neck:
117                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
118                else:
119                    embed_dim = self.encoder.patch_embed.proj.out_channels
120
121            try:
122                in_chans = self.encoder.patch_embed.proj.in_channels
123            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
124                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
125
126        # parameters for the decoder network
127        depth = 3
128        initial_features = 64
129        gain = 2
130        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
131        scale_factors = depth * [2]
132        self.out_channels = out_channels
133
134        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
135        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
136
137        if decoder is None:
138            self.decoder = Decoder(
139                features=features_decoder,
140                scale_factors=scale_factors[::-1],
141                conv_block_impl=ConvBlock2d,
142                sampler_impl=_upsampler,
143            )
144        else:
145            self.decoder = decoder
146
147        if use_skip_connection:
148            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
149            self.deconv2 = nn.Sequential(
150                Deconv2DBlock(embed_dim, features_decoder[0]),
151                Deconv2DBlock(features_decoder[0], features_decoder[1])
152            )
153            self.deconv3 = nn.Sequential(
154                Deconv2DBlock(embed_dim, features_decoder[0]),
155                Deconv2DBlock(features_decoder[0], features_decoder[1]),
156                Deconv2DBlock(features_decoder[1], features_decoder[2])
157            )
158            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
159        else:
160            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
161            self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
162            self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
163            self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])
164
165        self.base = ConvBlock2d(embed_dim, features_decoder[0])
166        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
167        self.deconv_out = _upsampler(
168            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
169        )
170        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
171        self.final_activation = self._get_activation(final_activation)
172
173    def _get_activation(self, activation):
174        return_activation = None
175        if activation is None:
176            return None
177        if isinstance(activation, nn.Module):
178            return activation
179        if isinstance(activation, str):
180            return_activation = getattr(nn, activation, None)
181        if return_activation is None:
182            raise ValueError(f"Invalid activation: {activation}")
183        return return_activation()
184
185    @staticmethod
186    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
187        """Compute the output size given input size and target long side length.
188
189        Args:
190            oldh: The input image height.
191            oldw: The input image width.
192            long_side_length: The longest side length for resizing.
193
194        Returns:
195            The new image height.
196            The new image width.
197        """
198        scale = long_side_length * 1.0 / max(oldh, oldw)
199        newh, neww = oldh * scale, oldw * scale
200        neww = int(neww + 0.5)
201        newh = int(newh + 0.5)
202        return (newh, neww)
203
204    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
205        """Resize the image so that the longest side has the correct length.
206
207        Expects batched images with shape BxCxHxW and float format.
208
209        Args:
210            image: The input image.
211
212        Returns:
213            The resized image.
214        """
215        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
216        return F.interpolate(
217            image, target_size, mode="bilinear", align_corners=False, antialias=True
218        )
219
220    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
221        """@private
222        """
223        device = x.device
224
225        if self.use_sam_stats:
226            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
227            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
228        elif self.use_mae_stats:
229            # TODO: add mean std from mae experiments (or open up arguments for this)
230            raise NotImplementedError
231        else:
232            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
233            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
234
235        if self.resize_input:
236            x = self.resize_longest_side(x)
237        input_shape = x.shape[-2:]
238
239        x = (x - pixel_mean) / pixel_std
240        h, w = x.shape[-2:]
241        padh = self.encoder.img_size - h
242        padw = self.encoder.img_size - w
243        x = F.pad(x, (0, padw, 0, padh))
244        return x, input_shape
245
246    def postprocess_masks(
247        self,
248        masks: torch.Tensor,
249        input_size: Tuple[int, ...],
250        original_size: Tuple[int, ...],
251    ) -> torch.Tensor:
252        """@private
253        """
254        masks = F.interpolate(
255            masks,
256            (self.encoder.img_size, self.encoder.img_size),
257            mode="bilinear",
258            align_corners=False,
259        )
260        masks = masks[..., : input_size[0], : input_size[1]]
261        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
262        return masks
263
264    def forward(self, x: torch.Tensor) -> torch.Tensor:
265        """Apply the UNETR to the input data.
266
267        Args:
268            x: The input tensor.
269
270        Returns:
271            The UNETR output.
272        """
273        original_shape = x.shape[-2:]
274
275        # Reshape the inputs to the shape expected by the encoder
276        # and normalize the inputs if normalization is part of the model.
277        x, input_shape = self.preprocess(x)
278
279        use_skip_connection = getattr(self, "use_skip_connection", True)
280
281        encoder_outputs = self.encoder(x)
282
283        if isinstance(encoder_outputs[-1], list):
284            # `encoder_outputs` can be arranged in only two forms:
285            #   - either we only return the image embeddings
286            #   - or, we return the image embeddings and the "list" of global attention layers
287            z12, from_encoder = encoder_outputs
288        else:
289            z12 = encoder_outputs
290
291        if use_skip_connection:
292            from_encoder = from_encoder[::-1]
293            z9 = self.deconv1(from_encoder[0])
294            z6 = self.deconv2(from_encoder[1])
295            z3 = self.deconv3(from_encoder[2])
296            z0 = self.deconv4(x)
297
298        else:
299            z9 = self.deconv1(z12)
300            z6 = self.deconv2(z9)
301            z3 = self.deconv3(z6)
302            z0 = self.deconv4(z3)
303
304        updated_from_encoder = [z9, z6, z3]
305
306        x = self.base(z12)
307        x = self.decoder(x, encoder_inputs=updated_from_encoder)
308        x = self.deconv_out(x)
309
310        x = torch.cat([x, z0], dim=1)
311        x = self.decoder_head(x)
312
313        x = self.out_conv(x)
314        if self.final_activation is not None:
315            x = self.final_activation(x)
316
317        x = self.postprocess_masks(x, input_shape, original_shape)
318        return x

A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • backbone: The name of the vision transformer implementation. One of "sam" or "mae".
  • encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
  • decoder: The convolutional decoder.
  • out_channels: The number of output channels of the UNETR.
  • use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
  • use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
  • resize_input: Whether to resize the input images to match img_size.
  • encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
  • final_activation: The activation to apply to the UNETR output.
  • use_skip_connection: Whether to use skip connections.
  • embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
  • use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
UNETR( img_size: int = 1024, backbone: str = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = True)
 74    def __init__(
 75        self,
 76        img_size: int = 1024,
 77        backbone: str = "sam",
 78        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 79        decoder: Optional[nn.Module] = None,
 80        out_channels: int = 1,
 81        use_sam_stats: bool = False,
 82        use_mae_stats: bool = False,
 83        resize_input: bool = True,
 84        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 85        final_activation: Optional[Union[str, nn.Module]] = None,
 86        use_skip_connection: bool = True,
 87        embed_dim: Optional[int] = None,
 88        use_conv_transpose: bool = True,
 89    ) -> None:
 90        super().__init__()
 91
 92        self.use_sam_stats = use_sam_stats
 93        self.use_mae_stats = use_mae_stats
 94        self.use_skip_connection = use_skip_connection
 95        self.resize_input = resize_input
 96
 97        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
 98            print(f"Using {encoder} from {backbone.upper()}")
 99            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
100            if encoder_checkpoint is not None:
101                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
102
103            in_chans = self.encoder.in_chans
104            if embed_dim is None:
105                embed_dim = self.encoder.embed_dim
106
107        else:  # `nn.Module` ViT backbone
108            self.encoder = encoder
109
110            have_neck = False
111            for name, _ in self.encoder.named_parameters():
112                if name.startswith("neck"):
113                    have_neck = True
114
115            if embed_dim is None:
116                if have_neck:
117                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
118                else:
119                    embed_dim = self.encoder.patch_embed.proj.out_channels
120
121            try:
122                in_chans = self.encoder.patch_embed.proj.in_channels
123            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
124                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
125
126        # parameters for the decoder network
127        depth = 3
128        initial_features = 64
129        gain = 2
130        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
131        scale_factors = depth * [2]
132        self.out_channels = out_channels
133
134        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
135        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
136
137        if decoder is None:
138            self.decoder = Decoder(
139                features=features_decoder,
140                scale_factors=scale_factors[::-1],
141                conv_block_impl=ConvBlock2d,
142                sampler_impl=_upsampler,
143            )
144        else:
145            self.decoder = decoder
146
147        if use_skip_connection:
148            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
149            self.deconv2 = nn.Sequential(
150                Deconv2DBlock(embed_dim, features_decoder[0]),
151                Deconv2DBlock(features_decoder[0], features_decoder[1])
152            )
153            self.deconv3 = nn.Sequential(
154                Deconv2DBlock(embed_dim, features_decoder[0]),
155                Deconv2DBlock(features_decoder[0], features_decoder[1]),
156                Deconv2DBlock(features_decoder[1], features_decoder[2])
157            )
158            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
159        else:
160            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
161            self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
162            self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
163            self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])
164
165        self.base = ConvBlock2d(embed_dim, features_decoder[0])
166        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
167        self.deconv_out = _upsampler(
168            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
169        )
170        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
171        self.final_activation = self._get_activation(final_activation)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

use_sam_stats
use_mae_stats
use_skip_connection
resize_input
out_channels
base
out_conv
deconv_out
decoder_head
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
185    @staticmethod
186    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
187        """Compute the output size given input size and target long side length.
188
189        Args:
190            oldh: The input image height.
191            oldw: The input image width.
192            long_side_length: The longest side length for resizing.
193
194        Returns:
195            The new image height.
196            The new image width.
197        """
198        scale = long_side_length * 1.0 / max(oldh, oldw)
199        newh, neww = oldh * scale, oldw * scale
200        neww = int(neww + 0.5)
201        newh = int(newh + 0.5)
202        return (newh, neww)

Compute the output size given input size and target long side length.

Arguments:
  • oldh: The input image height.
  • oldw: The input image width.
  • long_side_length: The longest side length for resizing.
Returns:

The new image height. The new image width.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
204    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
205        """Resize the image so that the longest side has the correct length.
206
207        Expects batched images with shape BxCxHxW and float format.
208
209        Args:
210            image: The input image.
211
212        Returns:
213            The resized image.
214        """
215        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
216        return F.interpolate(
217            image, target_size, mode="bilinear", align_corners=False, antialias=True
218        )

Resize the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW and float format.

Arguments:
  • image: The input image.
Returns:

The resized image.

def forward(self, x: torch.Tensor) -> torch.Tensor:
264    def forward(self, x: torch.Tensor) -> torch.Tensor:
265        """Apply the UNETR to the input data.
266
267        Args:
268            x: The input tensor.
269
270        Returns:
271            The UNETR output.
272        """
273        original_shape = x.shape[-2:]
274
275        # Reshape the inputs to the shape expected by the encoder
276        # and normalize the inputs if normalization is part of the model.
277        x, input_shape = self.preprocess(x)
278
279        use_skip_connection = getattr(self, "use_skip_connection", True)
280
281        encoder_outputs = self.encoder(x)
282
283        if isinstance(encoder_outputs[-1], list):
284            # `encoder_outputs` can be arranged in only two forms:
285            #   - either we only return the image embeddings
286            #   - or, we return the image embeddings and the "list" of global attention layers
287            z12, from_encoder = encoder_outputs
288        else:
289            z12 = encoder_outputs
290
291        if use_skip_connection:
292            from_encoder = from_encoder[::-1]
293            z9 = self.deconv1(from_encoder[0])
294            z6 = self.deconv2(from_encoder[1])
295            z3 = self.deconv3(from_encoder[2])
296            z0 = self.deconv4(x)
297
298        else:
299            z9 = self.deconv1(z12)
300            z6 = self.deconv2(z9)
301            z3 = self.deconv3(z6)
302            z0 = self.deconv4(z3)
303
304        updated_from_encoder = [z9, z6, z3]
305
306        x = self.base(z12)
307        x = self.decoder(x, encoder_inputs=updated_from_encoder)
308        x = self.deconv_out(x)
309
310        x = torch.cat([x, z0], dim=1)
311        x = self.decoder_head(x)
312
313        x = self.out_conv(x)
314        if self.final_activation is not None:
315            x = self.final_activation(x)
316
317        x = self.postprocess_masks(x, input_shape, original_shape)
318        return x

Apply the UNETR to the input data.

Arguments:
  • x: The input tensor.
Returns:

The UNETR output.