torch_em.model.unetr
1from functools import partial 2from collections import OrderedDict 3from typing import Optional, Tuple, Union, Literal 4 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8 9from .vit import get_vision_transformer 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs 11 12try: 13 from micro_sam.util import get_sam_model 14except ImportError: 15 get_sam_model = None 16 17try: 18 from micro_sam2.util import get_sam2_model 19except ImportError: 20 get_sam2_model = None 21 22 23# 24# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / MAE / ScaleMAE) + UNet Decoder from `torch_em`] 25# 26 27 28class UNETRBase(nn.Module): 29 """Base class for implementing a UNETR. 30 31 Args: 32 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 33 backbone: The name of the vision transformer implementation. 34 One of "sam", "sam2", "mae", "scalemae", "dinov2", "dinov3". 35 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 36 decoder: The convolutional decoder. 37 out_channels: The number of output channels of the UNETR. 38 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 39 use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model. 40 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 41 resize_input: Whether to resize the input images to match `img_size`. 42 By default, it resizes the inputs to match the `img_size`. 43 encoder_checkpoint: Checkpoint for initializing the vision transformer. 44 Can either be a filepath or an already loaded checkpoint. 45 final_activation: The activation to apply to the UNETR output. 46 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 47 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 48 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 49 By default, it uses resampling for upsampling. 50 """ 51 def __init__( 52 self, 53 img_size: int = 1024, 54 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 55 encoder: Optional[Union[nn.Module, str]] = "vit_b", 56 decoder: Optional[nn.Module] = None, 57 out_channels: int = 1, 58 use_sam_stats: bool = False, 59 use_mae_stats: bool = False, 60 use_dino_stats: bool = False, 61 resize_input: bool = True, 62 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 63 final_activation: Optional[Union[str, nn.Module]] = None, 64 use_skip_connection: bool = True, 65 embed_dim: Optional[int] = None, 66 use_conv_transpose: bool = False, 67 **kwargs 68 ) -> None: 69 super().__init__() 70 71 self.img_size = img_size 72 self.use_sam_stats = use_sam_stats 73 self.use_mae_stats = use_mae_stats 74 self.use_dino_stats = use_dino_stats 75 self.use_skip_connection = use_skip_connection 76 self.resize_input = resize_input 77 self.use_conv_transpose = use_conv_transpose 78 self.backbone = backbone 79 80 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" 81 print(f"Using {encoder} from {backbone.upper()}") 82 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 83 84 if encoder_checkpoint is not None: 85 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 86 87 if embed_dim is None: 88 embed_dim = self.encoder.embed_dim 89 90 else: # `nn.Module` ViT backbone 91 self.encoder = encoder 92 93 have_neck = False 94 for name, _ in self.encoder.named_parameters(): 95 if name.startswith("neck"): 96 have_neck = True 97 98 if embed_dim is None: 99 if have_neck: 100 embed_dim = self.encoder.neck[2].out_channels # the value is 256 101 else: 102 embed_dim = self.encoder.patch_embed.proj.out_channels 103 104 self.embed_dim = embed_dim 105 self.final_activation = self._get_activation(final_activation) 106 107 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 108 """Function to load pretrained weights to the image encoder. 109 """ 110 if isinstance(checkpoint, str): 111 if backbone == "sam" and isinstance(encoder, str): 112 # If we have a SAM encoder, then we first try to load the full SAM Model 113 # (using micro_sam) and otherwise fall back on directly loading the encoder state 114 # from the checkpoint 115 try: 116 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 117 encoder_state = model.image_encoder.state_dict() 118 except Exception: 119 # Try loading the encoder state directly from a checkpoint. 120 encoder_state = torch.load(checkpoint, weights_only=False) 121 122 elif backbone == "sam2" and isinstance(encoder, str): 123 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 124 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 125 # from the checkpoint 126 try: 127 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 128 encoder_state = model.image_encoder.state_dict() 129 except Exception: 130 # Try loading the encoder state directly from a checkpoint. 131 encoder_state = torch.load(checkpoint, weights_only=False) 132 133 elif backbone == "mae": 134 # vit initialization hints from: 135 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 136 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 137 encoder_state = OrderedDict({ 138 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 139 }) 140 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 141 current_encoder_state = self.encoder.state_dict() 142 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 143 del self.encoder.head 144 145 elif backbone == "scalemae": 146 # Load the encoder state directly from a checkpoint. 147 encoder_state = torch.load(checkpoint)["model"] 148 encoder_state = OrderedDict({ 149 k: v for k, v in encoder_state.items() 150 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 151 }) 152 153 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 154 current_encoder_state = self.encoder.state_dict() 155 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 156 del self.encoder.head 157 158 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 159 del self.encoder.pos_embed 160 161 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 162 encoder_state = torch.load(checkpoint) 163 164 else: 165 raise ValueError( 166 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 167 ) 168 169 else: 170 encoder_state = checkpoint 171 172 self.encoder.load_state_dict(encoder_state) 173 174 def _get_activation(self, activation): 175 return_activation = None 176 if activation is None: 177 return None 178 if isinstance(activation, nn.Module): 179 return activation 180 if isinstance(activation, str): 181 return_activation = getattr(nn, activation, None) 182 if return_activation is None: 183 raise ValueError(f"Invalid activation: {activation}") 184 185 return return_activation() 186 187 @staticmethod 188 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 189 """Compute the output size given input size and target long side length. 190 191 Args: 192 oldh: The input image height. 193 oldw: The input image width. 194 long_side_length: The longest side length for resizing. 195 196 Returns: 197 The new image height. 198 The new image width. 199 """ 200 scale = long_side_length * 1.0 / max(oldh, oldw) 201 newh, neww = oldh * scale, oldw * scale 202 neww = int(neww + 0.5) 203 newh = int(newh + 0.5) 204 return (newh, neww) 205 206 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 207 """Resize the image so that the longest side has the correct length. 208 209 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 210 211 Args: 212 image: The input image. 213 214 Returns: 215 The resized image. 216 """ 217 if image.ndim == 4: # i.e. 2d image 218 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 219 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 220 elif image.ndim == 5: # i.e. 3d volume 221 B, C, Z, H, W = image.shape 222 target_size = self.get_preprocess_shape(H, W, self.img_size) 223 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 224 else: 225 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 226 227 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 228 """@private 229 """ 230 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 231 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 232 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 233 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 234 return pixel_mean, pixel_std 235 236 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 237 """@private 238 """ 239 device = x.device 240 is_3d = (x.ndim == 5) 241 device, dtype = x.device, x.dtype 242 243 if self.use_sam_stats: 244 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 245 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 246 raise NotImplementedError 247 elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"): 248 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 249 else: 250 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 251 252 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 253 254 if self.resize_input: 255 x = self.resize_longest_side(x) 256 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 257 258 x = (x - pixel_mean) / pixel_std 259 h, w = x.shape[-2:] 260 padh = self.encoder.img_size - h 261 padw = self.encoder.img_size - w 262 263 if is_3d: 264 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 265 else: 266 x = F.pad(x, (0, padw, 0, padh)) 267 268 return x, input_shape 269 270 def postprocess_masks( 271 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 272 ) -> torch.Tensor: 273 """@private 274 """ 275 if masks.ndim == 4: # i.e. 2d labels 276 masks = F.interpolate( 277 masks, 278 (self.encoder.img_size, self.encoder.img_size), 279 mode="bilinear", 280 align_corners=False, 281 ) 282 masks = masks[..., : input_size[0], : input_size[1]] 283 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 284 285 elif masks.ndim == 5: # i.e. 3d volumetric labels 286 masks = F.interpolate( 287 masks, 288 (input_size[0], self.img_size, self.img_size), 289 mode="trilinear", 290 align_corners=False, 291 ) 292 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 293 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 294 295 else: 296 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 297 298 return masks 299 300 301class UNETR(UNETRBase): 302 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 303 """ 304 def __init__( 305 self, 306 img_size: int = 1024, 307 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 308 encoder: Optional[Union[nn.Module, str]] = "vit_b", 309 decoder: Optional[nn.Module] = None, 310 out_channels: int = 1, 311 use_sam_stats: bool = False, 312 use_mae_stats: bool = False, 313 use_dino_stats: bool = False, 314 resize_input: bool = True, 315 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 316 final_activation: Optional[Union[str, nn.Module]] = None, 317 use_skip_connection: bool = True, 318 embed_dim: Optional[int] = None, 319 use_conv_transpose: bool = False, 320 **kwargs 321 ) -> None: 322 323 super().__init__( 324 img_size=img_size, 325 backbone=backbone, 326 encoder=encoder, 327 decoder=decoder, 328 out_channels=out_channels, 329 use_sam_stats=use_sam_stats, 330 use_mae_stats=use_mae_stats, 331 use_dino_stats=use_dino_stats, 332 resize_input=resize_input, 333 encoder_checkpoint=encoder_checkpoint, 334 final_activation=final_activation, 335 use_skip_connection=use_skip_connection, 336 embed_dim=embed_dim, 337 use_conv_transpose=use_conv_transpose, 338 **kwargs, 339 ) 340 341 encoder = self.encoder 342 343 if backbone == "sam2" and hasattr(encoder, "trunk"): 344 in_chans = encoder.trunk.patch_embed.proj.in_channels 345 elif hasattr(encoder, "in_chans"): 346 in_chans = encoder.in_chans 347 else: # `nn.Module` ViT backbone. 348 try: 349 in_chans = encoder.patch_embed.proj.in_channels 350 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 351 in_chans = encoder.patch_embed.seq[0].c.in_channels 352 353 # parameters for the decoder network 354 depth = 3 355 initial_features = 64 356 gain = 2 357 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 358 scale_factors = depth * [2] 359 self.out_channels = out_channels 360 361 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 362 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 363 364 self.decoder = decoder or Decoder( 365 features=features_decoder, 366 scale_factors=scale_factors[::-1], 367 conv_block_impl=ConvBlock2d, 368 sampler_impl=_upsampler, 369 ) 370 371 if use_skip_connection: 372 self.deconv1 = Deconv2DBlock( 373 in_channels=self.embed_dim, 374 out_channels=features_decoder[0], 375 use_conv_transpose=use_conv_transpose, 376 ) 377 self.deconv2 = nn.Sequential( 378 Deconv2DBlock( 379 in_channels=self.embed_dim, 380 out_channels=features_decoder[0], 381 use_conv_transpose=use_conv_transpose, 382 ), 383 Deconv2DBlock( 384 in_channels=features_decoder[0], 385 out_channels=features_decoder[1], 386 use_conv_transpose=use_conv_transpose, 387 ) 388 ) 389 self.deconv3 = nn.Sequential( 390 Deconv2DBlock( 391 in_channels=self.embed_dim, 392 out_channels=features_decoder[0], 393 use_conv_transpose=use_conv_transpose, 394 ), 395 Deconv2DBlock( 396 in_channels=features_decoder[0], 397 out_channels=features_decoder[1], 398 use_conv_transpose=use_conv_transpose, 399 ), 400 Deconv2DBlock( 401 in_channels=features_decoder[1], 402 out_channels=features_decoder[2], 403 use_conv_transpose=use_conv_transpose, 404 ) 405 ) 406 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 407 else: 408 self.deconv1 = Deconv2DBlock( 409 in_channels=self.embed_dim, 410 out_channels=features_decoder[0], 411 use_conv_transpose=use_conv_transpose, 412 ) 413 self.deconv2 = Deconv2DBlock( 414 in_channels=features_decoder[0], 415 out_channels=features_decoder[1], 416 use_conv_transpose=use_conv_transpose, 417 ) 418 self.deconv3 = Deconv2DBlock( 419 in_channels=features_decoder[1], 420 out_channels=features_decoder[2], 421 use_conv_transpose=use_conv_transpose, 422 ) 423 self.deconv4 = Deconv2DBlock( 424 in_channels=features_decoder[2], 425 out_channels=features_decoder[3], 426 use_conv_transpose=use_conv_transpose, 427 ) 428 429 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 430 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 431 self.deconv_out = _upsampler( 432 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 433 ) 434 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 435 436 def forward(self, x: torch.Tensor) -> torch.Tensor: 437 """Apply the UNETR to the input data. 438 439 Args: 440 x: The input tensor. 441 442 Returns: 443 The UNETR output. 444 """ 445 original_shape = x.shape[-2:] 446 447 # Reshape the inputs to the shape expected by the encoder 448 # and normalize the inputs if normalization is part of the model. 449 x, input_shape = self.preprocess(x) 450 451 encoder_outputs = self.encoder(x) 452 453 if isinstance(encoder_outputs[-1], list): 454 # `encoder_outputs` can be arranged in only two forms: 455 # - either we only return the image embeddings 456 # - or, we return the image embeddings and the "list" of global attention layers 457 z12, from_encoder = encoder_outputs 458 else: 459 z12 = encoder_outputs 460 461 if self.use_skip_connection: 462 from_encoder = from_encoder[::-1] 463 z9 = self.deconv1(from_encoder[0]) 464 z6 = self.deconv2(from_encoder[1]) 465 z3 = self.deconv3(from_encoder[2]) 466 z0 = self.deconv4(x) 467 468 else: 469 z9 = self.deconv1(z12) 470 z6 = self.deconv2(z9) 471 z3 = self.deconv3(z6) 472 z0 = self.deconv4(z3) 473 474 updated_from_encoder = [z9, z6, z3] 475 476 x = self.base(z12) 477 x = self.decoder(x, encoder_inputs=updated_from_encoder) 478 x = self.deconv_out(x) 479 480 x = torch.cat([x, z0], dim=1) 481 x = self.decoder_head(x) 482 483 x = self.out_conv(x) 484 if self.final_activation is not None: 485 x = self.final_activation(x) 486 487 x = self.postprocess_masks(x, input_shape, original_shape) 488 return x 489 490 491class UNETR2D(UNETR): 492 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 493 """ 494 pass 495 496 497class UNETR3D(UNETRBase): 498 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 499 """ 500 def __init__( 501 self, 502 img_size: int = 1024, 503 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 504 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 505 decoder: Optional[nn.Module] = None, 506 out_channels: int = 1, 507 use_sam_stats: bool = False, 508 use_mae_stats: bool = False, 509 use_dino_stats: bool = False, 510 resize_input: bool = True, 511 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 512 final_activation: Optional[Union[str, nn.Module]] = None, 513 use_skip_connection: bool = False, 514 embed_dim: Optional[int] = None, 515 use_conv_transpose: bool = False, 516 use_strip_pooling: bool = True, 517 **kwargs 518 ): 519 if use_skip_connection: 520 raise NotImplementedError("The framework cannot handle skip connections atm.") 521 if use_conv_transpose: 522 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 523 524 # Sort the `embed_dim` out 525 embed_dim = 256 if embed_dim is None else embed_dim 526 527 super().__init__( 528 img_size=img_size, 529 backbone=backbone, 530 encoder=encoder, 531 decoder=decoder, 532 out_channels=out_channels, 533 use_sam_stats=use_sam_stats, 534 use_mae_stats=use_mae_stats, 535 use_dino_stats=use_dino_stats, 536 resize_input=resize_input, 537 encoder_checkpoint=encoder_checkpoint, 538 final_activation=final_activation, 539 use_skip_connection=use_skip_connection, 540 embed_dim=embed_dim, 541 use_conv_transpose=use_conv_transpose, 542 **kwargs, 543 ) 544 545 # The 3d convolutional decoder. 546 # First, get the important parameters for the decoder. 547 depth = 3 548 initial_features = 64 549 gain = 2 550 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 551 scale_factors = [1, 2, 2] 552 self.out_channels = out_channels 553 554 # The mapping blocks. 555 self.deconv1 = Deconv3DBlock( 556 in_channels=embed_dim, 557 out_channels=features_decoder[0], 558 scale_factor=scale_factors, 559 use_strip_pooling=use_strip_pooling, 560 ) 561 self.deconv2 = Deconv3DBlock( 562 in_channels=features_decoder[0], 563 out_channels=features_decoder[1], 564 scale_factor=scale_factors, 565 use_strip_pooling=use_strip_pooling, 566 ) 567 self.deconv3 = Deconv3DBlock( 568 in_channels=features_decoder[1], 569 out_channels=features_decoder[2], 570 scale_factor=scale_factors, 571 use_strip_pooling=use_strip_pooling, 572 ) 573 self.deconv4 = Deconv3DBlock( 574 in_channels=features_decoder[2], 575 out_channels=features_decoder[3], 576 scale_factor=scale_factors, 577 use_strip_pooling=use_strip_pooling, 578 ) 579 580 # The core decoder block. 581 self.decoder = decoder or Decoder( 582 features=features_decoder, 583 scale_factors=[scale_factors] * depth, 584 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 585 sampler_impl=Upsampler3d, 586 ) 587 588 # And the final upsampler to match the expected dimensions. 589 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 590 in_channels=features_decoder[-1], 591 out_channels=features_decoder[-1], 592 scale_factor=scale_factors, 593 use_strip_pooling=use_strip_pooling, 594 ) 595 596 # Additional conjunction blocks. 597 self.base = ConvBlock3dWithStrip( 598 in_channels=embed_dim, 599 out_channels=features_decoder[0], 600 use_strip_pooling=use_strip_pooling, 601 ) 602 603 # And the output layers. 604 self.decoder_head = ConvBlock3dWithStrip( 605 in_channels=2 * features_decoder[-1], 606 out_channels=features_decoder[-1], 607 use_strip_pooling=use_strip_pooling, 608 ) 609 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 610 611 def forward(self, x: torch.Tensor): 612 """Forward pass of the UNETR-3D model. 613 614 Args: 615 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 616 617 Returns: 618 The UNETR output. 619 """ 620 B, C, Z, H, W = x.shape 621 original_shape = (Z, H, W) 622 623 # Preprocessing step 624 x, input_shape = self.preprocess(x) 625 626 # Run the image encoder. 627 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 628 629 # Prepare the counterparts for the decoder. 630 # NOTE: The section below is sequential, there's no skip connections atm. 631 z9 = self.deconv1(curr_features) 632 z6 = self.deconv2(z9) 633 z3 = self.deconv3(z6) 634 z0 = self.deconv4(z3) 635 636 updated_from_encoder = [z9, z6, z3] 637 638 # Align the features through the base block. 639 x = self.base(curr_features) 640 # Run the decoder 641 x = self.decoder(x, encoder_inputs=updated_from_encoder) 642 x = self.deconv_out(x) # NOTE before `end_up` 643 644 # And the final output head. 645 x = torch.cat([x, z0], dim=1) 646 x = self.decoder_head(x) 647 x = self.out_conv(x) 648 if self.final_activation is not None: 649 x = self.final_activation(x) 650 651 # Postprocess the output back to original size. 652 x = self.postprocess_masks(x, input_shape, original_shape) 653 return x 654 655# 656# ADDITIONAL FUNCTIONALITIES 657# 658 659 660def _strip_pooling_layers(enabled, channels) -> nn.Module: 661 return DepthStripPooling(channels) if enabled else nn.Identity() 662 663 664class DepthStripPooling(nn.Module): 665 """@private 666 """ 667 def __init__(self, channels: int, reduction: int = 4): 668 """Block for strip pooling along the depth dimension (only). 669 670 eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling 671 to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to 672 modulate the original features (using a gated residual). 673 674 For 2D (Z == 1 or Z == 3): returns input unchanged (no-op). 675 676 Args: 677 channels: The output channels. 678 reduction: The reduction of the hidden layers. 679 """ 680 super().__init__() 681 hidden = max(1, channels // reduction) 682 self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1) 683 self.bn1 = nn.BatchNorm3d(hidden) 684 self.relu = nn.ReLU(inplace=True) 685 self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1) 686 687 def forward(self, x: torch.Tensor) -> torch.Tensor: 688 if x.dim() != 5: 689 raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.") 690 691 B, C, Z, H, W = x.shape 692 if Z == 1 or Z == 3: # i.e. 2d-as-1-slice or RGB_2d-as-1-slice. 693 return x # We simply do nothing there. 694 695 # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W) 696 feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W)) 697 feat = self.conv1(feat) 698 feat = self.bn1(feat) 699 feat = self.relu(feat) 700 feat = self.conv2(feat) 701 gate = torch.sigmoid(feat).expand(B, C, Z, H, W) # Broadcast the collapsed depth context back to all slices 702 703 # Gated residual fusion 704 return x * gate + x 705 706 707class Deconv3DBlock(nn.Module): 708 """@private 709 """ 710 def __init__( 711 self, 712 scale_factor, 713 in_channels, 714 out_channels, 715 kernel_size=3, 716 anisotropic_kernel=True, 717 use_strip_pooling=True, 718 ): 719 super().__init__() 720 conv_block_kwargs = { 721 "in_channels": out_channels, 722 "out_channels": out_channels, 723 "kernel_size": kernel_size, 724 "padding": ((kernel_size - 1) // 2), 725 } 726 if anisotropic_kernel: 727 conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor) 728 729 self.block = nn.Sequential( 730 Upsampler3d(scale_factor, in_channels, out_channels), 731 nn.Conv3d(**conv_block_kwargs), 732 nn.BatchNorm3d(out_channels), 733 nn.ReLU(True), 734 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 735 ) 736 737 def forward(self, x): 738 return self.block(x) 739 740 741class ConvBlock3dWithStrip(nn.Module): 742 """@private 743 """ 744 def __init__( 745 self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs 746 ): 747 super().__init__() 748 self.block = nn.Sequential( 749 ConvBlock3d(in_channels, out_channels, **kwargs), 750 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 751 ) 752 753 def forward(self, x): 754 return self.block(x) 755 756 757class SingleDeconv2DBlock(nn.Module): 758 """@private 759 """ 760 def __init__(self, scale_factor, in_channels, out_channels): 761 super().__init__() 762 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 763 764 def forward(self, x): 765 return self.block(x) 766 767 768class SingleConv2DBlock(nn.Module): 769 """@private 770 """ 771 def __init__(self, in_channels, out_channels, kernel_size): 772 super().__init__() 773 self.block = nn.Conv2d( 774 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 775 ) 776 777 def forward(self, x): 778 return self.block(x) 779 780 781class Conv2DBlock(nn.Module): 782 """@private 783 """ 784 def __init__(self, in_channels, out_channels, kernel_size=3): 785 super().__init__() 786 self.block = nn.Sequential( 787 SingleConv2DBlock(in_channels, out_channels, kernel_size), 788 nn.BatchNorm2d(out_channels), 789 nn.ReLU(True) 790 ) 791 792 def forward(self, x): 793 return self.block(x) 794 795 796class Deconv2DBlock(nn.Module): 797 """@private 798 """ 799 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 800 super().__init__() 801 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 802 self.block = nn.Sequential( 803 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 804 SingleConv2DBlock(out_channels, out_channels, kernel_size), 805 nn.BatchNorm2d(out_channels), 806 nn.ReLU(True) 807 ) 808 809 def forward(self, x): 810 return self.block(x)
29class UNETRBase(nn.Module): 30 """Base class for implementing a UNETR. 31 32 Args: 33 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 34 backbone: The name of the vision transformer implementation. 35 One of "sam", "sam2", "mae", "scalemae", "dinov2", "dinov3". 36 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 37 decoder: The convolutional decoder. 38 out_channels: The number of output channels of the UNETR. 39 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 40 use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model. 41 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 42 resize_input: Whether to resize the input images to match `img_size`. 43 By default, it resizes the inputs to match the `img_size`. 44 encoder_checkpoint: Checkpoint for initializing the vision transformer. 45 Can either be a filepath or an already loaded checkpoint. 46 final_activation: The activation to apply to the UNETR output. 47 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 48 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 49 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 50 By default, it uses resampling for upsampling. 51 """ 52 def __init__( 53 self, 54 img_size: int = 1024, 55 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 56 encoder: Optional[Union[nn.Module, str]] = "vit_b", 57 decoder: Optional[nn.Module] = None, 58 out_channels: int = 1, 59 use_sam_stats: bool = False, 60 use_mae_stats: bool = False, 61 use_dino_stats: bool = False, 62 resize_input: bool = True, 63 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 64 final_activation: Optional[Union[str, nn.Module]] = None, 65 use_skip_connection: bool = True, 66 embed_dim: Optional[int] = None, 67 use_conv_transpose: bool = False, 68 **kwargs 69 ) -> None: 70 super().__init__() 71 72 self.img_size = img_size 73 self.use_sam_stats = use_sam_stats 74 self.use_mae_stats = use_mae_stats 75 self.use_dino_stats = use_dino_stats 76 self.use_skip_connection = use_skip_connection 77 self.resize_input = resize_input 78 self.use_conv_transpose = use_conv_transpose 79 self.backbone = backbone 80 81 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" 82 print(f"Using {encoder} from {backbone.upper()}") 83 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 84 85 if encoder_checkpoint is not None: 86 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 87 88 if embed_dim is None: 89 embed_dim = self.encoder.embed_dim 90 91 else: # `nn.Module` ViT backbone 92 self.encoder = encoder 93 94 have_neck = False 95 for name, _ in self.encoder.named_parameters(): 96 if name.startswith("neck"): 97 have_neck = True 98 99 if embed_dim is None: 100 if have_neck: 101 embed_dim = self.encoder.neck[2].out_channels # the value is 256 102 else: 103 embed_dim = self.encoder.patch_embed.proj.out_channels 104 105 self.embed_dim = embed_dim 106 self.final_activation = self._get_activation(final_activation) 107 108 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 109 """Function to load pretrained weights to the image encoder. 110 """ 111 if isinstance(checkpoint, str): 112 if backbone == "sam" and isinstance(encoder, str): 113 # If we have a SAM encoder, then we first try to load the full SAM Model 114 # (using micro_sam) and otherwise fall back on directly loading the encoder state 115 # from the checkpoint 116 try: 117 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 118 encoder_state = model.image_encoder.state_dict() 119 except Exception: 120 # Try loading the encoder state directly from a checkpoint. 121 encoder_state = torch.load(checkpoint, weights_only=False) 122 123 elif backbone == "sam2" and isinstance(encoder, str): 124 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 125 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 126 # from the checkpoint 127 try: 128 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 129 encoder_state = model.image_encoder.state_dict() 130 except Exception: 131 # Try loading the encoder state directly from a checkpoint. 132 encoder_state = torch.load(checkpoint, weights_only=False) 133 134 elif backbone == "mae": 135 # vit initialization hints from: 136 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 137 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 138 encoder_state = OrderedDict({ 139 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 140 }) 141 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 142 current_encoder_state = self.encoder.state_dict() 143 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 144 del self.encoder.head 145 146 elif backbone == "scalemae": 147 # Load the encoder state directly from a checkpoint. 148 encoder_state = torch.load(checkpoint)["model"] 149 encoder_state = OrderedDict({ 150 k: v for k, v in encoder_state.items() 151 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 152 }) 153 154 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 155 current_encoder_state = self.encoder.state_dict() 156 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 157 del self.encoder.head 158 159 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 160 del self.encoder.pos_embed 161 162 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 163 encoder_state = torch.load(checkpoint) 164 165 else: 166 raise ValueError( 167 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 168 ) 169 170 else: 171 encoder_state = checkpoint 172 173 self.encoder.load_state_dict(encoder_state) 174 175 def _get_activation(self, activation): 176 return_activation = None 177 if activation is None: 178 return None 179 if isinstance(activation, nn.Module): 180 return activation 181 if isinstance(activation, str): 182 return_activation = getattr(nn, activation, None) 183 if return_activation is None: 184 raise ValueError(f"Invalid activation: {activation}") 185 186 return return_activation() 187 188 @staticmethod 189 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 190 """Compute the output size given input size and target long side length. 191 192 Args: 193 oldh: The input image height. 194 oldw: The input image width. 195 long_side_length: The longest side length for resizing. 196 197 Returns: 198 The new image height. 199 The new image width. 200 """ 201 scale = long_side_length * 1.0 / max(oldh, oldw) 202 newh, neww = oldh * scale, oldw * scale 203 neww = int(neww + 0.5) 204 newh = int(newh + 0.5) 205 return (newh, neww) 206 207 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 208 """Resize the image so that the longest side has the correct length. 209 210 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 211 212 Args: 213 image: The input image. 214 215 Returns: 216 The resized image. 217 """ 218 if image.ndim == 4: # i.e. 2d image 219 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 220 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 221 elif image.ndim == 5: # i.e. 3d volume 222 B, C, Z, H, W = image.shape 223 target_size = self.get_preprocess_shape(H, W, self.img_size) 224 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 225 else: 226 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 227 228 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 229 """@private 230 """ 231 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 232 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 233 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 234 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 235 return pixel_mean, pixel_std 236 237 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 238 """@private 239 """ 240 device = x.device 241 is_3d = (x.ndim == 5) 242 device, dtype = x.device, x.dtype 243 244 if self.use_sam_stats: 245 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 246 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 247 raise NotImplementedError 248 elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"): 249 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 250 else: 251 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 252 253 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 254 255 if self.resize_input: 256 x = self.resize_longest_side(x) 257 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 258 259 x = (x - pixel_mean) / pixel_std 260 h, w = x.shape[-2:] 261 padh = self.encoder.img_size - h 262 padw = self.encoder.img_size - w 263 264 if is_3d: 265 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 266 else: 267 x = F.pad(x, (0, padw, 0, padh)) 268 269 return x, input_shape 270 271 def postprocess_masks( 272 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 273 ) -> torch.Tensor: 274 """@private 275 """ 276 if masks.ndim == 4: # i.e. 2d labels 277 masks = F.interpolate( 278 masks, 279 (self.encoder.img_size, self.encoder.img_size), 280 mode="bilinear", 281 align_corners=False, 282 ) 283 masks = masks[..., : input_size[0], : input_size[1]] 284 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 285 286 elif masks.ndim == 5: # i.e. 3d volumetric labels 287 masks = F.interpolate( 288 masks, 289 (input_size[0], self.img_size, self.img_size), 290 mode="trilinear", 291 align_corners=False, 292 ) 293 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 294 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 295 296 else: 297 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 298 299 return masks
Base class for implementing a UNETR.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam", "sam2", "mae", "scalemae", "dinov2", "dinov3".
- encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
- use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv3 model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size. By default, it resizes the inputs to match theimg_size. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
52 def __init__( 53 self, 54 img_size: int = 1024, 55 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 56 encoder: Optional[Union[nn.Module, str]] = "vit_b", 57 decoder: Optional[nn.Module] = None, 58 out_channels: int = 1, 59 use_sam_stats: bool = False, 60 use_mae_stats: bool = False, 61 use_dino_stats: bool = False, 62 resize_input: bool = True, 63 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 64 final_activation: Optional[Union[str, nn.Module]] = None, 65 use_skip_connection: bool = True, 66 embed_dim: Optional[int] = None, 67 use_conv_transpose: bool = False, 68 **kwargs 69 ) -> None: 70 super().__init__() 71 72 self.img_size = img_size 73 self.use_sam_stats = use_sam_stats 74 self.use_mae_stats = use_mae_stats 75 self.use_dino_stats = use_dino_stats 76 self.use_skip_connection = use_skip_connection 77 self.resize_input = resize_input 78 self.use_conv_transpose = use_conv_transpose 79 self.backbone = backbone 80 81 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" 82 print(f"Using {encoder} from {backbone.upper()}") 83 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 84 85 if encoder_checkpoint is not None: 86 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 87 88 if embed_dim is None: 89 embed_dim = self.encoder.embed_dim 90 91 else: # `nn.Module` ViT backbone 92 self.encoder = encoder 93 94 have_neck = False 95 for name, _ in self.encoder.named_parameters(): 96 if name.startswith("neck"): 97 have_neck = True 98 99 if embed_dim is None: 100 if have_neck: 101 embed_dim = self.encoder.neck[2].out_channels # the value is 256 102 else: 103 embed_dim = self.encoder.patch_embed.proj.out_channels 104 105 self.embed_dim = embed_dim 106 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
188 @staticmethod 189 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 190 """Compute the output size given input size and target long side length. 191 192 Args: 193 oldh: The input image height. 194 oldw: The input image width. 195 long_side_length: The longest side length for resizing. 196 197 Returns: 198 The new image height. 199 The new image width. 200 """ 201 scale = long_side_length * 1.0 / max(oldh, oldw) 202 newh, neww = oldh * scale, oldw * scale 203 neww = int(neww + 0.5) 204 newh = int(newh + 0.5) 205 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
207 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 208 """Resize the image so that the longest side has the correct length. 209 210 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 211 212 Args: 213 image: The input image. 214 215 Returns: 216 The resized image. 217 """ 218 if image.ndim == 4: # i.e. 2d image 219 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 220 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 221 elif image.ndim == 5: # i.e. 3d volume 222 B, C, Z, H, W = image.shape 223 target_size = self.get_preprocess_shape(H, W, self.img_size) 224 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 225 else: 226 raise ValueError("Expected 4d or 5d inputs, got", image.shape)
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
302class UNETR(UNETRBase): 303 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 304 """ 305 def __init__( 306 self, 307 img_size: int = 1024, 308 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 309 encoder: Optional[Union[nn.Module, str]] = "vit_b", 310 decoder: Optional[nn.Module] = None, 311 out_channels: int = 1, 312 use_sam_stats: bool = False, 313 use_mae_stats: bool = False, 314 use_dino_stats: bool = False, 315 resize_input: bool = True, 316 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 317 final_activation: Optional[Union[str, nn.Module]] = None, 318 use_skip_connection: bool = True, 319 embed_dim: Optional[int] = None, 320 use_conv_transpose: bool = False, 321 **kwargs 322 ) -> None: 323 324 super().__init__( 325 img_size=img_size, 326 backbone=backbone, 327 encoder=encoder, 328 decoder=decoder, 329 out_channels=out_channels, 330 use_sam_stats=use_sam_stats, 331 use_mae_stats=use_mae_stats, 332 use_dino_stats=use_dino_stats, 333 resize_input=resize_input, 334 encoder_checkpoint=encoder_checkpoint, 335 final_activation=final_activation, 336 use_skip_connection=use_skip_connection, 337 embed_dim=embed_dim, 338 use_conv_transpose=use_conv_transpose, 339 **kwargs, 340 ) 341 342 encoder = self.encoder 343 344 if backbone == "sam2" and hasattr(encoder, "trunk"): 345 in_chans = encoder.trunk.patch_embed.proj.in_channels 346 elif hasattr(encoder, "in_chans"): 347 in_chans = encoder.in_chans 348 else: # `nn.Module` ViT backbone. 349 try: 350 in_chans = encoder.patch_embed.proj.in_channels 351 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 352 in_chans = encoder.patch_embed.seq[0].c.in_channels 353 354 # parameters for the decoder network 355 depth = 3 356 initial_features = 64 357 gain = 2 358 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 359 scale_factors = depth * [2] 360 self.out_channels = out_channels 361 362 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 363 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 364 365 self.decoder = decoder or Decoder( 366 features=features_decoder, 367 scale_factors=scale_factors[::-1], 368 conv_block_impl=ConvBlock2d, 369 sampler_impl=_upsampler, 370 ) 371 372 if use_skip_connection: 373 self.deconv1 = Deconv2DBlock( 374 in_channels=self.embed_dim, 375 out_channels=features_decoder[0], 376 use_conv_transpose=use_conv_transpose, 377 ) 378 self.deconv2 = nn.Sequential( 379 Deconv2DBlock( 380 in_channels=self.embed_dim, 381 out_channels=features_decoder[0], 382 use_conv_transpose=use_conv_transpose, 383 ), 384 Deconv2DBlock( 385 in_channels=features_decoder[0], 386 out_channels=features_decoder[1], 387 use_conv_transpose=use_conv_transpose, 388 ) 389 ) 390 self.deconv3 = nn.Sequential( 391 Deconv2DBlock( 392 in_channels=self.embed_dim, 393 out_channels=features_decoder[0], 394 use_conv_transpose=use_conv_transpose, 395 ), 396 Deconv2DBlock( 397 in_channels=features_decoder[0], 398 out_channels=features_decoder[1], 399 use_conv_transpose=use_conv_transpose, 400 ), 401 Deconv2DBlock( 402 in_channels=features_decoder[1], 403 out_channels=features_decoder[2], 404 use_conv_transpose=use_conv_transpose, 405 ) 406 ) 407 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 408 else: 409 self.deconv1 = Deconv2DBlock( 410 in_channels=self.embed_dim, 411 out_channels=features_decoder[0], 412 use_conv_transpose=use_conv_transpose, 413 ) 414 self.deconv2 = Deconv2DBlock( 415 in_channels=features_decoder[0], 416 out_channels=features_decoder[1], 417 use_conv_transpose=use_conv_transpose, 418 ) 419 self.deconv3 = Deconv2DBlock( 420 in_channels=features_decoder[1], 421 out_channels=features_decoder[2], 422 use_conv_transpose=use_conv_transpose, 423 ) 424 self.deconv4 = Deconv2DBlock( 425 in_channels=features_decoder[2], 426 out_channels=features_decoder[3], 427 use_conv_transpose=use_conv_transpose, 428 ) 429 430 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 431 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 432 self.deconv_out = _upsampler( 433 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 434 ) 435 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 436 437 def forward(self, x: torch.Tensor) -> torch.Tensor: 438 """Apply the UNETR to the input data. 439 440 Args: 441 x: The input tensor. 442 443 Returns: 444 The UNETR output. 445 """ 446 original_shape = x.shape[-2:] 447 448 # Reshape the inputs to the shape expected by the encoder 449 # and normalize the inputs if normalization is part of the model. 450 x, input_shape = self.preprocess(x) 451 452 encoder_outputs = self.encoder(x) 453 454 if isinstance(encoder_outputs[-1], list): 455 # `encoder_outputs` can be arranged in only two forms: 456 # - either we only return the image embeddings 457 # - or, we return the image embeddings and the "list" of global attention layers 458 z12, from_encoder = encoder_outputs 459 else: 460 z12 = encoder_outputs 461 462 if self.use_skip_connection: 463 from_encoder = from_encoder[::-1] 464 z9 = self.deconv1(from_encoder[0]) 465 z6 = self.deconv2(from_encoder[1]) 466 z3 = self.deconv3(from_encoder[2]) 467 z0 = self.deconv4(x) 468 469 else: 470 z9 = self.deconv1(z12) 471 z6 = self.deconv2(z9) 472 z3 = self.deconv3(z6) 473 z0 = self.deconv4(z3) 474 475 updated_from_encoder = [z9, z6, z3] 476 477 x = self.base(z12) 478 x = self.decoder(x, encoder_inputs=updated_from_encoder) 479 x = self.deconv_out(x) 480 481 x = torch.cat([x, z0], dim=1) 482 x = self.decoder_head(x) 483 484 x = self.out_conv(x) 485 if self.final_activation is not None: 486 x = self.final_activation(x) 487 488 x = self.postprocess_masks(x, input_shape, original_shape) 489 return x
A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
305 def __init__( 306 self, 307 img_size: int = 1024, 308 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 309 encoder: Optional[Union[nn.Module, str]] = "vit_b", 310 decoder: Optional[nn.Module] = None, 311 out_channels: int = 1, 312 use_sam_stats: bool = False, 313 use_mae_stats: bool = False, 314 use_dino_stats: bool = False, 315 resize_input: bool = True, 316 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 317 final_activation: Optional[Union[str, nn.Module]] = None, 318 use_skip_connection: bool = True, 319 embed_dim: Optional[int] = None, 320 use_conv_transpose: bool = False, 321 **kwargs 322 ) -> None: 323 324 super().__init__( 325 img_size=img_size, 326 backbone=backbone, 327 encoder=encoder, 328 decoder=decoder, 329 out_channels=out_channels, 330 use_sam_stats=use_sam_stats, 331 use_mae_stats=use_mae_stats, 332 use_dino_stats=use_dino_stats, 333 resize_input=resize_input, 334 encoder_checkpoint=encoder_checkpoint, 335 final_activation=final_activation, 336 use_skip_connection=use_skip_connection, 337 embed_dim=embed_dim, 338 use_conv_transpose=use_conv_transpose, 339 **kwargs, 340 ) 341 342 encoder = self.encoder 343 344 if backbone == "sam2" and hasattr(encoder, "trunk"): 345 in_chans = encoder.trunk.patch_embed.proj.in_channels 346 elif hasattr(encoder, "in_chans"): 347 in_chans = encoder.in_chans 348 else: # `nn.Module` ViT backbone. 349 try: 350 in_chans = encoder.patch_embed.proj.in_channels 351 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 352 in_chans = encoder.patch_embed.seq[0].c.in_channels 353 354 # parameters for the decoder network 355 depth = 3 356 initial_features = 64 357 gain = 2 358 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 359 scale_factors = depth * [2] 360 self.out_channels = out_channels 361 362 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 363 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 364 365 self.decoder = decoder or Decoder( 366 features=features_decoder, 367 scale_factors=scale_factors[::-1], 368 conv_block_impl=ConvBlock2d, 369 sampler_impl=_upsampler, 370 ) 371 372 if use_skip_connection: 373 self.deconv1 = Deconv2DBlock( 374 in_channels=self.embed_dim, 375 out_channels=features_decoder[0], 376 use_conv_transpose=use_conv_transpose, 377 ) 378 self.deconv2 = nn.Sequential( 379 Deconv2DBlock( 380 in_channels=self.embed_dim, 381 out_channels=features_decoder[0], 382 use_conv_transpose=use_conv_transpose, 383 ), 384 Deconv2DBlock( 385 in_channels=features_decoder[0], 386 out_channels=features_decoder[1], 387 use_conv_transpose=use_conv_transpose, 388 ) 389 ) 390 self.deconv3 = nn.Sequential( 391 Deconv2DBlock( 392 in_channels=self.embed_dim, 393 out_channels=features_decoder[0], 394 use_conv_transpose=use_conv_transpose, 395 ), 396 Deconv2DBlock( 397 in_channels=features_decoder[0], 398 out_channels=features_decoder[1], 399 use_conv_transpose=use_conv_transpose, 400 ), 401 Deconv2DBlock( 402 in_channels=features_decoder[1], 403 out_channels=features_decoder[2], 404 use_conv_transpose=use_conv_transpose, 405 ) 406 ) 407 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 408 else: 409 self.deconv1 = Deconv2DBlock( 410 in_channels=self.embed_dim, 411 out_channels=features_decoder[0], 412 use_conv_transpose=use_conv_transpose, 413 ) 414 self.deconv2 = Deconv2DBlock( 415 in_channels=features_decoder[0], 416 out_channels=features_decoder[1], 417 use_conv_transpose=use_conv_transpose, 418 ) 419 self.deconv3 = Deconv2DBlock( 420 in_channels=features_decoder[1], 421 out_channels=features_decoder[2], 422 use_conv_transpose=use_conv_transpose, 423 ) 424 self.deconv4 = Deconv2DBlock( 425 in_channels=features_decoder[2], 426 out_channels=features_decoder[3], 427 use_conv_transpose=use_conv_transpose, 428 ) 429 430 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 431 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 432 self.deconv_out = _upsampler( 433 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 434 ) 435 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
437 def forward(self, x: torch.Tensor) -> torch.Tensor: 438 """Apply the UNETR to the input data. 439 440 Args: 441 x: The input tensor. 442 443 Returns: 444 The UNETR output. 445 """ 446 original_shape = x.shape[-2:] 447 448 # Reshape the inputs to the shape expected by the encoder 449 # and normalize the inputs if normalization is part of the model. 450 x, input_shape = self.preprocess(x) 451 452 encoder_outputs = self.encoder(x) 453 454 if isinstance(encoder_outputs[-1], list): 455 # `encoder_outputs` can be arranged in only two forms: 456 # - either we only return the image embeddings 457 # - or, we return the image embeddings and the "list" of global attention layers 458 z12, from_encoder = encoder_outputs 459 else: 460 z12 = encoder_outputs 461 462 if self.use_skip_connection: 463 from_encoder = from_encoder[::-1] 464 z9 = self.deconv1(from_encoder[0]) 465 z6 = self.deconv2(from_encoder[1]) 466 z3 = self.deconv3(from_encoder[2]) 467 z0 = self.deconv4(x) 468 469 else: 470 z9 = self.deconv1(z12) 471 z6 = self.deconv2(z9) 472 z3 = self.deconv3(z6) 473 z0 = self.deconv4(z3) 474 475 updated_from_encoder = [z9, z6, z3] 476 477 x = self.base(z12) 478 x = self.decoder(x, encoder_inputs=updated_from_encoder) 479 x = self.deconv_out(x) 480 481 x = torch.cat([x, z0], dim=1) 482 x = self.decoder_head(x) 483 484 x = self.out_conv(x) 485 if self.final_activation is not None: 486 x = self.final_activation(x) 487 488 x = self.postprocess_masks(x, input_shape, original_shape) 489 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.
492class UNETR2D(UNETR): 493 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 494 """ 495 pass
A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
498class UNETR3D(UNETRBase): 499 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 500 """ 501 def __init__( 502 self, 503 img_size: int = 1024, 504 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 505 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 506 decoder: Optional[nn.Module] = None, 507 out_channels: int = 1, 508 use_sam_stats: bool = False, 509 use_mae_stats: bool = False, 510 use_dino_stats: bool = False, 511 resize_input: bool = True, 512 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 513 final_activation: Optional[Union[str, nn.Module]] = None, 514 use_skip_connection: bool = False, 515 embed_dim: Optional[int] = None, 516 use_conv_transpose: bool = False, 517 use_strip_pooling: bool = True, 518 **kwargs 519 ): 520 if use_skip_connection: 521 raise NotImplementedError("The framework cannot handle skip connections atm.") 522 if use_conv_transpose: 523 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 524 525 # Sort the `embed_dim` out 526 embed_dim = 256 if embed_dim is None else embed_dim 527 528 super().__init__( 529 img_size=img_size, 530 backbone=backbone, 531 encoder=encoder, 532 decoder=decoder, 533 out_channels=out_channels, 534 use_sam_stats=use_sam_stats, 535 use_mae_stats=use_mae_stats, 536 use_dino_stats=use_dino_stats, 537 resize_input=resize_input, 538 encoder_checkpoint=encoder_checkpoint, 539 final_activation=final_activation, 540 use_skip_connection=use_skip_connection, 541 embed_dim=embed_dim, 542 use_conv_transpose=use_conv_transpose, 543 **kwargs, 544 ) 545 546 # The 3d convolutional decoder. 547 # First, get the important parameters for the decoder. 548 depth = 3 549 initial_features = 64 550 gain = 2 551 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 552 scale_factors = [1, 2, 2] 553 self.out_channels = out_channels 554 555 # The mapping blocks. 556 self.deconv1 = Deconv3DBlock( 557 in_channels=embed_dim, 558 out_channels=features_decoder[0], 559 scale_factor=scale_factors, 560 use_strip_pooling=use_strip_pooling, 561 ) 562 self.deconv2 = Deconv3DBlock( 563 in_channels=features_decoder[0], 564 out_channels=features_decoder[1], 565 scale_factor=scale_factors, 566 use_strip_pooling=use_strip_pooling, 567 ) 568 self.deconv3 = Deconv3DBlock( 569 in_channels=features_decoder[1], 570 out_channels=features_decoder[2], 571 scale_factor=scale_factors, 572 use_strip_pooling=use_strip_pooling, 573 ) 574 self.deconv4 = Deconv3DBlock( 575 in_channels=features_decoder[2], 576 out_channels=features_decoder[3], 577 scale_factor=scale_factors, 578 use_strip_pooling=use_strip_pooling, 579 ) 580 581 # The core decoder block. 582 self.decoder = decoder or Decoder( 583 features=features_decoder, 584 scale_factors=[scale_factors] * depth, 585 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 586 sampler_impl=Upsampler3d, 587 ) 588 589 # And the final upsampler to match the expected dimensions. 590 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 591 in_channels=features_decoder[-1], 592 out_channels=features_decoder[-1], 593 scale_factor=scale_factors, 594 use_strip_pooling=use_strip_pooling, 595 ) 596 597 # Additional conjunction blocks. 598 self.base = ConvBlock3dWithStrip( 599 in_channels=embed_dim, 600 out_channels=features_decoder[0], 601 use_strip_pooling=use_strip_pooling, 602 ) 603 604 # And the output layers. 605 self.decoder_head = ConvBlock3dWithStrip( 606 in_channels=2 * features_decoder[-1], 607 out_channels=features_decoder[-1], 608 use_strip_pooling=use_strip_pooling, 609 ) 610 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 611 612 def forward(self, x: torch.Tensor): 613 """Forward pass of the UNETR-3D model. 614 615 Args: 616 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 617 618 Returns: 619 The UNETR output. 620 """ 621 B, C, Z, H, W = x.shape 622 original_shape = (Z, H, W) 623 624 # Preprocessing step 625 x, input_shape = self.preprocess(x) 626 627 # Run the image encoder. 628 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 629 630 # Prepare the counterparts for the decoder. 631 # NOTE: The section below is sequential, there's no skip connections atm. 632 z9 = self.deconv1(curr_features) 633 z6 = self.deconv2(z9) 634 z3 = self.deconv3(z6) 635 z0 = self.deconv4(z3) 636 637 updated_from_encoder = [z9, z6, z3] 638 639 # Align the features through the base block. 640 x = self.base(curr_features) 641 # Run the decoder 642 x = self.decoder(x, encoder_inputs=updated_from_encoder) 643 x = self.deconv_out(x) # NOTE before `end_up` 644 645 # And the final output head. 646 x = torch.cat([x, z0], dim=1) 647 x = self.decoder_head(x) 648 x = self.out_conv(x) 649 if self.final_activation is not None: 650 x = self.final_activation(x) 651 652 # Postprocess the output back to original size. 653 x = self.postprocess_masks(x, input_shape, original_shape) 654 return x
A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
501 def __init__( 502 self, 503 img_size: int = 1024, 504 backbone: Literal["sam", "sam2", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 505 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 506 decoder: Optional[nn.Module] = None, 507 out_channels: int = 1, 508 use_sam_stats: bool = False, 509 use_mae_stats: bool = False, 510 use_dino_stats: bool = False, 511 resize_input: bool = True, 512 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 513 final_activation: Optional[Union[str, nn.Module]] = None, 514 use_skip_connection: bool = False, 515 embed_dim: Optional[int] = None, 516 use_conv_transpose: bool = False, 517 use_strip_pooling: bool = True, 518 **kwargs 519 ): 520 if use_skip_connection: 521 raise NotImplementedError("The framework cannot handle skip connections atm.") 522 if use_conv_transpose: 523 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 524 525 # Sort the `embed_dim` out 526 embed_dim = 256 if embed_dim is None else embed_dim 527 528 super().__init__( 529 img_size=img_size, 530 backbone=backbone, 531 encoder=encoder, 532 decoder=decoder, 533 out_channels=out_channels, 534 use_sam_stats=use_sam_stats, 535 use_mae_stats=use_mae_stats, 536 use_dino_stats=use_dino_stats, 537 resize_input=resize_input, 538 encoder_checkpoint=encoder_checkpoint, 539 final_activation=final_activation, 540 use_skip_connection=use_skip_connection, 541 embed_dim=embed_dim, 542 use_conv_transpose=use_conv_transpose, 543 **kwargs, 544 ) 545 546 # The 3d convolutional decoder. 547 # First, get the important parameters for the decoder. 548 depth = 3 549 initial_features = 64 550 gain = 2 551 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 552 scale_factors = [1, 2, 2] 553 self.out_channels = out_channels 554 555 # The mapping blocks. 556 self.deconv1 = Deconv3DBlock( 557 in_channels=embed_dim, 558 out_channels=features_decoder[0], 559 scale_factor=scale_factors, 560 use_strip_pooling=use_strip_pooling, 561 ) 562 self.deconv2 = Deconv3DBlock( 563 in_channels=features_decoder[0], 564 out_channels=features_decoder[1], 565 scale_factor=scale_factors, 566 use_strip_pooling=use_strip_pooling, 567 ) 568 self.deconv3 = Deconv3DBlock( 569 in_channels=features_decoder[1], 570 out_channels=features_decoder[2], 571 scale_factor=scale_factors, 572 use_strip_pooling=use_strip_pooling, 573 ) 574 self.deconv4 = Deconv3DBlock( 575 in_channels=features_decoder[2], 576 out_channels=features_decoder[3], 577 scale_factor=scale_factors, 578 use_strip_pooling=use_strip_pooling, 579 ) 580 581 # The core decoder block. 582 self.decoder = decoder or Decoder( 583 features=features_decoder, 584 scale_factors=[scale_factors] * depth, 585 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 586 sampler_impl=Upsampler3d, 587 ) 588 589 # And the final upsampler to match the expected dimensions. 590 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 591 in_channels=features_decoder[-1], 592 out_channels=features_decoder[-1], 593 scale_factor=scale_factors, 594 use_strip_pooling=use_strip_pooling, 595 ) 596 597 # Additional conjunction blocks. 598 self.base = ConvBlock3dWithStrip( 599 in_channels=embed_dim, 600 out_channels=features_decoder[0], 601 use_strip_pooling=use_strip_pooling, 602 ) 603 604 # And the output layers. 605 self.decoder_head = ConvBlock3dWithStrip( 606 in_channels=2 * features_decoder[-1], 607 out_channels=features_decoder[-1], 608 use_strip_pooling=use_strip_pooling, 609 ) 610 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
612 def forward(self, x: torch.Tensor): 613 """Forward pass of the UNETR-3D model. 614 615 Args: 616 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 617 618 Returns: 619 The UNETR output. 620 """ 621 B, C, Z, H, W = x.shape 622 original_shape = (Z, H, W) 623 624 # Preprocessing step 625 x, input_shape = self.preprocess(x) 626 627 # Run the image encoder. 628 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 629 630 # Prepare the counterparts for the decoder. 631 # NOTE: The section below is sequential, there's no skip connections atm. 632 z9 = self.deconv1(curr_features) 633 z6 = self.deconv2(z9) 634 z3 = self.deconv3(z6) 635 z0 = self.deconv4(z3) 636 637 updated_from_encoder = [z9, z6, z3] 638 639 # Align the features through the base block. 640 x = self.base(curr_features) 641 # Run the decoder 642 x = self.decoder(x, encoder_inputs=updated_from_encoder) 643 x = self.deconv_out(x) # NOTE before `end_up` 644 645 # And the final output head. 646 x = torch.cat([x, z0], dim=1) 647 x = self.decoder_head(x) 648 x = self.out_conv(x) 649 if self.final_activation is not None: 650 x = self.final_activation(x) 651 652 # Postprocess the output back to original size. 653 x = self.postprocess_masks(x, input_shape, original_shape) 654 return x
Forward pass of the UNETR-3D model.
Arguments:
- x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:
The UNETR output.