torch_em.model.unetr
1from collections import OrderedDict 2from typing import Optional, Tuple, Union 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7 8from .vit import get_vision_transformer 9from .unet import Decoder, ConvBlock2d, Upsampler2d 10 11try: 12 from micro_sam.util import get_sam_model 13except ImportError: 14 get_sam_model = None 15 16try: 17 from micro_sam2.util import get_sam2_model 18except ImportError: 19 get_sam2_model = None 20 21 22# 23# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / MAE / ScaleMAE) + UNet Decoder from `torch_em`] 24# 25 26 27class UNETR(nn.Module): 28 """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder. 29 30 Args: 31 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 32 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 33 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 34 decoder: The convolutional decoder. 35 out_channels: The number of output channels of the UNETR. 36 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 37 use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model. 38 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 39 resize_input: Whether to resize the input images to match `img_size`. 40 By default, it resizes the inputs to match the `img_size`. 41 encoder_checkpoint: Checkpoint for initializing the vision transformer. 42 Can either be a filepath or an already loaded checkpoint. 43 final_activation: The activation to apply to the UNETR output. 44 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 45 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 46 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 47 By default, it uses resampling for upsampling. 48 """ 49 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 50 """Function to load pretrained weights to the image encoder. 51 """ 52 if isinstance(checkpoint, str): 53 if backbone == "sam" and isinstance(encoder, str): 54 # If we have a SAM encoder, then we first try to load the full SAM Model 55 # (using micro_sam) and otherwise fall back on directly loading the encoder state 56 # from the checkpoint 57 try: 58 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 59 encoder_state = model.image_encoder.state_dict() 60 except Exception: 61 # Try loading the encoder state directly from a checkpoint. 62 encoder_state = torch.load(checkpoint, weights_only=False) 63 64 elif backbone == "sam2" and isinstance(encoder, str): 65 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 66 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 67 # from the checkpoint 68 try: 69 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 70 encoder_state = model.image_encoder.state_dict() 71 except Exception: 72 # Try loading the encoder state directly from a checkpoint. 73 encoder_state = torch.load(checkpoint, weights_only=False) 74 75 elif backbone == "mae": 76 # vit initialization hints from: 77 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 78 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 79 encoder_state = OrderedDict({ 80 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 81 }) 82 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 83 current_encoder_state = self.encoder.state_dict() 84 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 85 del self.encoder.head 86 87 elif backbone == "scalemae": 88 # Load the encoder state directly from a checkpoint. 89 encoder_state = torch.load(checkpoint)["model"] 90 encoder_state = OrderedDict({ 91 k: v for k, v in encoder_state.items() 92 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 93 }) 94 95 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 96 current_encoder_state = self.encoder.state_dict() 97 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 98 del self.encoder.head 99 100 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 101 del self.encoder.pos_embed 102 103 elif backbone == "dinov3": # Load the encoder state directly from a checkpoint. 104 encoder_state = torch.load(checkpoint) 105 106 else: 107 raise ValueError( 108 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 109 ) 110 111 else: 112 encoder_state = checkpoint 113 114 self.encoder.load_state_dict(encoder_state) 115 116 def __init__( 117 self, 118 img_size: int = 1024, 119 backbone: str = "sam", 120 encoder: Optional[Union[nn.Module, str]] = "vit_b", 121 decoder: Optional[nn.Module] = None, 122 out_channels: int = 1, 123 use_sam_stats: bool = False, 124 use_mae_stats: bool = False, 125 use_dino_stats: bool = False, 126 resize_input: bool = True, 127 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 128 final_activation: Optional[Union[str, nn.Module]] = None, 129 use_skip_connection: bool = True, 130 embed_dim: Optional[int] = None, 131 use_conv_transpose: bool = False, 132 **kwargs 133 ) -> None: 134 super().__init__() 135 136 self.use_sam_stats = use_sam_stats 137 self.use_mae_stats = use_mae_stats 138 self.use_dino_stats = use_dino_stats 139 self.use_skip_connection = use_skip_connection 140 self.resize_input = resize_input 141 142 if isinstance(encoder, str): # e.g. "vit_b" / "vit_l" / "vit_h" 143 print(f"Using {encoder} from {backbone.upper()}") 144 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 145 146 if encoder_checkpoint is not None: 147 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 148 149 if backbone == "sam2": 150 in_chans = self.encoder.trunk.patch_embed.proj.in_channels 151 else: 152 in_chans = self.encoder.in_chans 153 154 if embed_dim is None: 155 embed_dim = self.encoder.embed_dim 156 157 else: # `nn.Module` ViT backbone 158 self.encoder = encoder 159 160 have_neck = False 161 for name, _ in self.encoder.named_parameters(): 162 if name.startswith("neck"): 163 have_neck = True 164 165 if embed_dim is None: 166 if have_neck: 167 embed_dim = self.encoder.neck[2].out_channels # the value is 256 168 else: 169 embed_dim = self.encoder.patch_embed.proj.out_channels 170 171 try: 172 in_chans = self.encoder.patch_embed.proj.in_channels 173 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 174 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 175 176 # parameters for the decoder network 177 depth = 3 178 initial_features = 64 179 gain = 2 180 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 181 scale_factors = depth * [2] 182 self.out_channels = out_channels 183 184 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 185 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 186 187 if decoder is None: 188 self.decoder = Decoder( 189 features=features_decoder, 190 scale_factors=scale_factors[::-1], 191 conv_block_impl=ConvBlock2d, 192 sampler_impl=_upsampler, 193 ) 194 else: 195 self.decoder = decoder 196 197 if use_skip_connection: 198 self.deconv1 = Deconv2DBlock( 199 in_channels=embed_dim, 200 out_channels=features_decoder[0], 201 use_conv_transpose=use_conv_transpose, 202 ) 203 self.deconv2 = nn.Sequential( 204 Deconv2DBlock( 205 in_channels=embed_dim, 206 out_channels=features_decoder[0], 207 use_conv_transpose=use_conv_transpose, 208 ), 209 Deconv2DBlock( 210 in_channels=features_decoder[0], 211 out_channels=features_decoder[1], 212 use_conv_transpose=use_conv_transpose, 213 ) 214 ) 215 self.deconv3 = nn.Sequential( 216 Deconv2DBlock( 217 in_channels=embed_dim, 218 out_channels=features_decoder[0], 219 use_conv_transpose=use_conv_transpose, 220 ), 221 Deconv2DBlock( 222 in_channels=features_decoder[0], 223 out_channels=features_decoder[1], 224 use_conv_transpose=use_conv_transpose, 225 ), 226 Deconv2DBlock( 227 in_channels=features_decoder[1], 228 out_channels=features_decoder[2], 229 use_conv_transpose=use_conv_transpose, 230 ) 231 ) 232 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 233 else: 234 self.deconv1 = Deconv2DBlock( 235 in_channels=embed_dim, 236 out_channels=features_decoder[0], 237 use_conv_transpose=use_conv_transpose, 238 ) 239 self.deconv2 = Deconv2DBlock( 240 in_channels=features_decoder[0], 241 out_channels=features_decoder[1], 242 use_conv_transpose=use_conv_transpose, 243 ) 244 self.deconv3 = Deconv2DBlock( 245 in_channels=features_decoder[1], 246 out_channels=features_decoder[2], 247 use_conv_transpose=use_conv_transpose, 248 ) 249 self.deconv4 = Deconv2DBlock( 250 in_channels=features_decoder[2], 251 out_channels=features_decoder[3], 252 use_conv_transpose=use_conv_transpose, 253 ) 254 255 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 256 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 257 self.deconv_out = _upsampler( 258 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 259 ) 260 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 261 self.final_activation = self._get_activation(final_activation) 262 263 def _get_activation(self, activation): 264 return_activation = None 265 if activation is None: 266 return None 267 if isinstance(activation, nn.Module): 268 return activation 269 if isinstance(activation, str): 270 return_activation = getattr(nn, activation, None) 271 if return_activation is None: 272 raise ValueError(f"Invalid activation: {activation}") 273 274 return return_activation() 275 276 @staticmethod 277 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 278 """Compute the output size given input size and target long side length. 279 280 Args: 281 oldh: The input image height. 282 oldw: The input image width. 283 long_side_length: The longest side length for resizing. 284 285 Returns: 286 The new image height. 287 The new image width. 288 """ 289 scale = long_side_length * 1.0 / max(oldh, oldw) 290 newh, neww = oldh * scale, oldw * scale 291 neww = int(neww + 0.5) 292 newh = int(newh + 0.5) 293 return (newh, neww) 294 295 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 296 """Resize the image so that the longest side has the correct length. 297 298 Expects batched images with shape BxCxHxW and float format. 299 300 Args: 301 image: The input image. 302 303 Returns: 304 The resized image. 305 """ 306 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 307 return F.interpolate( 308 image, target_size, mode="bilinear", align_corners=False, antialias=True 309 ) 310 311 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 312 """@private 313 """ 314 device = x.device 315 316 if self.use_sam_stats: 317 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 318 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 319 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 320 raise NotImplementedError 321 elif self.use_dino_stats: 322 pixel_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to(device) 323 pixel_std = torch.Tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to(device) 324 else: 325 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 326 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 327 328 if self.resize_input: 329 x = self.resize_longest_side(x) 330 input_shape = x.shape[-2:] 331 332 x = (x - pixel_mean) / pixel_std 333 h, w = x.shape[-2:] 334 padh = self.encoder.img_size - h 335 padw = self.encoder.img_size - w 336 x = F.pad(x, (0, padw, 0, padh)) 337 return x, input_shape 338 339 def postprocess_masks( 340 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 341 ) -> torch.Tensor: 342 """@private 343 """ 344 masks = F.interpolate( 345 masks, 346 (self.encoder.img_size, self.encoder.img_size), 347 mode="bilinear", 348 align_corners=False, 349 ) 350 masks = masks[..., : input_size[0], : input_size[1]] 351 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 352 return masks 353 354 def forward(self, x: torch.Tensor) -> torch.Tensor: 355 """Apply the UNETR to the input data. 356 357 Args: 358 x: The input tensor. 359 360 Returns: 361 The UNETR output. 362 """ 363 original_shape = x.shape[-2:] 364 365 # Reshape the inputs to the shape expected by the encoder 366 # and normalize the inputs if normalization is part of the model. 367 x, input_shape = self.preprocess(x) 368 369 use_skip_connection = getattr(self, "use_skip_connection", True) 370 371 encoder_outputs = self.encoder(x) 372 373 if isinstance(encoder_outputs[-1], list): 374 # `encoder_outputs` can be arranged in only two forms: 375 # - either we only return the image embeddings 376 # - or, we return the image embeddings and the "list" of global attention layers 377 z12, from_encoder = encoder_outputs 378 else: 379 z12 = encoder_outputs 380 381 if use_skip_connection: 382 from_encoder = from_encoder[::-1] 383 z9 = self.deconv1(from_encoder[0]) 384 z6 = self.deconv2(from_encoder[1]) 385 z3 = self.deconv3(from_encoder[2]) 386 z0 = self.deconv4(x) 387 388 else: 389 z9 = self.deconv1(z12) 390 z6 = self.deconv2(z9) 391 z3 = self.deconv3(z6) 392 z0 = self.deconv4(z3) 393 394 updated_from_encoder = [z9, z6, z3] 395 396 x = self.base(z12) 397 x = self.decoder(x, encoder_inputs=updated_from_encoder) 398 x = self.deconv_out(x) 399 400 x = torch.cat([x, z0], dim=1) 401 x = self.decoder_head(x) 402 403 x = self.out_conv(x) 404 if self.final_activation is not None: 405 x = self.final_activation(x) 406 407 x = self.postprocess_masks(x, input_shape, original_shape) 408 return x 409 410 411# 412# ADDITIONAL FUNCTIONALITIES 413# 414 415 416class SingleDeconv2DBlock(nn.Module): 417 """@private 418 """ 419 def __init__(self, scale_factor, in_channels, out_channels): 420 super().__init__() 421 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 422 423 def forward(self, x): 424 return self.block(x) 425 426 427class SingleConv2DBlock(nn.Module): 428 """@private 429 """ 430 def __init__(self, in_channels, out_channels, kernel_size): 431 super().__init__() 432 self.block = nn.Conv2d( 433 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 434 ) 435 436 def forward(self, x): 437 return self.block(x) 438 439 440class Conv2DBlock(nn.Module): 441 """@private 442 """ 443 def __init__(self, in_channels, out_channels, kernel_size=3): 444 super().__init__() 445 self.block = nn.Sequential( 446 SingleConv2DBlock(in_channels, out_channels, kernel_size), 447 nn.BatchNorm2d(out_channels), 448 nn.ReLU(True) 449 ) 450 451 def forward(self, x): 452 return self.block(x) 453 454 455class Deconv2DBlock(nn.Module): 456 """@private 457 """ 458 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 459 super().__init__() 460 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 461 self.block = nn.Sequential( 462 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 463 SingleConv2DBlock(out_channels, out_channels, kernel_size), 464 nn.BatchNorm2d(out_channels), 465 nn.ReLU(True) 466 ) 467 468 def forward(self, x): 469 return self.block(x)
class
UNETR(torch.nn.modules.module.Module):
28class UNETR(nn.Module): 29 """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder. 30 31 Args: 32 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 33 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 34 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 35 decoder: The convolutional decoder. 36 out_channels: The number of output channels of the UNETR. 37 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 38 use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model. 39 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 40 resize_input: Whether to resize the input images to match `img_size`. 41 By default, it resizes the inputs to match the `img_size`. 42 encoder_checkpoint: Checkpoint for initializing the vision transformer. 43 Can either be a filepath or an already loaded checkpoint. 44 final_activation: The activation to apply to the UNETR output. 45 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 46 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 47 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 48 By default, it uses resampling for upsampling. 49 """ 50 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 51 """Function to load pretrained weights to the image encoder. 52 """ 53 if isinstance(checkpoint, str): 54 if backbone == "sam" and isinstance(encoder, str): 55 # If we have a SAM encoder, then we first try to load the full SAM Model 56 # (using micro_sam) and otherwise fall back on directly loading the encoder state 57 # from the checkpoint 58 try: 59 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 60 encoder_state = model.image_encoder.state_dict() 61 except Exception: 62 # Try loading the encoder state directly from a checkpoint. 63 encoder_state = torch.load(checkpoint, weights_only=False) 64 65 elif backbone == "sam2" and isinstance(encoder, str): 66 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 67 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 68 # from the checkpoint 69 try: 70 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 71 encoder_state = model.image_encoder.state_dict() 72 except Exception: 73 # Try loading the encoder state directly from a checkpoint. 74 encoder_state = torch.load(checkpoint, weights_only=False) 75 76 elif backbone == "mae": 77 # vit initialization hints from: 78 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 79 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 80 encoder_state = OrderedDict({ 81 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 82 }) 83 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 84 current_encoder_state = self.encoder.state_dict() 85 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 86 del self.encoder.head 87 88 elif backbone == "scalemae": 89 # Load the encoder state directly from a checkpoint. 90 encoder_state = torch.load(checkpoint)["model"] 91 encoder_state = OrderedDict({ 92 k: v for k, v in encoder_state.items() 93 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 94 }) 95 96 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 97 current_encoder_state = self.encoder.state_dict() 98 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 99 del self.encoder.head 100 101 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 102 del self.encoder.pos_embed 103 104 elif backbone == "dinov3": # Load the encoder state directly from a checkpoint. 105 encoder_state = torch.load(checkpoint) 106 107 else: 108 raise ValueError( 109 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 110 ) 111 112 else: 113 encoder_state = checkpoint 114 115 self.encoder.load_state_dict(encoder_state) 116 117 def __init__( 118 self, 119 img_size: int = 1024, 120 backbone: str = "sam", 121 encoder: Optional[Union[nn.Module, str]] = "vit_b", 122 decoder: Optional[nn.Module] = None, 123 out_channels: int = 1, 124 use_sam_stats: bool = False, 125 use_mae_stats: bool = False, 126 use_dino_stats: bool = False, 127 resize_input: bool = True, 128 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 129 final_activation: Optional[Union[str, nn.Module]] = None, 130 use_skip_connection: bool = True, 131 embed_dim: Optional[int] = None, 132 use_conv_transpose: bool = False, 133 **kwargs 134 ) -> None: 135 super().__init__() 136 137 self.use_sam_stats = use_sam_stats 138 self.use_mae_stats = use_mae_stats 139 self.use_dino_stats = use_dino_stats 140 self.use_skip_connection = use_skip_connection 141 self.resize_input = resize_input 142 143 if isinstance(encoder, str): # e.g. "vit_b" / "vit_l" / "vit_h" 144 print(f"Using {encoder} from {backbone.upper()}") 145 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 146 147 if encoder_checkpoint is not None: 148 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 149 150 if backbone == "sam2": 151 in_chans = self.encoder.trunk.patch_embed.proj.in_channels 152 else: 153 in_chans = self.encoder.in_chans 154 155 if embed_dim is None: 156 embed_dim = self.encoder.embed_dim 157 158 else: # `nn.Module` ViT backbone 159 self.encoder = encoder 160 161 have_neck = False 162 for name, _ in self.encoder.named_parameters(): 163 if name.startswith("neck"): 164 have_neck = True 165 166 if embed_dim is None: 167 if have_neck: 168 embed_dim = self.encoder.neck[2].out_channels # the value is 256 169 else: 170 embed_dim = self.encoder.patch_embed.proj.out_channels 171 172 try: 173 in_chans = self.encoder.patch_embed.proj.in_channels 174 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 175 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 176 177 # parameters for the decoder network 178 depth = 3 179 initial_features = 64 180 gain = 2 181 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 182 scale_factors = depth * [2] 183 self.out_channels = out_channels 184 185 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 186 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 187 188 if decoder is None: 189 self.decoder = Decoder( 190 features=features_decoder, 191 scale_factors=scale_factors[::-1], 192 conv_block_impl=ConvBlock2d, 193 sampler_impl=_upsampler, 194 ) 195 else: 196 self.decoder = decoder 197 198 if use_skip_connection: 199 self.deconv1 = Deconv2DBlock( 200 in_channels=embed_dim, 201 out_channels=features_decoder[0], 202 use_conv_transpose=use_conv_transpose, 203 ) 204 self.deconv2 = nn.Sequential( 205 Deconv2DBlock( 206 in_channels=embed_dim, 207 out_channels=features_decoder[0], 208 use_conv_transpose=use_conv_transpose, 209 ), 210 Deconv2DBlock( 211 in_channels=features_decoder[0], 212 out_channels=features_decoder[1], 213 use_conv_transpose=use_conv_transpose, 214 ) 215 ) 216 self.deconv3 = nn.Sequential( 217 Deconv2DBlock( 218 in_channels=embed_dim, 219 out_channels=features_decoder[0], 220 use_conv_transpose=use_conv_transpose, 221 ), 222 Deconv2DBlock( 223 in_channels=features_decoder[0], 224 out_channels=features_decoder[1], 225 use_conv_transpose=use_conv_transpose, 226 ), 227 Deconv2DBlock( 228 in_channels=features_decoder[1], 229 out_channels=features_decoder[2], 230 use_conv_transpose=use_conv_transpose, 231 ) 232 ) 233 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 234 else: 235 self.deconv1 = Deconv2DBlock( 236 in_channels=embed_dim, 237 out_channels=features_decoder[0], 238 use_conv_transpose=use_conv_transpose, 239 ) 240 self.deconv2 = Deconv2DBlock( 241 in_channels=features_decoder[0], 242 out_channels=features_decoder[1], 243 use_conv_transpose=use_conv_transpose, 244 ) 245 self.deconv3 = Deconv2DBlock( 246 in_channels=features_decoder[1], 247 out_channels=features_decoder[2], 248 use_conv_transpose=use_conv_transpose, 249 ) 250 self.deconv4 = Deconv2DBlock( 251 in_channels=features_decoder[2], 252 out_channels=features_decoder[3], 253 use_conv_transpose=use_conv_transpose, 254 ) 255 256 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 257 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 258 self.deconv_out = _upsampler( 259 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 260 ) 261 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 262 self.final_activation = self._get_activation(final_activation) 263 264 def _get_activation(self, activation): 265 return_activation = None 266 if activation is None: 267 return None 268 if isinstance(activation, nn.Module): 269 return activation 270 if isinstance(activation, str): 271 return_activation = getattr(nn, activation, None) 272 if return_activation is None: 273 raise ValueError(f"Invalid activation: {activation}") 274 275 return return_activation() 276 277 @staticmethod 278 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 279 """Compute the output size given input size and target long side length. 280 281 Args: 282 oldh: The input image height. 283 oldw: The input image width. 284 long_side_length: The longest side length for resizing. 285 286 Returns: 287 The new image height. 288 The new image width. 289 """ 290 scale = long_side_length * 1.0 / max(oldh, oldw) 291 newh, neww = oldh * scale, oldw * scale 292 neww = int(neww + 0.5) 293 newh = int(newh + 0.5) 294 return (newh, neww) 295 296 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 297 """Resize the image so that the longest side has the correct length. 298 299 Expects batched images with shape BxCxHxW and float format. 300 301 Args: 302 image: The input image. 303 304 Returns: 305 The resized image. 306 """ 307 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 308 return F.interpolate( 309 image, target_size, mode="bilinear", align_corners=False, antialias=True 310 ) 311 312 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 313 """@private 314 """ 315 device = x.device 316 317 if self.use_sam_stats: 318 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 319 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 320 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 321 raise NotImplementedError 322 elif self.use_dino_stats: 323 pixel_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to(device) 324 pixel_std = torch.Tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to(device) 325 else: 326 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 327 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 328 329 if self.resize_input: 330 x = self.resize_longest_side(x) 331 input_shape = x.shape[-2:] 332 333 x = (x - pixel_mean) / pixel_std 334 h, w = x.shape[-2:] 335 padh = self.encoder.img_size - h 336 padw = self.encoder.img_size - w 337 x = F.pad(x, (0, padw, 0, padh)) 338 return x, input_shape 339 340 def postprocess_masks( 341 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 342 ) -> torch.Tensor: 343 """@private 344 """ 345 masks = F.interpolate( 346 masks, 347 (self.encoder.img_size, self.encoder.img_size), 348 mode="bilinear", 349 align_corners=False, 350 ) 351 masks = masks[..., : input_size[0], : input_size[1]] 352 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 353 return masks 354 355 def forward(self, x: torch.Tensor) -> torch.Tensor: 356 """Apply the UNETR to the input data. 357 358 Args: 359 x: The input tensor. 360 361 Returns: 362 The UNETR output. 363 """ 364 original_shape = x.shape[-2:] 365 366 # Reshape the inputs to the shape expected by the encoder 367 # and normalize the inputs if normalization is part of the model. 368 x, input_shape = self.preprocess(x) 369 370 use_skip_connection = getattr(self, "use_skip_connection", True) 371 372 encoder_outputs = self.encoder(x) 373 374 if isinstance(encoder_outputs[-1], list): 375 # `encoder_outputs` can be arranged in only two forms: 376 # - either we only return the image embeddings 377 # - or, we return the image embeddings and the "list" of global attention layers 378 z12, from_encoder = encoder_outputs 379 else: 380 z12 = encoder_outputs 381 382 if use_skip_connection: 383 from_encoder = from_encoder[::-1] 384 z9 = self.deconv1(from_encoder[0]) 385 z6 = self.deconv2(from_encoder[1]) 386 z3 = self.deconv3(from_encoder[2]) 387 z0 = self.deconv4(x) 388 389 else: 390 z9 = self.deconv1(z12) 391 z6 = self.deconv2(z9) 392 z3 = self.deconv3(z6) 393 z0 = self.deconv4(z3) 394 395 updated_from_encoder = [z9, z6, z3] 396 397 x = self.base(z12) 398 x = self.decoder(x, encoder_inputs=updated_from_encoder) 399 x = self.deconv_out(x) 400 401 x = torch.cat([x, z0], dim=1) 402 x = self.decoder_head(x) 403 404 x = self.out_conv(x) 405 if self.final_activation is not None: 406 x = self.final_activation(x) 407 408 x = self.postprocess_masks(x, input_shape, original_shape) 409 return x
A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam" or "mae".
- encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
- use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size. By default, it resizes the inputs to match theimg_size. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
UNETR( img_size: int = 1024, backbone: str = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, use_dino_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
117 def __init__( 118 self, 119 img_size: int = 1024, 120 backbone: str = "sam", 121 encoder: Optional[Union[nn.Module, str]] = "vit_b", 122 decoder: Optional[nn.Module] = None, 123 out_channels: int = 1, 124 use_sam_stats: bool = False, 125 use_mae_stats: bool = False, 126 use_dino_stats: bool = False, 127 resize_input: bool = True, 128 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 129 final_activation: Optional[Union[str, nn.Module]] = None, 130 use_skip_connection: bool = True, 131 embed_dim: Optional[int] = None, 132 use_conv_transpose: bool = False, 133 **kwargs 134 ) -> None: 135 super().__init__() 136 137 self.use_sam_stats = use_sam_stats 138 self.use_mae_stats = use_mae_stats 139 self.use_dino_stats = use_dino_stats 140 self.use_skip_connection = use_skip_connection 141 self.resize_input = resize_input 142 143 if isinstance(encoder, str): # e.g. "vit_b" / "vit_l" / "vit_h" 144 print(f"Using {encoder} from {backbone.upper()}") 145 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 146 147 if encoder_checkpoint is not None: 148 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 149 150 if backbone == "sam2": 151 in_chans = self.encoder.trunk.patch_embed.proj.in_channels 152 else: 153 in_chans = self.encoder.in_chans 154 155 if embed_dim is None: 156 embed_dim = self.encoder.embed_dim 157 158 else: # `nn.Module` ViT backbone 159 self.encoder = encoder 160 161 have_neck = False 162 for name, _ in self.encoder.named_parameters(): 163 if name.startswith("neck"): 164 have_neck = True 165 166 if embed_dim is None: 167 if have_neck: 168 embed_dim = self.encoder.neck[2].out_channels # the value is 256 169 else: 170 embed_dim = self.encoder.patch_embed.proj.out_channels 171 172 try: 173 in_chans = self.encoder.patch_embed.proj.in_channels 174 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 175 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 176 177 # parameters for the decoder network 178 depth = 3 179 initial_features = 64 180 gain = 2 181 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 182 scale_factors = depth * [2] 183 self.out_channels = out_channels 184 185 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 186 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 187 188 if decoder is None: 189 self.decoder = Decoder( 190 features=features_decoder, 191 scale_factors=scale_factors[::-1], 192 conv_block_impl=ConvBlock2d, 193 sampler_impl=_upsampler, 194 ) 195 else: 196 self.decoder = decoder 197 198 if use_skip_connection: 199 self.deconv1 = Deconv2DBlock( 200 in_channels=embed_dim, 201 out_channels=features_decoder[0], 202 use_conv_transpose=use_conv_transpose, 203 ) 204 self.deconv2 = nn.Sequential( 205 Deconv2DBlock( 206 in_channels=embed_dim, 207 out_channels=features_decoder[0], 208 use_conv_transpose=use_conv_transpose, 209 ), 210 Deconv2DBlock( 211 in_channels=features_decoder[0], 212 out_channels=features_decoder[1], 213 use_conv_transpose=use_conv_transpose, 214 ) 215 ) 216 self.deconv3 = nn.Sequential( 217 Deconv2DBlock( 218 in_channels=embed_dim, 219 out_channels=features_decoder[0], 220 use_conv_transpose=use_conv_transpose, 221 ), 222 Deconv2DBlock( 223 in_channels=features_decoder[0], 224 out_channels=features_decoder[1], 225 use_conv_transpose=use_conv_transpose, 226 ), 227 Deconv2DBlock( 228 in_channels=features_decoder[1], 229 out_channels=features_decoder[2], 230 use_conv_transpose=use_conv_transpose, 231 ) 232 ) 233 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 234 else: 235 self.deconv1 = Deconv2DBlock( 236 in_channels=embed_dim, 237 out_channels=features_decoder[0], 238 use_conv_transpose=use_conv_transpose, 239 ) 240 self.deconv2 = Deconv2DBlock( 241 in_channels=features_decoder[0], 242 out_channels=features_decoder[1], 243 use_conv_transpose=use_conv_transpose, 244 ) 245 self.deconv3 = Deconv2DBlock( 246 in_channels=features_decoder[1], 247 out_channels=features_decoder[2], 248 use_conv_transpose=use_conv_transpose, 249 ) 250 self.deconv4 = Deconv2DBlock( 251 in_channels=features_decoder[2], 252 out_channels=features_decoder[3], 253 use_conv_transpose=use_conv_transpose, 254 ) 255 256 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 257 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 258 self.deconv_out = _upsampler( 259 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 260 ) 261 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 262 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
@staticmethod
def
get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
277 @staticmethod 278 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 279 """Compute the output size given input size and target long side length. 280 281 Args: 282 oldh: The input image height. 283 oldw: The input image width. 284 long_side_length: The longest side length for resizing. 285 286 Returns: 287 The new image height. 288 The new image width. 289 """ 290 scale = long_side_length * 1.0 / max(oldh, oldw) 291 newh, neww = oldh * scale, oldw * scale 292 neww = int(neww + 0.5) 293 newh = int(newh + 0.5) 294 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
def
resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
296 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 297 """Resize the image so that the longest side has the correct length. 298 299 Expects batched images with shape BxCxHxW and float format. 300 301 Args: 302 image: The input image. 303 304 Returns: 305 The resized image. 306 """ 307 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 308 return F.interpolate( 309 image, target_size, mode="bilinear", align_corners=False, antialias=True 310 )
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
def
forward(self, x: torch.Tensor) -> torch.Tensor:
355 def forward(self, x: torch.Tensor) -> torch.Tensor: 356 """Apply the UNETR to the input data. 357 358 Args: 359 x: The input tensor. 360 361 Returns: 362 The UNETR output. 363 """ 364 original_shape = x.shape[-2:] 365 366 # Reshape the inputs to the shape expected by the encoder 367 # and normalize the inputs if normalization is part of the model. 368 x, input_shape = self.preprocess(x) 369 370 use_skip_connection = getattr(self, "use_skip_connection", True) 371 372 encoder_outputs = self.encoder(x) 373 374 if isinstance(encoder_outputs[-1], list): 375 # `encoder_outputs` can be arranged in only two forms: 376 # - either we only return the image embeddings 377 # - or, we return the image embeddings and the "list" of global attention layers 378 z12, from_encoder = encoder_outputs 379 else: 380 z12 = encoder_outputs 381 382 if use_skip_connection: 383 from_encoder = from_encoder[::-1] 384 z9 = self.deconv1(from_encoder[0]) 385 z6 = self.deconv2(from_encoder[1]) 386 z3 = self.deconv3(from_encoder[2]) 387 z0 = self.deconv4(x) 388 389 else: 390 z9 = self.deconv1(z12) 391 z6 = self.deconv2(z9) 392 z3 = self.deconv3(z6) 393 z0 = self.deconv4(z3) 394 395 updated_from_encoder = [z9, z6, z3] 396 397 x = self.base(z12) 398 x = self.decoder(x, encoder_inputs=updated_from_encoder) 399 x = self.deconv_out(x) 400 401 x = torch.cat([x, z0], dim=1) 402 x = self.decoder_head(x) 403 404 x = self.out_conv(x) 405 if self.final_activation is not None: 406 x = self.final_activation(x) 407 408 x = self.postprocess_masks(x, input_shape, original_shape) 409 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.