torch_em.model.unetr
1from collections import OrderedDict 2from typing import Optional, Tuple, Union 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7 8from .vit import get_vision_transformer 9from .unet import Decoder, ConvBlock2d, Upsampler2d 10 11try: 12 from micro_sam.util import get_sam_model 13except ImportError: 14 get_sam_model = None 15 16try: 17 from micro_sam2.util import get_sam2_model 18except ImportError: 19 get_sam2_model = None 20 21 22# 23# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / MAE / ScaleMAE) + UNet Decoder from `torch_em`] 24# 25 26 27class UNETR(nn.Module): 28 """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder. 29 30 Args: 31 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 32 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 33 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 34 decoder: The convolutional decoder. 35 out_channels: The number of output channels of the UNETR. 36 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 37 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 38 resize_input: Whether to resize the input images to match `img_size`. 39 By default, it resizes the inputs to match the `img_size`. 40 encoder_checkpoint: Checkpoint for initializing the vision transformer. 41 Can either be a filepath or an already loaded checkpoint. 42 final_activation: The activation to apply to the UNETR output. 43 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 44 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 45 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 46 By default, it uses resampling for upsampling. 47 """ 48 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 49 """Function to load pretrained weights to the image encoder. 50 """ 51 if isinstance(checkpoint, str): 52 if backbone == "sam" and isinstance(encoder, str): 53 # If we have a SAM encoder, then we first try to load the full SAM Model 54 # (using micro_sam) and otherwise fall back on directly loading the encoder state 55 # from the checkpoint 56 try: 57 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 58 encoder_state = model.image_encoder.state_dict() 59 except Exception: 60 # Try loading the encoder state directly from a checkpoint. 61 encoder_state = torch.load(checkpoint, weights_only=False) 62 63 elif backbone == "sam2" and isinstance(encoder, str): 64 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 65 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 66 # from the checkpoint 67 try: 68 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 69 encoder_state = model.image_encoder.state_dict() 70 except Exception: 71 # Try loading the encoder state directly from a checkpoint. 72 encoder_state = torch.load(checkpoint, weights_only=False) 73 74 elif backbone == "mae": 75 # vit initialization hints from: 76 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 77 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 78 encoder_state = OrderedDict({ 79 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 80 }) 81 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 82 current_encoder_state = self.encoder.state_dict() 83 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 84 del self.encoder.head 85 86 elif backbone == "scalemae": 87 # Load the encoder state directly from a checkpoint. 88 encoder_state = torch.load(checkpoint)["model"] 89 encoder_state = OrderedDict({ 90 k: v for k, v in encoder_state.items() 91 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 92 }) 93 94 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 95 current_encoder_state = self.encoder.state_dict() 96 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 97 del self.encoder.head 98 99 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 100 del self.encoder.pos_embed 101 102 else: 103 encoder_state = checkpoint 104 105 self.encoder.load_state_dict(encoder_state) 106 107 def __init__( 108 self, 109 img_size: int = 1024, 110 backbone: str = "sam", 111 encoder: Optional[Union[nn.Module, str]] = "vit_b", 112 decoder: Optional[nn.Module] = None, 113 out_channels: int = 1, 114 use_sam_stats: bool = False, 115 use_mae_stats: bool = False, 116 resize_input: bool = True, 117 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 118 final_activation: Optional[Union[str, nn.Module]] = None, 119 use_skip_connection: bool = True, 120 embed_dim: Optional[int] = None, 121 use_conv_transpose: bool = False, 122 **kwargs 123 ) -> None: 124 super().__init__() 125 126 self.use_sam_stats = use_sam_stats 127 self.use_mae_stats = use_mae_stats 128 self.use_skip_connection = use_skip_connection 129 self.resize_input = resize_input 130 131 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 132 print(f"Using {encoder} from {backbone.upper()}") 133 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 134 135 if encoder_checkpoint is not None: 136 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 137 138 if backbone == "sam2": 139 in_chans = self.encoder.trunk.patch_embed.proj.in_channels 140 else: 141 in_chans = self.encoder.in_chans 142 143 if embed_dim is None: 144 embed_dim = self.encoder.embed_dim 145 146 else: # `nn.Module` ViT backbone 147 self.encoder = encoder 148 149 have_neck = False 150 for name, _ in self.encoder.named_parameters(): 151 if name.startswith("neck"): 152 have_neck = True 153 154 if embed_dim is None: 155 if have_neck: 156 embed_dim = self.encoder.neck[2].out_channels # the value is 256 157 else: 158 embed_dim = self.encoder.patch_embed.proj.out_channels 159 160 try: 161 in_chans = self.encoder.patch_embed.proj.in_channels 162 except AttributeError: # for getting the input channels while using vit_t from MobileSam 163 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 164 165 # parameters for the decoder network 166 depth = 3 167 initial_features = 64 168 gain = 2 169 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 170 scale_factors = depth * [2] 171 self.out_channels = out_channels 172 173 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 174 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 175 176 if decoder is None: 177 self.decoder = Decoder( 178 features=features_decoder, 179 scale_factors=scale_factors[::-1], 180 conv_block_impl=ConvBlock2d, 181 sampler_impl=_upsampler, 182 ) 183 else: 184 self.decoder = decoder 185 186 if use_skip_connection: 187 self.deconv1 = Deconv2DBlock( 188 in_channels=embed_dim, 189 out_channels=features_decoder[0], 190 use_conv_transpose=use_conv_transpose, 191 ) 192 self.deconv2 = nn.Sequential( 193 Deconv2DBlock( 194 in_channels=embed_dim, 195 out_channels=features_decoder[0], 196 use_conv_transpose=use_conv_transpose, 197 ), 198 Deconv2DBlock( 199 in_channels=features_decoder[0], 200 out_channels=features_decoder[1], 201 use_conv_transpose=use_conv_transpose, 202 ) 203 ) 204 self.deconv3 = nn.Sequential( 205 Deconv2DBlock( 206 in_channels=embed_dim, 207 out_channels=features_decoder[0], 208 use_conv_transpose=use_conv_transpose, 209 ), 210 Deconv2DBlock( 211 in_channels=features_decoder[0], 212 out_channels=features_decoder[1], 213 use_conv_transpose=use_conv_transpose, 214 ), 215 Deconv2DBlock( 216 in_channels=features_decoder[1], 217 out_channels=features_decoder[2], 218 use_conv_transpose=use_conv_transpose, 219 ) 220 ) 221 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 222 else: 223 self.deconv1 = Deconv2DBlock( 224 in_channels=embed_dim, 225 out_channels=features_decoder[0], 226 use_conv_transpose=use_conv_transpose, 227 ) 228 self.deconv2 = Deconv2DBlock( 229 in_channels=features_decoder[0], 230 out_channels=features_decoder[1], 231 use_conv_transpose=use_conv_transpose, 232 ) 233 self.deconv3 = Deconv2DBlock( 234 in_channels=features_decoder[1], 235 out_channels=features_decoder[2], 236 use_conv_transpose=use_conv_transpose, 237 ) 238 self.deconv4 = Deconv2DBlock( 239 in_channels=features_decoder[2], 240 out_channels=features_decoder[3], 241 use_conv_transpose=use_conv_transpose, 242 ) 243 244 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 245 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 246 self.deconv_out = _upsampler( 247 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 248 ) 249 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 250 self.final_activation = self._get_activation(final_activation) 251 252 def _get_activation(self, activation): 253 return_activation = None 254 if activation is None: 255 return None 256 if isinstance(activation, nn.Module): 257 return activation 258 if isinstance(activation, str): 259 return_activation = getattr(nn, activation, None) 260 if return_activation is None: 261 raise ValueError(f"Invalid activation: {activation}") 262 263 return return_activation() 264 265 @staticmethod 266 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 267 """Compute the output size given input size and target long side length. 268 269 Args: 270 oldh: The input image height. 271 oldw: The input image width. 272 long_side_length: The longest side length for resizing. 273 274 Returns: 275 The new image height. 276 The new image width. 277 """ 278 scale = long_side_length * 1.0 / max(oldh, oldw) 279 newh, neww = oldh * scale, oldw * scale 280 neww = int(neww + 0.5) 281 newh = int(newh + 0.5) 282 return (newh, neww) 283 284 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 285 """Resize the image so that the longest side has the correct length. 286 287 Expects batched images with shape BxCxHxW and float format. 288 289 Args: 290 image: The input image. 291 292 Returns: 293 The resized image. 294 """ 295 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 296 return F.interpolate( 297 image, target_size, mode="bilinear", align_corners=False, antialias=True 298 ) 299 300 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 301 """@private 302 """ 303 device = x.device 304 305 if self.use_sam_stats: 306 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 307 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 308 elif self.use_mae_stats: 309 # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 310 raise NotImplementedError 311 else: 312 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 313 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 314 315 if self.resize_input: 316 x = self.resize_longest_side(x) 317 input_shape = x.shape[-2:] 318 319 x = (x - pixel_mean) / pixel_std 320 h, w = x.shape[-2:] 321 padh = self.encoder.img_size - h 322 padw = self.encoder.img_size - w 323 x = F.pad(x, (0, padw, 0, padh)) 324 return x, input_shape 325 326 def postprocess_masks( 327 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 328 ) -> torch.Tensor: 329 """@private 330 """ 331 masks = F.interpolate( 332 masks, 333 (self.encoder.img_size, self.encoder.img_size), 334 mode="bilinear", 335 align_corners=False, 336 ) 337 masks = masks[..., : input_size[0], : input_size[1]] 338 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 339 return masks 340 341 def forward(self, x: torch.Tensor) -> torch.Tensor: 342 """Apply the UNETR to the input data. 343 344 Args: 345 x: The input tensor. 346 347 Returns: 348 The UNETR output. 349 """ 350 original_shape = x.shape[-2:] 351 352 # Reshape the inputs to the shape expected by the encoder 353 # and normalize the inputs if normalization is part of the model. 354 x, input_shape = self.preprocess(x) 355 356 use_skip_connection = getattr(self, "use_skip_connection", True) 357 358 encoder_outputs = self.encoder(x) 359 360 if isinstance(encoder_outputs[-1], list): 361 # `encoder_outputs` can be arranged in only two forms: 362 # - either we only return the image embeddings 363 # - or, we return the image embeddings and the "list" of global attention layers 364 z12, from_encoder = encoder_outputs 365 else: 366 z12 = encoder_outputs 367 368 if use_skip_connection: 369 from_encoder = from_encoder[::-1] 370 z9 = self.deconv1(from_encoder[0]) 371 z6 = self.deconv2(from_encoder[1]) 372 z3 = self.deconv3(from_encoder[2]) 373 z0 = self.deconv4(x) 374 375 else: 376 z9 = self.deconv1(z12) 377 z6 = self.deconv2(z9) 378 z3 = self.deconv3(z6) 379 z0 = self.deconv4(z3) 380 381 updated_from_encoder = [z9, z6, z3] 382 383 x = self.base(z12) 384 x = self.decoder(x, encoder_inputs=updated_from_encoder) 385 x = self.deconv_out(x) 386 387 x = torch.cat([x, z0], dim=1) 388 x = self.decoder_head(x) 389 390 x = self.out_conv(x) 391 if self.final_activation is not None: 392 x = self.final_activation(x) 393 394 x = self.postprocess_masks(x, input_shape, original_shape) 395 return x 396 397 398# 399# ADDITIONAL FUNCTIONALITIES 400# 401 402 403class SingleDeconv2DBlock(nn.Module): 404 """@private 405 """ 406 def __init__(self, scale_factor, in_channels, out_channels): 407 super().__init__() 408 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 409 410 def forward(self, x): 411 return self.block(x) 412 413 414class SingleConv2DBlock(nn.Module): 415 """@private 416 """ 417 def __init__(self, in_channels, out_channels, kernel_size): 418 super().__init__() 419 self.block = nn.Conv2d( 420 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 421 ) 422 423 def forward(self, x): 424 return self.block(x) 425 426 427class Conv2DBlock(nn.Module): 428 """@private 429 """ 430 def __init__(self, in_channels, out_channels, kernel_size=3): 431 super().__init__() 432 self.block = nn.Sequential( 433 SingleConv2DBlock(in_channels, out_channels, kernel_size), 434 nn.BatchNorm2d(out_channels), 435 nn.ReLU(True) 436 ) 437 438 def forward(self, x): 439 return self.block(x) 440 441 442class Deconv2DBlock(nn.Module): 443 """@private 444 """ 445 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 446 super().__init__() 447 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 448 self.block = nn.Sequential( 449 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 450 SingleConv2DBlock(out_channels, out_channels, kernel_size), 451 nn.BatchNorm2d(out_channels), 452 nn.ReLU(True) 453 ) 454 455 def forward(self, x): 456 return self.block(x)
class
UNETR(torch.nn.modules.module.Module):
28class UNETR(nn.Module): 29 """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder. 30 31 Args: 32 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 33 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 34 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 35 decoder: The convolutional decoder. 36 out_channels: The number of output channels of the UNETR. 37 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 38 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 39 resize_input: Whether to resize the input images to match `img_size`. 40 By default, it resizes the inputs to match the `img_size`. 41 encoder_checkpoint: Checkpoint for initializing the vision transformer. 42 Can either be a filepath or an already loaded checkpoint. 43 final_activation: The activation to apply to the UNETR output. 44 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 45 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 46 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 47 By default, it uses resampling for upsampling. 48 """ 49 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 50 """Function to load pretrained weights to the image encoder. 51 """ 52 if isinstance(checkpoint, str): 53 if backbone == "sam" and isinstance(encoder, str): 54 # If we have a SAM encoder, then we first try to load the full SAM Model 55 # (using micro_sam) and otherwise fall back on directly loading the encoder state 56 # from the checkpoint 57 try: 58 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 59 encoder_state = model.image_encoder.state_dict() 60 except Exception: 61 # Try loading the encoder state directly from a checkpoint. 62 encoder_state = torch.load(checkpoint, weights_only=False) 63 64 elif backbone == "sam2" and isinstance(encoder, str): 65 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 66 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 67 # from the checkpoint 68 try: 69 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 70 encoder_state = model.image_encoder.state_dict() 71 except Exception: 72 # Try loading the encoder state directly from a checkpoint. 73 encoder_state = torch.load(checkpoint, weights_only=False) 74 75 elif backbone == "mae": 76 # vit initialization hints from: 77 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 78 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 79 encoder_state = OrderedDict({ 80 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 81 }) 82 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 83 current_encoder_state = self.encoder.state_dict() 84 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 85 del self.encoder.head 86 87 elif backbone == "scalemae": 88 # Load the encoder state directly from a checkpoint. 89 encoder_state = torch.load(checkpoint)["model"] 90 encoder_state = OrderedDict({ 91 k: v for k, v in encoder_state.items() 92 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 93 }) 94 95 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 96 current_encoder_state = self.encoder.state_dict() 97 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 98 del self.encoder.head 99 100 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 101 del self.encoder.pos_embed 102 103 else: 104 encoder_state = checkpoint 105 106 self.encoder.load_state_dict(encoder_state) 107 108 def __init__( 109 self, 110 img_size: int = 1024, 111 backbone: str = "sam", 112 encoder: Optional[Union[nn.Module, str]] = "vit_b", 113 decoder: Optional[nn.Module] = None, 114 out_channels: int = 1, 115 use_sam_stats: bool = False, 116 use_mae_stats: bool = False, 117 resize_input: bool = True, 118 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 119 final_activation: Optional[Union[str, nn.Module]] = None, 120 use_skip_connection: bool = True, 121 embed_dim: Optional[int] = None, 122 use_conv_transpose: bool = False, 123 **kwargs 124 ) -> None: 125 super().__init__() 126 127 self.use_sam_stats = use_sam_stats 128 self.use_mae_stats = use_mae_stats 129 self.use_skip_connection = use_skip_connection 130 self.resize_input = resize_input 131 132 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 133 print(f"Using {encoder} from {backbone.upper()}") 134 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 135 136 if encoder_checkpoint is not None: 137 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 138 139 if backbone == "sam2": 140 in_chans = self.encoder.trunk.patch_embed.proj.in_channels 141 else: 142 in_chans = self.encoder.in_chans 143 144 if embed_dim is None: 145 embed_dim = self.encoder.embed_dim 146 147 else: # `nn.Module` ViT backbone 148 self.encoder = encoder 149 150 have_neck = False 151 for name, _ in self.encoder.named_parameters(): 152 if name.startswith("neck"): 153 have_neck = True 154 155 if embed_dim is None: 156 if have_neck: 157 embed_dim = self.encoder.neck[2].out_channels # the value is 256 158 else: 159 embed_dim = self.encoder.patch_embed.proj.out_channels 160 161 try: 162 in_chans = self.encoder.patch_embed.proj.in_channels 163 except AttributeError: # for getting the input channels while using vit_t from MobileSam 164 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 165 166 # parameters for the decoder network 167 depth = 3 168 initial_features = 64 169 gain = 2 170 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 171 scale_factors = depth * [2] 172 self.out_channels = out_channels 173 174 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 175 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 176 177 if decoder is None: 178 self.decoder = Decoder( 179 features=features_decoder, 180 scale_factors=scale_factors[::-1], 181 conv_block_impl=ConvBlock2d, 182 sampler_impl=_upsampler, 183 ) 184 else: 185 self.decoder = decoder 186 187 if use_skip_connection: 188 self.deconv1 = Deconv2DBlock( 189 in_channels=embed_dim, 190 out_channels=features_decoder[0], 191 use_conv_transpose=use_conv_transpose, 192 ) 193 self.deconv2 = nn.Sequential( 194 Deconv2DBlock( 195 in_channels=embed_dim, 196 out_channels=features_decoder[0], 197 use_conv_transpose=use_conv_transpose, 198 ), 199 Deconv2DBlock( 200 in_channels=features_decoder[0], 201 out_channels=features_decoder[1], 202 use_conv_transpose=use_conv_transpose, 203 ) 204 ) 205 self.deconv3 = nn.Sequential( 206 Deconv2DBlock( 207 in_channels=embed_dim, 208 out_channels=features_decoder[0], 209 use_conv_transpose=use_conv_transpose, 210 ), 211 Deconv2DBlock( 212 in_channels=features_decoder[0], 213 out_channels=features_decoder[1], 214 use_conv_transpose=use_conv_transpose, 215 ), 216 Deconv2DBlock( 217 in_channels=features_decoder[1], 218 out_channels=features_decoder[2], 219 use_conv_transpose=use_conv_transpose, 220 ) 221 ) 222 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 223 else: 224 self.deconv1 = Deconv2DBlock( 225 in_channels=embed_dim, 226 out_channels=features_decoder[0], 227 use_conv_transpose=use_conv_transpose, 228 ) 229 self.deconv2 = Deconv2DBlock( 230 in_channels=features_decoder[0], 231 out_channels=features_decoder[1], 232 use_conv_transpose=use_conv_transpose, 233 ) 234 self.deconv3 = Deconv2DBlock( 235 in_channels=features_decoder[1], 236 out_channels=features_decoder[2], 237 use_conv_transpose=use_conv_transpose, 238 ) 239 self.deconv4 = Deconv2DBlock( 240 in_channels=features_decoder[2], 241 out_channels=features_decoder[3], 242 use_conv_transpose=use_conv_transpose, 243 ) 244 245 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 246 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 247 self.deconv_out = _upsampler( 248 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 249 ) 250 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 251 self.final_activation = self._get_activation(final_activation) 252 253 def _get_activation(self, activation): 254 return_activation = None 255 if activation is None: 256 return None 257 if isinstance(activation, nn.Module): 258 return activation 259 if isinstance(activation, str): 260 return_activation = getattr(nn, activation, None) 261 if return_activation is None: 262 raise ValueError(f"Invalid activation: {activation}") 263 264 return return_activation() 265 266 @staticmethod 267 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 268 """Compute the output size given input size and target long side length. 269 270 Args: 271 oldh: The input image height. 272 oldw: The input image width. 273 long_side_length: The longest side length for resizing. 274 275 Returns: 276 The new image height. 277 The new image width. 278 """ 279 scale = long_side_length * 1.0 / max(oldh, oldw) 280 newh, neww = oldh * scale, oldw * scale 281 neww = int(neww + 0.5) 282 newh = int(newh + 0.5) 283 return (newh, neww) 284 285 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 286 """Resize the image so that the longest side has the correct length. 287 288 Expects batched images with shape BxCxHxW and float format. 289 290 Args: 291 image: The input image. 292 293 Returns: 294 The resized image. 295 """ 296 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 297 return F.interpolate( 298 image, target_size, mode="bilinear", align_corners=False, antialias=True 299 ) 300 301 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 302 """@private 303 """ 304 device = x.device 305 306 if self.use_sam_stats: 307 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 308 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 309 elif self.use_mae_stats: 310 # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 311 raise NotImplementedError 312 else: 313 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 314 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 315 316 if self.resize_input: 317 x = self.resize_longest_side(x) 318 input_shape = x.shape[-2:] 319 320 x = (x - pixel_mean) / pixel_std 321 h, w = x.shape[-2:] 322 padh = self.encoder.img_size - h 323 padw = self.encoder.img_size - w 324 x = F.pad(x, (0, padw, 0, padh)) 325 return x, input_shape 326 327 def postprocess_masks( 328 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 329 ) -> torch.Tensor: 330 """@private 331 """ 332 masks = F.interpolate( 333 masks, 334 (self.encoder.img_size, self.encoder.img_size), 335 mode="bilinear", 336 align_corners=False, 337 ) 338 masks = masks[..., : input_size[0], : input_size[1]] 339 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 340 return masks 341 342 def forward(self, x: torch.Tensor) -> torch.Tensor: 343 """Apply the UNETR to the input data. 344 345 Args: 346 x: The input tensor. 347 348 Returns: 349 The UNETR output. 350 """ 351 original_shape = x.shape[-2:] 352 353 # Reshape the inputs to the shape expected by the encoder 354 # and normalize the inputs if normalization is part of the model. 355 x, input_shape = self.preprocess(x) 356 357 use_skip_connection = getattr(self, "use_skip_connection", True) 358 359 encoder_outputs = self.encoder(x) 360 361 if isinstance(encoder_outputs[-1], list): 362 # `encoder_outputs` can be arranged in only two forms: 363 # - either we only return the image embeddings 364 # - or, we return the image embeddings and the "list" of global attention layers 365 z12, from_encoder = encoder_outputs 366 else: 367 z12 = encoder_outputs 368 369 if use_skip_connection: 370 from_encoder = from_encoder[::-1] 371 z9 = self.deconv1(from_encoder[0]) 372 z6 = self.deconv2(from_encoder[1]) 373 z3 = self.deconv3(from_encoder[2]) 374 z0 = self.deconv4(x) 375 376 else: 377 z9 = self.deconv1(z12) 378 z6 = self.deconv2(z9) 379 z3 = self.deconv3(z6) 380 z0 = self.deconv4(z3) 381 382 updated_from_encoder = [z9, z6, z3] 383 384 x = self.base(z12) 385 x = self.decoder(x, encoder_inputs=updated_from_encoder) 386 x = self.deconv_out(x) 387 388 x = torch.cat([x, z0], dim=1) 389 x = self.decoder_head(x) 390 391 x = self.out_conv(x) 392 if self.final_activation is not None: 393 x = self.final_activation(x) 394 395 x = self.postprocess_masks(x, input_shape, original_shape) 396 return x
A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam" or "mae".
- encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size
. By default, it resizes the inputs to match theimg_size
. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
UNETR( img_size: int = 1024, backbone: str = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
108 def __init__( 109 self, 110 img_size: int = 1024, 111 backbone: str = "sam", 112 encoder: Optional[Union[nn.Module, str]] = "vit_b", 113 decoder: Optional[nn.Module] = None, 114 out_channels: int = 1, 115 use_sam_stats: bool = False, 116 use_mae_stats: bool = False, 117 resize_input: bool = True, 118 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 119 final_activation: Optional[Union[str, nn.Module]] = None, 120 use_skip_connection: bool = True, 121 embed_dim: Optional[int] = None, 122 use_conv_transpose: bool = False, 123 **kwargs 124 ) -> None: 125 super().__init__() 126 127 self.use_sam_stats = use_sam_stats 128 self.use_mae_stats = use_mae_stats 129 self.use_skip_connection = use_skip_connection 130 self.resize_input = resize_input 131 132 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 133 print(f"Using {encoder} from {backbone.upper()}") 134 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 135 136 if encoder_checkpoint is not None: 137 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 138 139 if backbone == "sam2": 140 in_chans = self.encoder.trunk.patch_embed.proj.in_channels 141 else: 142 in_chans = self.encoder.in_chans 143 144 if embed_dim is None: 145 embed_dim = self.encoder.embed_dim 146 147 else: # `nn.Module` ViT backbone 148 self.encoder = encoder 149 150 have_neck = False 151 for name, _ in self.encoder.named_parameters(): 152 if name.startswith("neck"): 153 have_neck = True 154 155 if embed_dim is None: 156 if have_neck: 157 embed_dim = self.encoder.neck[2].out_channels # the value is 256 158 else: 159 embed_dim = self.encoder.patch_embed.proj.out_channels 160 161 try: 162 in_chans = self.encoder.patch_embed.proj.in_channels 163 except AttributeError: # for getting the input channels while using vit_t from MobileSam 164 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 165 166 # parameters for the decoder network 167 depth = 3 168 initial_features = 64 169 gain = 2 170 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 171 scale_factors = depth * [2] 172 self.out_channels = out_channels 173 174 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 175 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 176 177 if decoder is None: 178 self.decoder = Decoder( 179 features=features_decoder, 180 scale_factors=scale_factors[::-1], 181 conv_block_impl=ConvBlock2d, 182 sampler_impl=_upsampler, 183 ) 184 else: 185 self.decoder = decoder 186 187 if use_skip_connection: 188 self.deconv1 = Deconv2DBlock( 189 in_channels=embed_dim, 190 out_channels=features_decoder[0], 191 use_conv_transpose=use_conv_transpose, 192 ) 193 self.deconv2 = nn.Sequential( 194 Deconv2DBlock( 195 in_channels=embed_dim, 196 out_channels=features_decoder[0], 197 use_conv_transpose=use_conv_transpose, 198 ), 199 Deconv2DBlock( 200 in_channels=features_decoder[0], 201 out_channels=features_decoder[1], 202 use_conv_transpose=use_conv_transpose, 203 ) 204 ) 205 self.deconv3 = nn.Sequential( 206 Deconv2DBlock( 207 in_channels=embed_dim, 208 out_channels=features_decoder[0], 209 use_conv_transpose=use_conv_transpose, 210 ), 211 Deconv2DBlock( 212 in_channels=features_decoder[0], 213 out_channels=features_decoder[1], 214 use_conv_transpose=use_conv_transpose, 215 ), 216 Deconv2DBlock( 217 in_channels=features_decoder[1], 218 out_channels=features_decoder[2], 219 use_conv_transpose=use_conv_transpose, 220 ) 221 ) 222 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 223 else: 224 self.deconv1 = Deconv2DBlock( 225 in_channels=embed_dim, 226 out_channels=features_decoder[0], 227 use_conv_transpose=use_conv_transpose, 228 ) 229 self.deconv2 = Deconv2DBlock( 230 in_channels=features_decoder[0], 231 out_channels=features_decoder[1], 232 use_conv_transpose=use_conv_transpose, 233 ) 234 self.deconv3 = Deconv2DBlock( 235 in_channels=features_decoder[1], 236 out_channels=features_decoder[2], 237 use_conv_transpose=use_conv_transpose, 238 ) 239 self.deconv4 = Deconv2DBlock( 240 in_channels=features_decoder[2], 241 out_channels=features_decoder[3], 242 use_conv_transpose=use_conv_transpose, 243 ) 244 245 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 246 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 247 self.deconv_out = _upsampler( 248 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 249 ) 250 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 251 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
@staticmethod
def
get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
266 @staticmethod 267 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 268 """Compute the output size given input size and target long side length. 269 270 Args: 271 oldh: The input image height. 272 oldw: The input image width. 273 long_side_length: The longest side length for resizing. 274 275 Returns: 276 The new image height. 277 The new image width. 278 """ 279 scale = long_side_length * 1.0 / max(oldh, oldw) 280 newh, neww = oldh * scale, oldw * scale 281 neww = int(neww + 0.5) 282 newh = int(newh + 0.5) 283 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
def
resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
285 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 286 """Resize the image so that the longest side has the correct length. 287 288 Expects batched images with shape BxCxHxW and float format. 289 290 Args: 291 image: The input image. 292 293 Returns: 294 The resized image. 295 """ 296 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 297 return F.interpolate( 298 image, target_size, mode="bilinear", align_corners=False, antialias=True 299 )
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
def
forward(self, x: torch.Tensor) -> torch.Tensor:
342 def forward(self, x: torch.Tensor) -> torch.Tensor: 343 """Apply the UNETR to the input data. 344 345 Args: 346 x: The input tensor. 347 348 Returns: 349 The UNETR output. 350 """ 351 original_shape = x.shape[-2:] 352 353 # Reshape the inputs to the shape expected by the encoder 354 # and normalize the inputs if normalization is part of the model. 355 x, input_shape = self.preprocess(x) 356 357 use_skip_connection = getattr(self, "use_skip_connection", True) 358 359 encoder_outputs = self.encoder(x) 360 361 if isinstance(encoder_outputs[-1], list): 362 # `encoder_outputs` can be arranged in only two forms: 363 # - either we only return the image embeddings 364 # - or, we return the image embeddings and the "list" of global attention layers 365 z12, from_encoder = encoder_outputs 366 else: 367 z12 = encoder_outputs 368 369 if use_skip_connection: 370 from_encoder = from_encoder[::-1] 371 z9 = self.deconv1(from_encoder[0]) 372 z6 = self.deconv2(from_encoder[1]) 373 z3 = self.deconv3(from_encoder[2]) 374 z0 = self.deconv4(x) 375 376 else: 377 z9 = self.deconv1(z12) 378 z6 = self.deconv2(z9) 379 z3 = self.deconv3(z6) 380 z0 = self.deconv4(z3) 381 382 updated_from_encoder = [z9, z6, z3] 383 384 x = self.base(z12) 385 x = self.decoder(x, encoder_inputs=updated_from_encoder) 386 x = self.deconv_out(x) 387 388 x = torch.cat([x, z0], dim=1) 389 x = self.decoder_head(x) 390 391 x = self.out_conv(x) 392 if self.final_activation is not None: 393 x = self.final_activation(x) 394 395 x = self.postprocess_masks(x, input_shape, original_shape) 396 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.