torch_em.model.unetr
1from functools import partial 2from collections import OrderedDict 3from typing import Optional, Tuple, Union, Literal 4 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8 9from .vit import get_vision_transformer 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs 11 12try: 13 from micro_sam.util import get_sam_model 14except ImportError: 15 get_sam_model = None 16 17try: 18 from micro_sam2.util import get_sam2_model 19except ImportError: 20 get_sam2_model = None 21 22 23# 24# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / MAE / ScaleMAE) + UNet Decoder from `torch_em`] 25# 26 27 28class UNETRBase(nn.Module): 29 """Base class for implementing a UNETR. 30 31 Args: 32 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 33 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 34 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 35 decoder: The convolutional decoder. 36 out_channels: The number of output channels of the UNETR. 37 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 38 use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model. 39 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 40 resize_input: Whether to resize the input images to match `img_size`. 41 By default, it resizes the inputs to match the `img_size`. 42 encoder_checkpoint: Checkpoint for initializing the vision transformer. 43 Can either be a filepath or an already loaded checkpoint. 44 final_activation: The activation to apply to the UNETR output. 45 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 46 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 47 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 48 By default, it uses resampling for upsampling. 49 """ 50 def __init__( 51 self, 52 img_size: int = 1024, 53 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2", 54 encoder: Optional[Union[nn.Module, str]] = "vit_b", 55 decoder: Optional[nn.Module] = None, 56 out_channels: int = 1, 57 use_sam_stats: bool = False, 58 use_mae_stats: bool = False, 59 use_dino_stats: bool = False, 60 resize_input: bool = True, 61 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 62 final_activation: Optional[Union[str, nn.Module]] = None, 63 use_skip_connection: bool = True, 64 embed_dim: Optional[int] = None, 65 use_conv_transpose: bool = False, 66 **kwargs 67 ) -> None: 68 super().__init__() 69 70 self.img_size = img_size 71 self.use_sam_stats = use_sam_stats 72 self.use_mae_stats = use_mae_stats 73 self.use_dino_stats = use_dino_stats 74 self.use_skip_connection = use_skip_connection 75 self.resize_input = resize_input 76 self.use_conv_transpose = use_conv_transpose 77 self.backbone = backbone 78 79 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" 80 print(f"Using {encoder} from {backbone.upper()}") 81 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 82 83 if encoder_checkpoint is not None: 84 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 85 86 if embed_dim is None: 87 embed_dim = self.encoder.embed_dim 88 89 else: # `nn.Module` ViT backbone 90 self.encoder = encoder 91 92 have_neck = False 93 for name, _ in self.encoder.named_parameters(): 94 if name.startswith("neck"): 95 have_neck = True 96 97 if embed_dim is None: 98 if have_neck: 99 embed_dim = self.encoder.neck[2].out_channels # the value is 256 100 else: 101 embed_dim = self.encoder.patch_embed.proj.out_channels 102 103 self.embed_dim = embed_dim 104 self.final_activation = self._get_activation(final_activation) 105 106 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 107 """Function to load pretrained weights to the image encoder. 108 """ 109 if isinstance(checkpoint, str): 110 if backbone == "sam" and isinstance(encoder, str): 111 # If we have a SAM encoder, then we first try to load the full SAM Model 112 # (using micro_sam) and otherwise fall back on directly loading the encoder state 113 # from the checkpoint 114 try: 115 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 116 encoder_state = model.image_encoder.state_dict() 117 except Exception: 118 # Try loading the encoder state directly from a checkpoint. 119 encoder_state = torch.load(checkpoint, weights_only=False) 120 121 elif backbone == "sam2" and isinstance(encoder, str): 122 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 123 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 124 # from the checkpoint 125 try: 126 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 127 encoder_state = model.image_encoder.state_dict() 128 except Exception: 129 # Try loading the encoder state directly from a checkpoint. 130 encoder_state = torch.load(checkpoint, weights_only=False) 131 132 elif backbone == "mae": 133 # vit initialization hints from: 134 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 135 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 136 encoder_state = OrderedDict({ 137 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 138 }) 139 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 140 current_encoder_state = self.encoder.state_dict() 141 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 142 del self.encoder.head 143 144 elif backbone == "scalemae": 145 # Load the encoder state directly from a checkpoint. 146 encoder_state = torch.load(checkpoint)["model"] 147 encoder_state = OrderedDict({ 148 k: v for k, v in encoder_state.items() 149 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 150 }) 151 152 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 153 current_encoder_state = self.encoder.state_dict() 154 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 155 del self.encoder.head 156 157 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 158 del self.encoder.pos_embed 159 160 elif backbone == "dinov3": # Load the encoder state directly from a checkpoint. 161 encoder_state = torch.load(checkpoint) 162 163 else: 164 raise ValueError( 165 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 166 ) 167 168 else: 169 encoder_state = checkpoint 170 171 self.encoder.load_state_dict(encoder_state) 172 173 def _get_activation(self, activation): 174 return_activation = None 175 if activation is None: 176 return None 177 if isinstance(activation, nn.Module): 178 return activation 179 if isinstance(activation, str): 180 return_activation = getattr(nn, activation, None) 181 if return_activation is None: 182 raise ValueError(f"Invalid activation: {activation}") 183 184 return return_activation() 185 186 @staticmethod 187 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 188 """Compute the output size given input size and target long side length. 189 190 Args: 191 oldh: The input image height. 192 oldw: The input image width. 193 long_side_length: The longest side length for resizing. 194 195 Returns: 196 The new image height. 197 The new image width. 198 """ 199 scale = long_side_length * 1.0 / max(oldh, oldw) 200 newh, neww = oldh * scale, oldw * scale 201 neww = int(neww + 0.5) 202 newh = int(newh + 0.5) 203 return (newh, neww) 204 205 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 206 """Resize the image so that the longest side has the correct length. 207 208 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 209 210 Args: 211 image: The input image. 212 213 Returns: 214 The resized image. 215 """ 216 if image.ndim == 4: # i.e. 2d image 217 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 218 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 219 elif image.ndim == 5: # i.e. 3d volume 220 B, C, Z, H, W = image.shape 221 target_size = self.get_preprocess_shape(H, W, self.img_size) 222 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 223 else: 224 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 225 226 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 227 """@private 228 """ 229 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 230 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 231 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 232 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 233 return pixel_mean, pixel_std 234 235 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 236 """@private 237 """ 238 device = x.device 239 is_3d = (x.ndim == 5) 240 device, dtype = x.device, x.dtype 241 242 if self.use_sam_stats: 243 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 244 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 245 raise NotImplementedError 246 elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"): 247 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 248 else: 249 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 250 251 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 252 253 if self.resize_input: 254 x = self.resize_longest_side(x) 255 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 256 257 x = (x - pixel_mean) / pixel_std 258 h, w = x.shape[-2:] 259 padh = self.encoder.img_size - h 260 padw = self.encoder.img_size - w 261 262 if is_3d: 263 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 264 else: 265 x = F.pad(x, (0, padw, 0, padh)) 266 267 return x, input_shape 268 269 def postprocess_masks( 270 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 271 ) -> torch.Tensor: 272 """@private 273 """ 274 if masks.ndim == 4: # i.e. 2d labels 275 masks = F.interpolate( 276 masks, 277 (self.encoder.img_size, self.encoder.img_size), 278 mode="bilinear", 279 align_corners=False, 280 ) 281 masks = masks[..., : input_size[0], : input_size[1]] 282 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 283 284 elif masks.ndim == 5: # i.e. 3d volumetric labels 285 masks = F.interpolate( 286 masks, 287 (input_size[0], self.img_size, self.img_size), 288 mode="trilinear", 289 align_corners=False, 290 ) 291 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 292 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 293 294 else: 295 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 296 297 return masks 298 299 300class UNETR(UNETRBase): 301 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 302 """ 303 def __init__( 304 self, 305 img_size: int = 1024, 306 backbone: str = "sam", 307 encoder: Optional[Union[nn.Module, str]] = "vit_b", 308 decoder: Optional[nn.Module] = None, 309 out_channels: int = 1, 310 use_sam_stats: bool = False, 311 use_mae_stats: bool = False, 312 use_dino_stats: bool = False, 313 resize_input: bool = True, 314 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 315 final_activation: Optional[Union[str, nn.Module]] = None, 316 use_skip_connection: bool = True, 317 embed_dim: Optional[int] = None, 318 use_conv_transpose: bool = False, 319 **kwargs 320 ) -> None: 321 322 super().__init__( 323 img_size=img_size, 324 backbone=backbone, 325 encoder=encoder, 326 decoder=decoder, 327 out_channels=out_channels, 328 use_sam_stats=use_sam_stats, 329 use_mae_stats=use_mae_stats, 330 use_dino_stats=use_dino_stats, 331 resize_input=resize_input, 332 encoder_checkpoint=encoder_checkpoint, 333 final_activation=final_activation, 334 use_skip_connection=use_skip_connection, 335 embed_dim=embed_dim, 336 use_conv_transpose=use_conv_transpose, 337 **kwargs, 338 ) 339 340 encoder = self.encoder 341 342 if backbone == "sam2" and hasattr(encoder, "trunk"): 343 in_chans = encoder.trunk.patch_embed.proj.in_channels 344 elif hasattr(encoder, "in_chans"): 345 in_chans = encoder.in_chans 346 else: # `nn.Module` ViT backbone. 347 try: 348 in_chans = encoder.patch_embed.proj.in_channels 349 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 350 in_chans = encoder.patch_embed.seq[0].c.in_channels 351 352 # parameters for the decoder network 353 depth = 3 354 initial_features = 64 355 gain = 2 356 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 357 scale_factors = depth * [2] 358 self.out_channels = out_channels 359 360 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 361 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 362 363 self.decoder = decoder or Decoder( 364 features=features_decoder, 365 scale_factors=scale_factors[::-1], 366 conv_block_impl=ConvBlock2d, 367 sampler_impl=_upsampler, 368 ) 369 370 if use_skip_connection: 371 self.deconv1 = Deconv2DBlock( 372 in_channels=self.embed_dim, 373 out_channels=features_decoder[0], 374 use_conv_transpose=use_conv_transpose, 375 ) 376 self.deconv2 = nn.Sequential( 377 Deconv2DBlock( 378 in_channels=self.embed_dim, 379 out_channels=features_decoder[0], 380 use_conv_transpose=use_conv_transpose, 381 ), 382 Deconv2DBlock( 383 in_channels=features_decoder[0], 384 out_channels=features_decoder[1], 385 use_conv_transpose=use_conv_transpose, 386 ) 387 ) 388 self.deconv3 = nn.Sequential( 389 Deconv2DBlock( 390 in_channels=self.embed_dim, 391 out_channels=features_decoder[0], 392 use_conv_transpose=use_conv_transpose, 393 ), 394 Deconv2DBlock( 395 in_channels=features_decoder[0], 396 out_channels=features_decoder[1], 397 use_conv_transpose=use_conv_transpose, 398 ), 399 Deconv2DBlock( 400 in_channels=features_decoder[1], 401 out_channels=features_decoder[2], 402 use_conv_transpose=use_conv_transpose, 403 ) 404 ) 405 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 406 else: 407 self.deconv1 = Deconv2DBlock( 408 in_channels=self.embed_dim, 409 out_channels=features_decoder[0], 410 use_conv_transpose=use_conv_transpose, 411 ) 412 self.deconv2 = Deconv2DBlock( 413 in_channels=features_decoder[0], 414 out_channels=features_decoder[1], 415 use_conv_transpose=use_conv_transpose, 416 ) 417 self.deconv3 = Deconv2DBlock( 418 in_channels=features_decoder[1], 419 out_channels=features_decoder[2], 420 use_conv_transpose=use_conv_transpose, 421 ) 422 self.deconv4 = Deconv2DBlock( 423 in_channels=features_decoder[2], 424 out_channels=features_decoder[3], 425 use_conv_transpose=use_conv_transpose, 426 ) 427 428 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 429 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 430 self.deconv_out = _upsampler( 431 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 432 ) 433 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 434 435 def forward(self, x: torch.Tensor) -> torch.Tensor: 436 """Apply the UNETR to the input data. 437 438 Args: 439 x: The input tensor. 440 441 Returns: 442 The UNETR output. 443 """ 444 original_shape = x.shape[-2:] 445 446 # Reshape the inputs to the shape expected by the encoder 447 # and normalize the inputs if normalization is part of the model. 448 x, input_shape = self.preprocess(x) 449 450 encoder_outputs = self.encoder(x) 451 452 if isinstance(encoder_outputs[-1], list): 453 # `encoder_outputs` can be arranged in only two forms: 454 # - either we only return the image embeddings 455 # - or, we return the image embeddings and the "list" of global attention layers 456 z12, from_encoder = encoder_outputs 457 else: 458 z12 = encoder_outputs 459 460 if self.use_skip_connection: 461 from_encoder = from_encoder[::-1] 462 z9 = self.deconv1(from_encoder[0]) 463 z6 = self.deconv2(from_encoder[1]) 464 z3 = self.deconv3(from_encoder[2]) 465 z0 = self.deconv4(x) 466 467 else: 468 z9 = self.deconv1(z12) 469 z6 = self.deconv2(z9) 470 z3 = self.deconv3(z6) 471 z0 = self.deconv4(z3) 472 473 updated_from_encoder = [z9, z6, z3] 474 475 x = self.base(z12) 476 x = self.decoder(x, encoder_inputs=updated_from_encoder) 477 x = self.deconv_out(x) 478 479 x = torch.cat([x, z0], dim=1) 480 x = self.decoder_head(x) 481 482 x = self.out_conv(x) 483 if self.final_activation is not None: 484 x = self.final_activation(x) 485 486 x = self.postprocess_masks(x, input_shape, original_shape) 487 return x 488 489 490class UNETR2D(UNETR): 491 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 492 """ 493 pass 494 495 496class UNETR3D(UNETRBase): 497 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 498 """ 499 def __init__( 500 self, 501 img_size: int = 1024, 502 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2", 503 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 504 decoder: Optional[nn.Module] = None, 505 out_channels: int = 1, 506 use_sam_stats: bool = False, 507 use_mae_stats: bool = False, 508 use_dino_stats: bool = False, 509 resize_input: bool = True, 510 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 511 final_activation: Optional[Union[str, nn.Module]] = None, 512 use_skip_connection: bool = False, 513 embed_dim: Optional[int] = None, 514 use_conv_transpose: bool = False, 515 use_strip_pooling: bool = True, 516 **kwargs 517 ): 518 if use_skip_connection: 519 raise NotImplementedError("The framework cannot handle skip connections atm.") 520 if use_conv_transpose: 521 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 522 523 super().__init__( 524 img_size=img_size, 525 backbone=backbone, 526 encoder=encoder, 527 decoder=decoder, 528 out_channels=out_channels, 529 use_sam_stats=use_sam_stats, 530 use_mae_stats=use_mae_stats, 531 use_dino_stats=use_dino_stats, 532 resize_input=resize_input, 533 encoder_checkpoint=encoder_checkpoint, 534 final_activation=final_activation, 535 use_skip_connection=use_skip_connection, 536 embed_dim=embed_dim, 537 use_conv_transpose=use_conv_transpose, 538 **kwargs, 539 ) 540 541 # Load the pretrained image encoder weights 542 self.image_encoder = self.encoder 543 544 # Step 2: the 3d convolutional decoder. 545 # First, get the important parameters for the decoder. 546 embed_dim = 256 547 depth = 3 548 initial_features = 64 549 gain = 2 550 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 551 scale_factors = [1, 2, 2] 552 self.out_channels = out_channels 553 554 # The mapping blocks. 555 self.deconv1 = Deconv3DBlock( 556 in_channels=embed_dim, 557 out_channels=features_decoder[0], 558 scale_factor=scale_factors, 559 use_strip_pooling=use_strip_pooling, 560 ) 561 self.deconv2 = Deconv3DBlock( 562 in_channels=features_decoder[0], 563 out_channels=features_decoder[1], 564 scale_factor=scale_factors, 565 use_strip_pooling=use_strip_pooling, 566 ) 567 self.deconv3 = Deconv3DBlock( 568 in_channels=features_decoder[1], 569 out_channels=features_decoder[2], 570 scale_factor=scale_factors, 571 use_strip_pooling=use_strip_pooling, 572 ) 573 self.deconv4 = Deconv3DBlock( 574 in_channels=features_decoder[2], 575 out_channels=features_decoder[3], 576 scale_factor=scale_factors, 577 use_strip_pooling=use_strip_pooling, 578 ) 579 580 # The core decoder block. 581 self.decoder = decoder or Decoder( 582 features=features_decoder, 583 scale_factors=[scale_factors] * depth, 584 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 585 sampler_impl=Upsampler3d, 586 ) 587 588 # And the final upsampler to match the expected dimensions. 589 self.end_up = Deconv3DBlock( 590 in_channels=features_decoder[-1], 591 out_channels=features_decoder[-1], 592 scale_factor=scale_factors, 593 use_strip_pooling=use_strip_pooling, 594 ) 595 596 # Additional conjunction blocks. 597 self.base = ConvBlock3dWithStrip( 598 in_channels=embed_dim, 599 out_channels=features_decoder[0], 600 use_strip_pooling=use_strip_pooling, 601 ) 602 603 # And the output layers. 604 self.decoder_head = ConvBlock3dWithStrip( 605 in_channels=2 * features_decoder[-1], 606 out_channels=features_decoder[-1], 607 use_strip_pooling=use_strip_pooling, 608 ) 609 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 610 611 def forward(self, x: torch.Tensor): 612 """Forward pass of the UNETR-3D model. 613 614 Args: 615 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 616 617 Returns: 618 The UNETR output. 619 """ 620 B, C, Z, H, W = x.shape 621 original_shape = (Z, H, W) 622 623 # Preprocessing step 624 x, input_shape = self.preprocess(x) 625 626 # Run the image encoder. 627 curr_features = torch.stack([self.image_encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 628 629 # Prepare the counterparts for the decoder. 630 # NOTE: The section below is sequential, there's no skip connections atm. 631 z9 = self.deconv1(curr_features) 632 z6 = self.deconv2(z9) 633 z3 = self.deconv3(z6) 634 z0 = self.deconv4(z3) 635 636 # Align the features through the base block. 637 x = self.base(curr_features) 638 639 # Run the decoder. 640 updated_from_encoder = [z9, z6, z3] 641 x = self.decoder(x, encoder_inputs=updated_from_encoder) 642 x = self.end_up(x) 643 644 # And the final output head. 645 x = torch.cat([x, z0], dim=1) 646 x = self.decoder_head(x) 647 x = self.out_conv(x) 648 if self.final_activation is not None: 649 x = self.final_activation(x) 650 651 # Postprocess the output back to original size. 652 x = self.postprocess_masks(x, input_shape, original_shape) 653 return x 654 655# 656# ADDITIONAL FUNCTIONALITIES 657# 658 659 660def _strip_pooling_layers(enabled, channels) -> nn.Module: 661 return DepthStripPooling(channels) if enabled else nn.Identity() 662 663 664class DepthStripPooling(nn.Module): 665 """@private 666 """ 667 def __init__(self, channels: int, reduction: int = 4): 668 """Block for strip pooling along the depth dimension (only). 669 670 eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling 671 to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to 672 modulate the original features (using a gated residual). 673 674 For 2D (Z == 1 or Z == 3): returns input unchanged (no-op). 675 676 Args: 677 channels: The output channels. 678 reduction: The reduction of the hidden layers. 679 """ 680 super().__init__() 681 hidden = max(1, channels // reduction) 682 self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1) 683 self.bn1 = nn.BatchNorm3d(hidden) 684 self.relu = nn.ReLU(inplace=True) 685 self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1) 686 687 def forward(self, x: torch.Tensor) -> torch.Tensor: 688 if x.dim() != 5: 689 raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.") 690 691 B, C, Z, H, W = x.shape 692 if Z == 1 or Z == 3: # i.e. 2d-as-1-slice or RGB_2d-as-1-slice. 693 return x # We simply do nothing there. 694 695 # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W) 696 feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W)) 697 feat = self.conv1(feat) 698 feat = self.bn1(feat) 699 feat = self.relu(feat) 700 feat = self.conv2(feat) 701 gate = torch.sigmoid(feat).expand(B, C, Z, H, W) # Broadcast the collapsed depth context back to all slices 702 703 # Gated residual fusion 704 return x * gate + x 705 706 707class Deconv3DBlock(nn.Module): 708 """@private 709 """ 710 def __init__( 711 self, 712 scale_factor, 713 in_channels, 714 out_channels, 715 kernel_size=3, 716 anisotropic_kernel=True, 717 use_strip_pooling=True, 718 ): 719 super().__init__() 720 conv_block_kwargs = { 721 "in_channels": out_channels, 722 "out_channels": out_channels, 723 "kernel_size": kernel_size, 724 "padding": ((kernel_size - 1) // 2), 725 } 726 if anisotropic_kernel: 727 conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor) 728 729 self.block = nn.Sequential( 730 Upsampler3d(scale_factor, in_channels, out_channels), 731 nn.Conv3d(**conv_block_kwargs), 732 nn.BatchNorm3d(out_channels), 733 nn.ReLU(True), 734 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 735 ) 736 737 def forward(self, x): 738 return self.block(x) 739 740 741class ConvBlock3dWithStrip(nn.Module): 742 """@private 743 """ 744 def __init__( 745 self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs 746 ): 747 super().__init__() 748 self.block = nn.Sequential( 749 ConvBlock3d(in_channels, out_channels, **kwargs), 750 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 751 ) 752 753 def forward(self, x): 754 return self.block(x) 755 756 757class SingleDeconv2DBlock(nn.Module): 758 """@private 759 """ 760 def __init__(self, scale_factor, in_channels, out_channels): 761 super().__init__() 762 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 763 764 def forward(self, x): 765 return self.block(x) 766 767 768class SingleConv2DBlock(nn.Module): 769 """@private 770 """ 771 def __init__(self, in_channels, out_channels, kernel_size): 772 super().__init__() 773 self.block = nn.Conv2d( 774 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 775 ) 776 777 def forward(self, x): 778 return self.block(x) 779 780 781class Conv2DBlock(nn.Module): 782 """@private 783 """ 784 def __init__(self, in_channels, out_channels, kernel_size=3): 785 super().__init__() 786 self.block = nn.Sequential( 787 SingleConv2DBlock(in_channels, out_channels, kernel_size), 788 nn.BatchNorm2d(out_channels), 789 nn.ReLU(True) 790 ) 791 792 def forward(self, x): 793 return self.block(x) 794 795 796class Deconv2DBlock(nn.Module): 797 """@private 798 """ 799 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 800 super().__init__() 801 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 802 self.block = nn.Sequential( 803 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 804 SingleConv2DBlock(out_channels, out_channels, kernel_size), 805 nn.BatchNorm2d(out_channels), 806 nn.ReLU(True) 807 ) 808 809 def forward(self, x): 810 return self.block(x)
29class UNETRBase(nn.Module): 30 """Base class for implementing a UNETR. 31 32 Args: 33 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 34 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 35 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 36 decoder: The convolutional decoder. 37 out_channels: The number of output channels of the UNETR. 38 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 39 use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model. 40 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 41 resize_input: Whether to resize the input images to match `img_size`. 42 By default, it resizes the inputs to match the `img_size`. 43 encoder_checkpoint: Checkpoint for initializing the vision transformer. 44 Can either be a filepath or an already loaded checkpoint. 45 final_activation: The activation to apply to the UNETR output. 46 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 47 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 48 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 49 By default, it uses resampling for upsampling. 50 """ 51 def __init__( 52 self, 53 img_size: int = 1024, 54 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2", 55 encoder: Optional[Union[nn.Module, str]] = "vit_b", 56 decoder: Optional[nn.Module] = None, 57 out_channels: int = 1, 58 use_sam_stats: bool = False, 59 use_mae_stats: bool = False, 60 use_dino_stats: bool = False, 61 resize_input: bool = True, 62 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 63 final_activation: Optional[Union[str, nn.Module]] = None, 64 use_skip_connection: bool = True, 65 embed_dim: Optional[int] = None, 66 use_conv_transpose: bool = False, 67 **kwargs 68 ) -> None: 69 super().__init__() 70 71 self.img_size = img_size 72 self.use_sam_stats = use_sam_stats 73 self.use_mae_stats = use_mae_stats 74 self.use_dino_stats = use_dino_stats 75 self.use_skip_connection = use_skip_connection 76 self.resize_input = resize_input 77 self.use_conv_transpose = use_conv_transpose 78 self.backbone = backbone 79 80 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" 81 print(f"Using {encoder} from {backbone.upper()}") 82 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 83 84 if encoder_checkpoint is not None: 85 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 86 87 if embed_dim is None: 88 embed_dim = self.encoder.embed_dim 89 90 else: # `nn.Module` ViT backbone 91 self.encoder = encoder 92 93 have_neck = False 94 for name, _ in self.encoder.named_parameters(): 95 if name.startswith("neck"): 96 have_neck = True 97 98 if embed_dim is None: 99 if have_neck: 100 embed_dim = self.encoder.neck[2].out_channels # the value is 256 101 else: 102 embed_dim = self.encoder.patch_embed.proj.out_channels 103 104 self.embed_dim = embed_dim 105 self.final_activation = self._get_activation(final_activation) 106 107 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 108 """Function to load pretrained weights to the image encoder. 109 """ 110 if isinstance(checkpoint, str): 111 if backbone == "sam" and isinstance(encoder, str): 112 # If we have a SAM encoder, then we first try to load the full SAM Model 113 # (using micro_sam) and otherwise fall back on directly loading the encoder state 114 # from the checkpoint 115 try: 116 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 117 encoder_state = model.image_encoder.state_dict() 118 except Exception: 119 # Try loading the encoder state directly from a checkpoint. 120 encoder_state = torch.load(checkpoint, weights_only=False) 121 122 elif backbone == "sam2" and isinstance(encoder, str): 123 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 124 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 125 # from the checkpoint 126 try: 127 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 128 encoder_state = model.image_encoder.state_dict() 129 except Exception: 130 # Try loading the encoder state directly from a checkpoint. 131 encoder_state = torch.load(checkpoint, weights_only=False) 132 133 elif backbone == "mae": 134 # vit initialization hints from: 135 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 136 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 137 encoder_state = OrderedDict({ 138 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 139 }) 140 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 141 current_encoder_state = self.encoder.state_dict() 142 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 143 del self.encoder.head 144 145 elif backbone == "scalemae": 146 # Load the encoder state directly from a checkpoint. 147 encoder_state = torch.load(checkpoint)["model"] 148 encoder_state = OrderedDict({ 149 k: v for k, v in encoder_state.items() 150 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 151 }) 152 153 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 154 current_encoder_state = self.encoder.state_dict() 155 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 156 del self.encoder.head 157 158 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 159 del self.encoder.pos_embed 160 161 elif backbone == "dinov3": # Load the encoder state directly from a checkpoint. 162 encoder_state = torch.load(checkpoint) 163 164 else: 165 raise ValueError( 166 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 167 ) 168 169 else: 170 encoder_state = checkpoint 171 172 self.encoder.load_state_dict(encoder_state) 173 174 def _get_activation(self, activation): 175 return_activation = None 176 if activation is None: 177 return None 178 if isinstance(activation, nn.Module): 179 return activation 180 if isinstance(activation, str): 181 return_activation = getattr(nn, activation, None) 182 if return_activation is None: 183 raise ValueError(f"Invalid activation: {activation}") 184 185 return return_activation() 186 187 @staticmethod 188 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 189 """Compute the output size given input size and target long side length. 190 191 Args: 192 oldh: The input image height. 193 oldw: The input image width. 194 long_side_length: The longest side length for resizing. 195 196 Returns: 197 The new image height. 198 The new image width. 199 """ 200 scale = long_side_length * 1.0 / max(oldh, oldw) 201 newh, neww = oldh * scale, oldw * scale 202 neww = int(neww + 0.5) 203 newh = int(newh + 0.5) 204 return (newh, neww) 205 206 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 207 """Resize the image so that the longest side has the correct length. 208 209 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 210 211 Args: 212 image: The input image. 213 214 Returns: 215 The resized image. 216 """ 217 if image.ndim == 4: # i.e. 2d image 218 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 219 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 220 elif image.ndim == 5: # i.e. 3d volume 221 B, C, Z, H, W = image.shape 222 target_size = self.get_preprocess_shape(H, W, self.img_size) 223 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 224 else: 225 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 226 227 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 228 """@private 229 """ 230 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 231 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 232 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 233 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 234 return pixel_mean, pixel_std 235 236 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 237 """@private 238 """ 239 device = x.device 240 is_3d = (x.ndim == 5) 241 device, dtype = x.device, x.dtype 242 243 if self.use_sam_stats: 244 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 245 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 246 raise NotImplementedError 247 elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"): 248 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 249 else: 250 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 251 252 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 253 254 if self.resize_input: 255 x = self.resize_longest_side(x) 256 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 257 258 x = (x - pixel_mean) / pixel_std 259 h, w = x.shape[-2:] 260 padh = self.encoder.img_size - h 261 padw = self.encoder.img_size - w 262 263 if is_3d: 264 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 265 else: 266 x = F.pad(x, (0, padw, 0, padh)) 267 268 return x, input_shape 269 270 def postprocess_masks( 271 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 272 ) -> torch.Tensor: 273 """@private 274 """ 275 if masks.ndim == 4: # i.e. 2d labels 276 masks = F.interpolate( 277 masks, 278 (self.encoder.img_size, self.encoder.img_size), 279 mode="bilinear", 280 align_corners=False, 281 ) 282 masks = masks[..., : input_size[0], : input_size[1]] 283 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 284 285 elif masks.ndim == 5: # i.e. 3d volumetric labels 286 masks = F.interpolate( 287 masks, 288 (input_size[0], self.img_size, self.img_size), 289 mode="trilinear", 290 align_corners=False, 291 ) 292 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 293 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 294 295 else: 296 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 297 298 return masks
Base class for implementing a UNETR.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam" or "mae".
- encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
- use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size. By default, it resizes the inputs to match theimg_size. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
51 def __init__( 52 self, 53 img_size: int = 1024, 54 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2", 55 encoder: Optional[Union[nn.Module, str]] = "vit_b", 56 decoder: Optional[nn.Module] = None, 57 out_channels: int = 1, 58 use_sam_stats: bool = False, 59 use_mae_stats: bool = False, 60 use_dino_stats: bool = False, 61 resize_input: bool = True, 62 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 63 final_activation: Optional[Union[str, nn.Module]] = None, 64 use_skip_connection: bool = True, 65 embed_dim: Optional[int] = None, 66 use_conv_transpose: bool = False, 67 **kwargs 68 ) -> None: 69 super().__init__() 70 71 self.img_size = img_size 72 self.use_sam_stats = use_sam_stats 73 self.use_mae_stats = use_mae_stats 74 self.use_dino_stats = use_dino_stats 75 self.use_skip_connection = use_skip_connection 76 self.resize_input = resize_input 77 self.use_conv_transpose = use_conv_transpose 78 self.backbone = backbone 79 80 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" 81 print(f"Using {encoder} from {backbone.upper()}") 82 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 83 84 if encoder_checkpoint is not None: 85 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 86 87 if embed_dim is None: 88 embed_dim = self.encoder.embed_dim 89 90 else: # `nn.Module` ViT backbone 91 self.encoder = encoder 92 93 have_neck = False 94 for name, _ in self.encoder.named_parameters(): 95 if name.startswith("neck"): 96 have_neck = True 97 98 if embed_dim is None: 99 if have_neck: 100 embed_dim = self.encoder.neck[2].out_channels # the value is 256 101 else: 102 embed_dim = self.encoder.patch_embed.proj.out_channels 103 104 self.embed_dim = embed_dim 105 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
187 @staticmethod 188 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 189 """Compute the output size given input size and target long side length. 190 191 Args: 192 oldh: The input image height. 193 oldw: The input image width. 194 long_side_length: The longest side length for resizing. 195 196 Returns: 197 The new image height. 198 The new image width. 199 """ 200 scale = long_side_length * 1.0 / max(oldh, oldw) 201 newh, neww = oldh * scale, oldw * scale 202 neww = int(neww + 0.5) 203 newh = int(newh + 0.5) 204 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
206 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 207 """Resize the image so that the longest side has the correct length. 208 209 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 210 211 Args: 212 image: The input image. 213 214 Returns: 215 The resized image. 216 """ 217 if image.ndim == 4: # i.e. 2d image 218 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 219 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 220 elif image.ndim == 5: # i.e. 3d volume 221 B, C, Z, H, W = image.shape 222 target_size = self.get_preprocess_shape(H, W, self.img_size) 223 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 224 else: 225 raise ValueError("Expected 4d or 5d inputs, got", image.shape)
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
301class UNETR(UNETRBase): 302 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 303 """ 304 def __init__( 305 self, 306 img_size: int = 1024, 307 backbone: str = "sam", 308 encoder: Optional[Union[nn.Module, str]] = "vit_b", 309 decoder: Optional[nn.Module] = None, 310 out_channels: int = 1, 311 use_sam_stats: bool = False, 312 use_mae_stats: bool = False, 313 use_dino_stats: bool = False, 314 resize_input: bool = True, 315 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 316 final_activation: Optional[Union[str, nn.Module]] = None, 317 use_skip_connection: bool = True, 318 embed_dim: Optional[int] = None, 319 use_conv_transpose: bool = False, 320 **kwargs 321 ) -> None: 322 323 super().__init__( 324 img_size=img_size, 325 backbone=backbone, 326 encoder=encoder, 327 decoder=decoder, 328 out_channels=out_channels, 329 use_sam_stats=use_sam_stats, 330 use_mae_stats=use_mae_stats, 331 use_dino_stats=use_dino_stats, 332 resize_input=resize_input, 333 encoder_checkpoint=encoder_checkpoint, 334 final_activation=final_activation, 335 use_skip_connection=use_skip_connection, 336 embed_dim=embed_dim, 337 use_conv_transpose=use_conv_transpose, 338 **kwargs, 339 ) 340 341 encoder = self.encoder 342 343 if backbone == "sam2" and hasattr(encoder, "trunk"): 344 in_chans = encoder.trunk.patch_embed.proj.in_channels 345 elif hasattr(encoder, "in_chans"): 346 in_chans = encoder.in_chans 347 else: # `nn.Module` ViT backbone. 348 try: 349 in_chans = encoder.patch_embed.proj.in_channels 350 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 351 in_chans = encoder.patch_embed.seq[0].c.in_channels 352 353 # parameters for the decoder network 354 depth = 3 355 initial_features = 64 356 gain = 2 357 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 358 scale_factors = depth * [2] 359 self.out_channels = out_channels 360 361 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 362 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 363 364 self.decoder = decoder or Decoder( 365 features=features_decoder, 366 scale_factors=scale_factors[::-1], 367 conv_block_impl=ConvBlock2d, 368 sampler_impl=_upsampler, 369 ) 370 371 if use_skip_connection: 372 self.deconv1 = Deconv2DBlock( 373 in_channels=self.embed_dim, 374 out_channels=features_decoder[0], 375 use_conv_transpose=use_conv_transpose, 376 ) 377 self.deconv2 = nn.Sequential( 378 Deconv2DBlock( 379 in_channels=self.embed_dim, 380 out_channels=features_decoder[0], 381 use_conv_transpose=use_conv_transpose, 382 ), 383 Deconv2DBlock( 384 in_channels=features_decoder[0], 385 out_channels=features_decoder[1], 386 use_conv_transpose=use_conv_transpose, 387 ) 388 ) 389 self.deconv3 = nn.Sequential( 390 Deconv2DBlock( 391 in_channels=self.embed_dim, 392 out_channels=features_decoder[0], 393 use_conv_transpose=use_conv_transpose, 394 ), 395 Deconv2DBlock( 396 in_channels=features_decoder[0], 397 out_channels=features_decoder[1], 398 use_conv_transpose=use_conv_transpose, 399 ), 400 Deconv2DBlock( 401 in_channels=features_decoder[1], 402 out_channels=features_decoder[2], 403 use_conv_transpose=use_conv_transpose, 404 ) 405 ) 406 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 407 else: 408 self.deconv1 = Deconv2DBlock( 409 in_channels=self.embed_dim, 410 out_channels=features_decoder[0], 411 use_conv_transpose=use_conv_transpose, 412 ) 413 self.deconv2 = Deconv2DBlock( 414 in_channels=features_decoder[0], 415 out_channels=features_decoder[1], 416 use_conv_transpose=use_conv_transpose, 417 ) 418 self.deconv3 = Deconv2DBlock( 419 in_channels=features_decoder[1], 420 out_channels=features_decoder[2], 421 use_conv_transpose=use_conv_transpose, 422 ) 423 self.deconv4 = Deconv2DBlock( 424 in_channels=features_decoder[2], 425 out_channels=features_decoder[3], 426 use_conv_transpose=use_conv_transpose, 427 ) 428 429 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 430 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 431 self.deconv_out = _upsampler( 432 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 433 ) 434 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 435 436 def forward(self, x: torch.Tensor) -> torch.Tensor: 437 """Apply the UNETR to the input data. 438 439 Args: 440 x: The input tensor. 441 442 Returns: 443 The UNETR output. 444 """ 445 original_shape = x.shape[-2:] 446 447 # Reshape the inputs to the shape expected by the encoder 448 # and normalize the inputs if normalization is part of the model. 449 x, input_shape = self.preprocess(x) 450 451 encoder_outputs = self.encoder(x) 452 453 if isinstance(encoder_outputs[-1], list): 454 # `encoder_outputs` can be arranged in only two forms: 455 # - either we only return the image embeddings 456 # - or, we return the image embeddings and the "list" of global attention layers 457 z12, from_encoder = encoder_outputs 458 else: 459 z12 = encoder_outputs 460 461 if self.use_skip_connection: 462 from_encoder = from_encoder[::-1] 463 z9 = self.deconv1(from_encoder[0]) 464 z6 = self.deconv2(from_encoder[1]) 465 z3 = self.deconv3(from_encoder[2]) 466 z0 = self.deconv4(x) 467 468 else: 469 z9 = self.deconv1(z12) 470 z6 = self.deconv2(z9) 471 z3 = self.deconv3(z6) 472 z0 = self.deconv4(z3) 473 474 updated_from_encoder = [z9, z6, z3] 475 476 x = self.base(z12) 477 x = self.decoder(x, encoder_inputs=updated_from_encoder) 478 x = self.deconv_out(x) 479 480 x = torch.cat([x, z0], dim=1) 481 x = self.decoder_head(x) 482 483 x = self.out_conv(x) 484 if self.final_activation is not None: 485 x = self.final_activation(x) 486 487 x = self.postprocess_masks(x, input_shape, original_shape) 488 return x
A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
304 def __init__( 305 self, 306 img_size: int = 1024, 307 backbone: str = "sam", 308 encoder: Optional[Union[nn.Module, str]] = "vit_b", 309 decoder: Optional[nn.Module] = None, 310 out_channels: int = 1, 311 use_sam_stats: bool = False, 312 use_mae_stats: bool = False, 313 use_dino_stats: bool = False, 314 resize_input: bool = True, 315 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 316 final_activation: Optional[Union[str, nn.Module]] = None, 317 use_skip_connection: bool = True, 318 embed_dim: Optional[int] = None, 319 use_conv_transpose: bool = False, 320 **kwargs 321 ) -> None: 322 323 super().__init__( 324 img_size=img_size, 325 backbone=backbone, 326 encoder=encoder, 327 decoder=decoder, 328 out_channels=out_channels, 329 use_sam_stats=use_sam_stats, 330 use_mae_stats=use_mae_stats, 331 use_dino_stats=use_dino_stats, 332 resize_input=resize_input, 333 encoder_checkpoint=encoder_checkpoint, 334 final_activation=final_activation, 335 use_skip_connection=use_skip_connection, 336 embed_dim=embed_dim, 337 use_conv_transpose=use_conv_transpose, 338 **kwargs, 339 ) 340 341 encoder = self.encoder 342 343 if backbone == "sam2" and hasattr(encoder, "trunk"): 344 in_chans = encoder.trunk.patch_embed.proj.in_channels 345 elif hasattr(encoder, "in_chans"): 346 in_chans = encoder.in_chans 347 else: # `nn.Module` ViT backbone. 348 try: 349 in_chans = encoder.patch_embed.proj.in_channels 350 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 351 in_chans = encoder.patch_embed.seq[0].c.in_channels 352 353 # parameters for the decoder network 354 depth = 3 355 initial_features = 64 356 gain = 2 357 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 358 scale_factors = depth * [2] 359 self.out_channels = out_channels 360 361 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 362 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 363 364 self.decoder = decoder or Decoder( 365 features=features_decoder, 366 scale_factors=scale_factors[::-1], 367 conv_block_impl=ConvBlock2d, 368 sampler_impl=_upsampler, 369 ) 370 371 if use_skip_connection: 372 self.deconv1 = Deconv2DBlock( 373 in_channels=self.embed_dim, 374 out_channels=features_decoder[0], 375 use_conv_transpose=use_conv_transpose, 376 ) 377 self.deconv2 = nn.Sequential( 378 Deconv2DBlock( 379 in_channels=self.embed_dim, 380 out_channels=features_decoder[0], 381 use_conv_transpose=use_conv_transpose, 382 ), 383 Deconv2DBlock( 384 in_channels=features_decoder[0], 385 out_channels=features_decoder[1], 386 use_conv_transpose=use_conv_transpose, 387 ) 388 ) 389 self.deconv3 = nn.Sequential( 390 Deconv2DBlock( 391 in_channels=self.embed_dim, 392 out_channels=features_decoder[0], 393 use_conv_transpose=use_conv_transpose, 394 ), 395 Deconv2DBlock( 396 in_channels=features_decoder[0], 397 out_channels=features_decoder[1], 398 use_conv_transpose=use_conv_transpose, 399 ), 400 Deconv2DBlock( 401 in_channels=features_decoder[1], 402 out_channels=features_decoder[2], 403 use_conv_transpose=use_conv_transpose, 404 ) 405 ) 406 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 407 else: 408 self.deconv1 = Deconv2DBlock( 409 in_channels=self.embed_dim, 410 out_channels=features_decoder[0], 411 use_conv_transpose=use_conv_transpose, 412 ) 413 self.deconv2 = Deconv2DBlock( 414 in_channels=features_decoder[0], 415 out_channels=features_decoder[1], 416 use_conv_transpose=use_conv_transpose, 417 ) 418 self.deconv3 = Deconv2DBlock( 419 in_channels=features_decoder[1], 420 out_channels=features_decoder[2], 421 use_conv_transpose=use_conv_transpose, 422 ) 423 self.deconv4 = Deconv2DBlock( 424 in_channels=features_decoder[2], 425 out_channels=features_decoder[3], 426 use_conv_transpose=use_conv_transpose, 427 ) 428 429 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 430 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 431 self.deconv_out = _upsampler( 432 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 433 ) 434 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
436 def forward(self, x: torch.Tensor) -> torch.Tensor: 437 """Apply the UNETR to the input data. 438 439 Args: 440 x: The input tensor. 441 442 Returns: 443 The UNETR output. 444 """ 445 original_shape = x.shape[-2:] 446 447 # Reshape the inputs to the shape expected by the encoder 448 # and normalize the inputs if normalization is part of the model. 449 x, input_shape = self.preprocess(x) 450 451 encoder_outputs = self.encoder(x) 452 453 if isinstance(encoder_outputs[-1], list): 454 # `encoder_outputs` can be arranged in only two forms: 455 # - either we only return the image embeddings 456 # - or, we return the image embeddings and the "list" of global attention layers 457 z12, from_encoder = encoder_outputs 458 else: 459 z12 = encoder_outputs 460 461 if self.use_skip_connection: 462 from_encoder = from_encoder[::-1] 463 z9 = self.deconv1(from_encoder[0]) 464 z6 = self.deconv2(from_encoder[1]) 465 z3 = self.deconv3(from_encoder[2]) 466 z0 = self.deconv4(x) 467 468 else: 469 z9 = self.deconv1(z12) 470 z6 = self.deconv2(z9) 471 z3 = self.deconv3(z6) 472 z0 = self.deconv4(z3) 473 474 updated_from_encoder = [z9, z6, z3] 475 476 x = self.base(z12) 477 x = self.decoder(x, encoder_inputs=updated_from_encoder) 478 x = self.deconv_out(x) 479 480 x = torch.cat([x, z0], dim=1) 481 x = self.decoder_head(x) 482 483 x = self.out_conv(x) 484 if self.final_activation is not None: 485 x = self.final_activation(x) 486 487 x = self.postprocess_masks(x, input_shape, original_shape) 488 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.
491class UNETR2D(UNETR): 492 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 493 """ 494 pass
A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
497class UNETR3D(UNETRBase): 498 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 499 """ 500 def __init__( 501 self, 502 img_size: int = 1024, 503 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2", 504 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 505 decoder: Optional[nn.Module] = None, 506 out_channels: int = 1, 507 use_sam_stats: bool = False, 508 use_mae_stats: bool = False, 509 use_dino_stats: bool = False, 510 resize_input: bool = True, 511 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 512 final_activation: Optional[Union[str, nn.Module]] = None, 513 use_skip_connection: bool = False, 514 embed_dim: Optional[int] = None, 515 use_conv_transpose: bool = False, 516 use_strip_pooling: bool = True, 517 **kwargs 518 ): 519 if use_skip_connection: 520 raise NotImplementedError("The framework cannot handle skip connections atm.") 521 if use_conv_transpose: 522 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 523 524 super().__init__( 525 img_size=img_size, 526 backbone=backbone, 527 encoder=encoder, 528 decoder=decoder, 529 out_channels=out_channels, 530 use_sam_stats=use_sam_stats, 531 use_mae_stats=use_mae_stats, 532 use_dino_stats=use_dino_stats, 533 resize_input=resize_input, 534 encoder_checkpoint=encoder_checkpoint, 535 final_activation=final_activation, 536 use_skip_connection=use_skip_connection, 537 embed_dim=embed_dim, 538 use_conv_transpose=use_conv_transpose, 539 **kwargs, 540 ) 541 542 # Load the pretrained image encoder weights 543 self.image_encoder = self.encoder 544 545 # Step 2: the 3d convolutional decoder. 546 # First, get the important parameters for the decoder. 547 embed_dim = 256 548 depth = 3 549 initial_features = 64 550 gain = 2 551 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 552 scale_factors = [1, 2, 2] 553 self.out_channels = out_channels 554 555 # The mapping blocks. 556 self.deconv1 = Deconv3DBlock( 557 in_channels=embed_dim, 558 out_channels=features_decoder[0], 559 scale_factor=scale_factors, 560 use_strip_pooling=use_strip_pooling, 561 ) 562 self.deconv2 = Deconv3DBlock( 563 in_channels=features_decoder[0], 564 out_channels=features_decoder[1], 565 scale_factor=scale_factors, 566 use_strip_pooling=use_strip_pooling, 567 ) 568 self.deconv3 = Deconv3DBlock( 569 in_channels=features_decoder[1], 570 out_channels=features_decoder[2], 571 scale_factor=scale_factors, 572 use_strip_pooling=use_strip_pooling, 573 ) 574 self.deconv4 = Deconv3DBlock( 575 in_channels=features_decoder[2], 576 out_channels=features_decoder[3], 577 scale_factor=scale_factors, 578 use_strip_pooling=use_strip_pooling, 579 ) 580 581 # The core decoder block. 582 self.decoder = decoder or Decoder( 583 features=features_decoder, 584 scale_factors=[scale_factors] * depth, 585 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 586 sampler_impl=Upsampler3d, 587 ) 588 589 # And the final upsampler to match the expected dimensions. 590 self.end_up = Deconv3DBlock( 591 in_channels=features_decoder[-1], 592 out_channels=features_decoder[-1], 593 scale_factor=scale_factors, 594 use_strip_pooling=use_strip_pooling, 595 ) 596 597 # Additional conjunction blocks. 598 self.base = ConvBlock3dWithStrip( 599 in_channels=embed_dim, 600 out_channels=features_decoder[0], 601 use_strip_pooling=use_strip_pooling, 602 ) 603 604 # And the output layers. 605 self.decoder_head = ConvBlock3dWithStrip( 606 in_channels=2 * features_decoder[-1], 607 out_channels=features_decoder[-1], 608 use_strip_pooling=use_strip_pooling, 609 ) 610 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 611 612 def forward(self, x: torch.Tensor): 613 """Forward pass of the UNETR-3D model. 614 615 Args: 616 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 617 618 Returns: 619 The UNETR output. 620 """ 621 B, C, Z, H, W = x.shape 622 original_shape = (Z, H, W) 623 624 # Preprocessing step 625 x, input_shape = self.preprocess(x) 626 627 # Run the image encoder. 628 curr_features = torch.stack([self.image_encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 629 630 # Prepare the counterparts for the decoder. 631 # NOTE: The section below is sequential, there's no skip connections atm. 632 z9 = self.deconv1(curr_features) 633 z6 = self.deconv2(z9) 634 z3 = self.deconv3(z6) 635 z0 = self.deconv4(z3) 636 637 # Align the features through the base block. 638 x = self.base(curr_features) 639 640 # Run the decoder. 641 updated_from_encoder = [z9, z6, z3] 642 x = self.decoder(x, encoder_inputs=updated_from_encoder) 643 x = self.end_up(x) 644 645 # And the final output head. 646 x = torch.cat([x, z0], dim=1) 647 x = self.decoder_head(x) 648 x = self.out_conv(x) 649 if self.final_activation is not None: 650 x = self.final_activation(x) 651 652 # Postprocess the output back to original size. 653 x = self.postprocess_masks(x, input_shape, original_shape) 654 return x
A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
500 def __init__( 501 self, 502 img_size: int = 1024, 503 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov3"] = "sam2", 504 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 505 decoder: Optional[nn.Module] = None, 506 out_channels: int = 1, 507 use_sam_stats: bool = False, 508 use_mae_stats: bool = False, 509 use_dino_stats: bool = False, 510 resize_input: bool = True, 511 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 512 final_activation: Optional[Union[str, nn.Module]] = None, 513 use_skip_connection: bool = False, 514 embed_dim: Optional[int] = None, 515 use_conv_transpose: bool = False, 516 use_strip_pooling: bool = True, 517 **kwargs 518 ): 519 if use_skip_connection: 520 raise NotImplementedError("The framework cannot handle skip connections atm.") 521 if use_conv_transpose: 522 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 523 524 super().__init__( 525 img_size=img_size, 526 backbone=backbone, 527 encoder=encoder, 528 decoder=decoder, 529 out_channels=out_channels, 530 use_sam_stats=use_sam_stats, 531 use_mae_stats=use_mae_stats, 532 use_dino_stats=use_dino_stats, 533 resize_input=resize_input, 534 encoder_checkpoint=encoder_checkpoint, 535 final_activation=final_activation, 536 use_skip_connection=use_skip_connection, 537 embed_dim=embed_dim, 538 use_conv_transpose=use_conv_transpose, 539 **kwargs, 540 ) 541 542 # Load the pretrained image encoder weights 543 self.image_encoder = self.encoder 544 545 # Step 2: the 3d convolutional decoder. 546 # First, get the important parameters for the decoder. 547 embed_dim = 256 548 depth = 3 549 initial_features = 64 550 gain = 2 551 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 552 scale_factors = [1, 2, 2] 553 self.out_channels = out_channels 554 555 # The mapping blocks. 556 self.deconv1 = Deconv3DBlock( 557 in_channels=embed_dim, 558 out_channels=features_decoder[0], 559 scale_factor=scale_factors, 560 use_strip_pooling=use_strip_pooling, 561 ) 562 self.deconv2 = Deconv3DBlock( 563 in_channels=features_decoder[0], 564 out_channels=features_decoder[1], 565 scale_factor=scale_factors, 566 use_strip_pooling=use_strip_pooling, 567 ) 568 self.deconv3 = Deconv3DBlock( 569 in_channels=features_decoder[1], 570 out_channels=features_decoder[2], 571 scale_factor=scale_factors, 572 use_strip_pooling=use_strip_pooling, 573 ) 574 self.deconv4 = Deconv3DBlock( 575 in_channels=features_decoder[2], 576 out_channels=features_decoder[3], 577 scale_factor=scale_factors, 578 use_strip_pooling=use_strip_pooling, 579 ) 580 581 # The core decoder block. 582 self.decoder = decoder or Decoder( 583 features=features_decoder, 584 scale_factors=[scale_factors] * depth, 585 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 586 sampler_impl=Upsampler3d, 587 ) 588 589 # And the final upsampler to match the expected dimensions. 590 self.end_up = Deconv3DBlock( 591 in_channels=features_decoder[-1], 592 out_channels=features_decoder[-1], 593 scale_factor=scale_factors, 594 use_strip_pooling=use_strip_pooling, 595 ) 596 597 # Additional conjunction blocks. 598 self.base = ConvBlock3dWithStrip( 599 in_channels=embed_dim, 600 out_channels=features_decoder[0], 601 use_strip_pooling=use_strip_pooling, 602 ) 603 604 # And the output layers. 605 self.decoder_head = ConvBlock3dWithStrip( 606 in_channels=2 * features_decoder[-1], 607 out_channels=features_decoder[-1], 608 use_strip_pooling=use_strip_pooling, 609 ) 610 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
612 def forward(self, x: torch.Tensor): 613 """Forward pass of the UNETR-3D model. 614 615 Args: 616 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 617 618 Returns: 619 The UNETR output. 620 """ 621 B, C, Z, H, W = x.shape 622 original_shape = (Z, H, W) 623 624 # Preprocessing step 625 x, input_shape = self.preprocess(x) 626 627 # Run the image encoder. 628 curr_features = torch.stack([self.image_encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 629 630 # Prepare the counterparts for the decoder. 631 # NOTE: The section below is sequential, there's no skip connections atm. 632 z9 = self.deconv1(curr_features) 633 z6 = self.deconv2(z9) 634 z3 = self.deconv3(z6) 635 z0 = self.deconv4(z3) 636 637 # Align the features through the base block. 638 x = self.base(curr_features) 639 640 # Run the decoder. 641 updated_from_encoder = [z9, z6, z3] 642 x = self.decoder(x, encoder_inputs=updated_from_encoder) 643 x = self.end_up(x) 644 645 # And the final output head. 646 x = torch.cat([x, z0], dim=1) 647 x = self.decoder_head(x) 648 x = self.out_conv(x) 649 if self.final_activation is not None: 650 x = self.final_activation(x) 651 652 # Postprocess the output back to original size. 653 x = self.postprocess_masks(x, input_shape, original_shape) 654 return x
Forward pass of the UNETR-3D model.
Arguments:
- x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:
The UNETR output.