torch_em.model.unetr
1from functools import partial 2from collections import OrderedDict 3from typing import Optional, Tuple, Union, Literal 4 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8 9from .vit import get_vision_transformer 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs 11 12try: 13 from micro_sam.util import get_sam_model 14except ImportError: 15 get_sam_model = None 16 17try: 18 from micro_sam2.util import get_sam2_model 19except ImportError: 20 get_sam2_model = None 21 22try: 23 from micro_sam3.util import get_sam3_model 24except ImportError: 25 get_sam3_model = None 26 27 28# 29# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / SAM2 / SAM3 / DINOv2 / DINOv3 / MAE / ScaleMAE) + UNet Decoder from `torch_em`] # noqa 30# 31 32 33class UNETRBase(nn.Module): 34 """Base class for implementing a UNETR. 35 36 Args: 37 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 38 backbone: The name of the vision transformer implementation. 39 One of "sam", "sam2", "sam3, "mae", "scalemae", "dinov2", "dinov3" (see all combinations below) 40 encoder: The vision transformer. Can either be a name, such as "vit_b" 41 (see all combinations for this below) or a torch module. 42 decoder: The convolutional decoder. 43 out_channels: The number of output channels of the UNETR. 44 use_sam_stats: Whether to normalize the input data with the statistics of the 45 pretrained SAM / SAM2 / SAM3 model. 46 use_dino_stats: Whether to normalize the input data with the statistics of the 47 pretrained DINOv2 / DINOv3 model. 48 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 49 resize_input: Whether to resize the input images to match `img_size`. 50 By default, it resizes the inputs to match the `img_size`. 51 encoder_checkpoint: Checkpoint for initializing the vision transformer. 52 Can either be a filepath or an already loaded checkpoint. 53 final_activation: The activation to apply to the UNETR output. 54 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 55 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 56 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 57 By default, it uses resampling for upsampling. 58 59 NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following: 60 61 SAM_family_models: 62 - 'sam' x 'vit_b' 63 - 'sam' x 'vit_l' 64 - 'sam' x 'vit_h' 65 - 'sam2' x 'hvit_t' 66 - 'sam2' x 'hvit_s' 67 - 'sam2' x 'hvit_b' 68 - 'sam2' x 'hvit_l' 69 - 'sam3' x 'vit_pe' 70 71 DINO_family_models: 72 - 'dinov2' x 'vit_s' 73 - 'dinov2' x 'vit_b' 74 - 'dinov2' x 'vit_l' 75 - 'dinov2' x 'vit_g' 76 - 'dinov2' x 'vit_s_reg4' 77 - 'dinov2' x 'vit_b_reg4' 78 - 'dinov2' x 'vit_l_reg4' 79 - 'dinov2' x 'vit_g_reg4' 80 - 'dinov3' x 'vit_s' 81 - 'dinov3' x 'vit_s+' 82 - 'dinov3' x 'vit_b' 83 - 'dinov3' x 'vit_l' 84 - 'dinov3' x 'vit_l+' 85 - 'dinov3' x 'vit_h+' 86 - 'dinov3' x 'vit_7b' 87 88 MAE_family_models: 89 - 'mae' x 'vit_b' 90 - 'mae' x 'vit_l' 91 - 'mae' x 'vit_h' 92 - 'scalemae' x 'vit_b' 93 - 'scalemae' x 'vit_l' 94 - 'scalemae' x 'vit_h' 95 """ 96 def __init__( 97 self, 98 img_size: int = 1024, 99 backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 100 encoder: Optional[Union[nn.Module, str]] = "vit_b", 101 decoder: Optional[nn.Module] = None, 102 out_channels: int = 1, 103 use_sam_stats: bool = False, 104 use_mae_stats: bool = False, 105 use_dino_stats: bool = False, 106 resize_input: bool = True, 107 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 108 final_activation: Optional[Union[str, nn.Module]] = None, 109 use_skip_connection: bool = True, 110 embed_dim: Optional[int] = None, 111 use_conv_transpose: bool = False, 112 **kwargs 113 ) -> None: 114 super().__init__() 115 116 self.img_size = img_size 117 self.use_sam_stats = use_sam_stats 118 self.use_mae_stats = use_mae_stats 119 self.use_dino_stats = use_dino_stats 120 self.use_skip_connection = use_skip_connection 121 self.resize_input = resize_input 122 self.use_conv_transpose = use_conv_transpose 123 self.backbone = backbone 124 125 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 126 print(f"Using {encoder} from {backbone.upper()}") 127 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 128 129 if encoder_checkpoint is not None: 130 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 131 132 if embed_dim is None: 133 embed_dim = self.encoder.embed_dim 134 135 else: # `nn.Module` ViT backbone 136 self.encoder = encoder 137 138 have_neck = False 139 for name, _ in self.encoder.named_parameters(): 140 if name.startswith("neck"): 141 have_neck = True 142 143 if embed_dim is None: 144 if have_neck: 145 embed_dim = self.encoder.neck[2].out_channels # the value is 256 146 else: 147 embed_dim = self.encoder.patch_embed.proj.out_channels 148 149 self.embed_dim = embed_dim 150 self.final_activation = self._get_activation(final_activation) 151 152 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 153 """Function to load pretrained weights to the image encoder. 154 """ 155 if isinstance(checkpoint, str): 156 if backbone == "sam" and isinstance(encoder, str): 157 # If we have a SAM encoder, then we first try to load the full SAM Model 158 # (using micro_sam) and otherwise fall back on directly loading the encoder state 159 # from the checkpoint 160 try: 161 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 162 encoder_state = model.image_encoder.state_dict() 163 except Exception: 164 # Try loading the encoder state directly from a checkpoint. 165 encoder_state = torch.load(checkpoint, weights_only=False) 166 167 elif backbone == "sam2" and isinstance(encoder, str): 168 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 169 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 170 # from the checkpoint 171 try: 172 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 173 encoder_state = model.image_encoder.state_dict() 174 except Exception: 175 # Try loading the encoder state directly from a checkpoint. 176 encoder_state = torch.load(checkpoint, weights_only=False) 177 178 elif backbone == "sam3" and isinstance(encoder, str): 179 # If we have a SAM3 encoder, then we first try to load the full SAM3 Model. 180 # (using micro_sam3) and otherwise fall back on directly loading the encoder state 181 # from the checkpoint 182 try: 183 model = get_sam3_model(checkpoint_path=checkpoint) 184 encoder_state = model.backbone.vision_backbone.state_dict() 185 # Let's align loading the encoder weights with expected parameter names 186 encoder_state = { 187 k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items() 188 } 189 # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks. 190 encoder_state = { 191 k: v for k, v in encoder_state.items() 192 if not (k.startswith("convs.") or k.startswith("sam2_convs.")) 193 } 194 except Exception: 195 # Try loading the encoder state directly from a checkpoint. 196 encoder_state = torch.load(checkpoint, weights_only=False) 197 198 elif backbone == "mae": 199 # vit initialization hints from: 200 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 201 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 202 encoder_state = OrderedDict({ 203 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 204 }) 205 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 206 current_encoder_state = self.encoder.state_dict() 207 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 208 del self.encoder.head 209 210 elif backbone == "scalemae": 211 # Load the encoder state directly from a checkpoint. 212 encoder_state = torch.load(checkpoint)["model"] 213 encoder_state = OrderedDict({ 214 k: v for k, v in encoder_state.items() 215 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 216 }) 217 218 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 219 current_encoder_state = self.encoder.state_dict() 220 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 221 del self.encoder.head 222 223 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 224 del self.encoder.pos_embed 225 226 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 227 encoder_state = torch.load(checkpoint) 228 229 else: 230 raise ValueError( 231 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 232 ) 233 234 else: 235 encoder_state = checkpoint 236 237 self.encoder.load_state_dict(encoder_state) 238 239 def _get_activation(self, activation): 240 return_activation = None 241 if activation is None: 242 return None 243 if isinstance(activation, nn.Module): 244 return activation 245 if isinstance(activation, str): 246 return_activation = getattr(nn, activation, None) 247 if return_activation is None: 248 raise ValueError(f"Invalid activation: {activation}") 249 250 return return_activation() 251 252 @staticmethod 253 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 254 """Compute the output size given input size and target long side length. 255 256 Args: 257 oldh: The input image height. 258 oldw: The input image width. 259 long_side_length: The longest side length for resizing. 260 261 Returns: 262 The new image height. 263 The new image width. 264 """ 265 scale = long_side_length * 1.0 / max(oldh, oldw) 266 newh, neww = oldh * scale, oldw * scale 267 neww = int(neww + 0.5) 268 newh = int(newh + 0.5) 269 return (newh, neww) 270 271 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 272 """Resize the image so that the longest side has the correct length. 273 274 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 275 276 Args: 277 image: The input image. 278 279 Returns: 280 The resized image. 281 """ 282 if image.ndim == 4: # i.e. 2d image 283 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 284 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 285 elif image.ndim == 5: # i.e. 3d volume 286 B, C, Z, H, W = image.shape 287 target_size = self.get_preprocess_shape(H, W, self.img_size) 288 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 289 else: 290 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 291 292 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 293 """@private 294 """ 295 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 296 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 297 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 298 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 299 return pixel_mean, pixel_std 300 301 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 302 """@private 303 """ 304 device = x.device 305 is_3d = (x.ndim == 5) 306 device, dtype = x.device, x.dtype 307 308 if self.use_sam_stats: 309 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 310 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 311 raise NotImplementedError 312 elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"): 313 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 314 elif self.use_sam_stats and self.backbone == "sam3": 315 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 316 else: 317 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 318 319 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 320 321 if self.resize_input: 322 x = self.resize_longest_side(x) 323 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 324 325 x = (x - pixel_mean) / pixel_std 326 h, w = x.shape[-2:] 327 padh = self.encoder.img_size - h 328 padw = self.encoder.img_size - w 329 330 if is_3d: 331 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 332 else: 333 x = F.pad(x, (0, padw, 0, padh)) 334 335 return x, input_shape 336 337 def postprocess_masks( 338 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 339 ) -> torch.Tensor: 340 """@private 341 """ 342 if masks.ndim == 4: # i.e. 2d labels 343 masks = F.interpolate( 344 masks, 345 (self.encoder.img_size, self.encoder.img_size), 346 mode="bilinear", 347 align_corners=False, 348 ) 349 masks = masks[..., : input_size[0], : input_size[1]] 350 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 351 352 elif masks.ndim == 5: # i.e. 3d volumetric labels 353 masks = F.interpolate( 354 masks, 355 (input_size[0], self.img_size, self.img_size), 356 mode="trilinear", 357 align_corners=False, 358 ) 359 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 360 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 361 362 else: 363 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 364 365 return masks 366 367 368class UNETR(UNETRBase): 369 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 370 """ 371 def __init__( 372 self, 373 img_size: int = 1024, 374 backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 375 encoder: Optional[Union[nn.Module, str]] = "vit_b", 376 decoder: Optional[nn.Module] = None, 377 out_channels: int = 1, 378 use_sam_stats: bool = False, 379 use_mae_stats: bool = False, 380 use_dino_stats: bool = False, 381 resize_input: bool = True, 382 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 383 final_activation: Optional[Union[str, nn.Module]] = None, 384 use_skip_connection: bool = True, 385 embed_dim: Optional[int] = None, 386 use_conv_transpose: bool = False, 387 **kwargs 388 ) -> None: 389 390 super().__init__( 391 img_size=img_size, 392 backbone=backbone, 393 encoder=encoder, 394 decoder=decoder, 395 out_channels=out_channels, 396 use_sam_stats=use_sam_stats, 397 use_mae_stats=use_mae_stats, 398 use_dino_stats=use_dino_stats, 399 resize_input=resize_input, 400 encoder_checkpoint=encoder_checkpoint, 401 final_activation=final_activation, 402 use_skip_connection=use_skip_connection, 403 embed_dim=embed_dim, 404 use_conv_transpose=use_conv_transpose, 405 **kwargs, 406 ) 407 408 encoder = self.encoder 409 410 if backbone == "sam2" and hasattr(encoder, "trunk"): 411 in_chans = encoder.trunk.patch_embed.proj.in_channels 412 elif hasattr(encoder, "in_chans"): 413 in_chans = encoder.in_chans 414 else: # `nn.Module` ViT backbone. 415 try: 416 in_chans = encoder.patch_embed.proj.in_channels 417 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 418 in_chans = encoder.patch_embed.seq[0].c.in_channels 419 420 # parameters for the decoder network 421 depth = 3 422 initial_features = 64 423 gain = 2 424 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 425 scale_factors = depth * [2] 426 self.out_channels = out_channels 427 428 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 429 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 430 431 self.decoder = decoder or Decoder( 432 features=features_decoder, 433 scale_factors=scale_factors[::-1], 434 conv_block_impl=ConvBlock2d, 435 sampler_impl=_upsampler, 436 ) 437 438 if use_skip_connection: 439 self.deconv1 = Deconv2DBlock( 440 in_channels=self.embed_dim, 441 out_channels=features_decoder[0], 442 use_conv_transpose=use_conv_transpose, 443 ) 444 self.deconv2 = nn.Sequential( 445 Deconv2DBlock( 446 in_channels=self.embed_dim, 447 out_channels=features_decoder[0], 448 use_conv_transpose=use_conv_transpose, 449 ), 450 Deconv2DBlock( 451 in_channels=features_decoder[0], 452 out_channels=features_decoder[1], 453 use_conv_transpose=use_conv_transpose, 454 ) 455 ) 456 self.deconv3 = nn.Sequential( 457 Deconv2DBlock( 458 in_channels=self.embed_dim, 459 out_channels=features_decoder[0], 460 use_conv_transpose=use_conv_transpose, 461 ), 462 Deconv2DBlock( 463 in_channels=features_decoder[0], 464 out_channels=features_decoder[1], 465 use_conv_transpose=use_conv_transpose, 466 ), 467 Deconv2DBlock( 468 in_channels=features_decoder[1], 469 out_channels=features_decoder[2], 470 use_conv_transpose=use_conv_transpose, 471 ) 472 ) 473 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 474 else: 475 self.deconv1 = Deconv2DBlock( 476 in_channels=self.embed_dim, 477 out_channels=features_decoder[0], 478 use_conv_transpose=use_conv_transpose, 479 ) 480 self.deconv2 = Deconv2DBlock( 481 in_channels=features_decoder[0], 482 out_channels=features_decoder[1], 483 use_conv_transpose=use_conv_transpose, 484 ) 485 self.deconv3 = Deconv2DBlock( 486 in_channels=features_decoder[1], 487 out_channels=features_decoder[2], 488 use_conv_transpose=use_conv_transpose, 489 ) 490 self.deconv4 = Deconv2DBlock( 491 in_channels=features_decoder[2], 492 out_channels=features_decoder[3], 493 use_conv_transpose=use_conv_transpose, 494 ) 495 496 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 497 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 498 self.deconv_out = _upsampler( 499 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 500 ) 501 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 502 503 def forward(self, x: torch.Tensor) -> torch.Tensor: 504 """Apply the UNETR to the input data. 505 506 Args: 507 x: The input tensor. 508 509 Returns: 510 The UNETR output. 511 """ 512 original_shape = x.shape[-2:] 513 514 # Reshape the inputs to the shape expected by the encoder 515 # and normalize the inputs if normalization is part of the model. 516 x, input_shape = self.preprocess(x) 517 518 encoder_outputs = self.encoder(x) 519 520 if isinstance(encoder_outputs[-1], list): 521 # `encoder_outputs` can be arranged in only two forms: 522 # - either we only return the image embeddings 523 # - or, we return the image embeddings and the "list" of global attention layers 524 z12, from_encoder = encoder_outputs 525 else: 526 z12 = encoder_outputs 527 528 if self.use_skip_connection: 529 from_encoder = from_encoder[::-1] 530 z9 = self.deconv1(from_encoder[0]) 531 z6 = self.deconv2(from_encoder[1]) 532 z3 = self.deconv3(from_encoder[2]) 533 z0 = self.deconv4(x) 534 535 else: 536 z9 = self.deconv1(z12) 537 z6 = self.deconv2(z9) 538 z3 = self.deconv3(z6) 539 z0 = self.deconv4(z3) 540 541 updated_from_encoder = [z9, z6, z3] 542 543 x = self.base(z12) 544 x = self.decoder(x, encoder_inputs=updated_from_encoder) 545 x = self.deconv_out(x) 546 547 x = torch.cat([x, z0], dim=1) 548 x = self.decoder_head(x) 549 550 x = self.out_conv(x) 551 if self.final_activation is not None: 552 x = self.final_activation(x) 553 554 x = self.postprocess_masks(x, input_shape, original_shape) 555 return x 556 557 558class UNETR2D(UNETR): 559 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 560 """ 561 pass 562 563 564class UNETR3D(UNETRBase): 565 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 566 """ 567 def __init__( 568 self, 569 img_size: int = 1024, 570 backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 571 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 572 decoder: Optional[nn.Module] = None, 573 out_channels: int = 1, 574 use_sam_stats: bool = False, 575 use_mae_stats: bool = False, 576 use_dino_stats: bool = False, 577 resize_input: bool = True, 578 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 579 final_activation: Optional[Union[str, nn.Module]] = None, 580 use_skip_connection: bool = False, 581 embed_dim: Optional[int] = None, 582 use_conv_transpose: bool = False, 583 use_strip_pooling: bool = True, 584 **kwargs 585 ): 586 if use_skip_connection: 587 raise NotImplementedError("The framework cannot handle skip connections atm.") 588 if use_conv_transpose: 589 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 590 591 # Sort the `embed_dim` out 592 embed_dim = 256 if embed_dim is None else embed_dim 593 594 super().__init__( 595 img_size=img_size, 596 backbone=backbone, 597 encoder=encoder, 598 decoder=decoder, 599 out_channels=out_channels, 600 use_sam_stats=use_sam_stats, 601 use_mae_stats=use_mae_stats, 602 use_dino_stats=use_dino_stats, 603 resize_input=resize_input, 604 encoder_checkpoint=encoder_checkpoint, 605 final_activation=final_activation, 606 use_skip_connection=use_skip_connection, 607 embed_dim=embed_dim, 608 use_conv_transpose=use_conv_transpose, 609 **kwargs, 610 ) 611 612 # The 3d convolutional decoder. 613 # First, get the important parameters for the decoder. 614 depth = 3 615 initial_features = 64 616 gain = 2 617 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 618 scale_factors = [1, 2, 2] 619 self.out_channels = out_channels 620 621 # The mapping blocks. 622 self.deconv1 = Deconv3DBlock( 623 in_channels=embed_dim, 624 out_channels=features_decoder[0], 625 scale_factor=scale_factors, 626 use_strip_pooling=use_strip_pooling, 627 ) 628 self.deconv2 = Deconv3DBlock( 629 in_channels=features_decoder[0], 630 out_channels=features_decoder[1], 631 scale_factor=scale_factors, 632 use_strip_pooling=use_strip_pooling, 633 ) 634 self.deconv3 = Deconv3DBlock( 635 in_channels=features_decoder[1], 636 out_channels=features_decoder[2], 637 scale_factor=scale_factors, 638 use_strip_pooling=use_strip_pooling, 639 ) 640 self.deconv4 = Deconv3DBlock( 641 in_channels=features_decoder[2], 642 out_channels=features_decoder[3], 643 scale_factor=scale_factors, 644 use_strip_pooling=use_strip_pooling, 645 ) 646 647 # The core decoder block. 648 self.decoder = decoder or Decoder( 649 features=features_decoder, 650 scale_factors=[scale_factors] * depth, 651 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 652 sampler_impl=Upsampler3d, 653 ) 654 655 # And the final upsampler to match the expected dimensions. 656 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 657 in_channels=features_decoder[-1], 658 out_channels=features_decoder[-1], 659 scale_factor=scale_factors, 660 use_strip_pooling=use_strip_pooling, 661 ) 662 663 # Additional conjunction blocks. 664 self.base = ConvBlock3dWithStrip( 665 in_channels=embed_dim, 666 out_channels=features_decoder[0], 667 use_strip_pooling=use_strip_pooling, 668 ) 669 670 # And the output layers. 671 self.decoder_head = ConvBlock3dWithStrip( 672 in_channels=2 * features_decoder[-1], 673 out_channels=features_decoder[-1], 674 use_strip_pooling=use_strip_pooling, 675 ) 676 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 677 678 def forward(self, x: torch.Tensor): 679 """Forward pass of the UNETR-3D model. 680 681 Args: 682 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 683 684 Returns: 685 The UNETR output. 686 """ 687 B, C, Z, H, W = x.shape 688 original_shape = (Z, H, W) 689 690 # Preprocessing step 691 x, input_shape = self.preprocess(x) 692 693 # Run the image encoder. 694 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 695 696 # Prepare the counterparts for the decoder. 697 # NOTE: The section below is sequential, there's no skip connections atm. 698 z9 = self.deconv1(curr_features) 699 z6 = self.deconv2(z9) 700 z3 = self.deconv3(z6) 701 z0 = self.deconv4(z3) 702 703 updated_from_encoder = [z9, z6, z3] 704 705 # Align the features through the base block. 706 x = self.base(curr_features) 707 # Run the decoder 708 x = self.decoder(x, encoder_inputs=updated_from_encoder) 709 x = self.deconv_out(x) # NOTE before `end_up` 710 711 # And the final output head. 712 x = torch.cat([x, z0], dim=1) 713 x = self.decoder_head(x) 714 x = self.out_conv(x) 715 if self.final_activation is not None: 716 x = self.final_activation(x) 717 718 # Postprocess the output back to original size. 719 x = self.postprocess_masks(x, input_shape, original_shape) 720 return x 721 722# 723# ADDITIONAL FUNCTIONALITIES 724# 725 726 727def _strip_pooling_layers(enabled, channels) -> nn.Module: 728 return DepthStripPooling(channels) if enabled else nn.Identity() 729 730 731class DepthStripPooling(nn.Module): 732 """@private 733 """ 734 def __init__(self, channels: int, reduction: int = 4): 735 """Block for strip pooling along the depth dimension (only). 736 737 eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling 738 to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to 739 modulate the original features (using a gated residual). 740 741 For 2D (Z == 1 or Z == 3): returns input unchanged (no-op). 742 743 Args: 744 channels: The output channels. 745 reduction: The reduction of the hidden layers. 746 """ 747 super().__init__() 748 hidden = max(1, channels // reduction) 749 self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1) 750 self.bn1 = nn.BatchNorm3d(hidden) 751 self.relu = nn.ReLU(inplace=True) 752 self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1) 753 754 def forward(self, x: torch.Tensor) -> torch.Tensor: 755 if x.dim() != 5: 756 raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.") 757 758 B, C, Z, H, W = x.shape 759 if Z == 1 or Z == 3: # i.e. 2d-as-1-slice or RGB_2d-as-1-slice. 760 return x # We simply do nothing there. 761 762 # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W) 763 feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W)) 764 feat = self.conv1(feat) 765 feat = self.bn1(feat) 766 feat = self.relu(feat) 767 feat = self.conv2(feat) 768 gate = torch.sigmoid(feat).expand(B, C, Z, H, W) # Broadcast the collapsed depth context back to all slices 769 770 # Gated residual fusion 771 return x * gate + x 772 773 774class Deconv3DBlock(nn.Module): 775 """@private 776 """ 777 def __init__( 778 self, 779 scale_factor, 780 in_channels, 781 out_channels, 782 kernel_size=3, 783 anisotropic_kernel=True, 784 use_strip_pooling=True, 785 ): 786 super().__init__() 787 conv_block_kwargs = { 788 "in_channels": out_channels, 789 "out_channels": out_channels, 790 "kernel_size": kernel_size, 791 "padding": ((kernel_size - 1) // 2), 792 } 793 if anisotropic_kernel: 794 conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor) 795 796 self.block = nn.Sequential( 797 Upsampler3d(scale_factor, in_channels, out_channels), 798 nn.Conv3d(**conv_block_kwargs), 799 nn.BatchNorm3d(out_channels), 800 nn.ReLU(True), 801 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 802 ) 803 804 def forward(self, x): 805 return self.block(x) 806 807 808class ConvBlock3dWithStrip(nn.Module): 809 """@private 810 """ 811 def __init__( 812 self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs 813 ): 814 super().__init__() 815 self.block = nn.Sequential( 816 ConvBlock3d(in_channels, out_channels, **kwargs), 817 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 818 ) 819 820 def forward(self, x): 821 return self.block(x) 822 823 824class SingleDeconv2DBlock(nn.Module): 825 """@private 826 """ 827 def __init__(self, scale_factor, in_channels, out_channels): 828 super().__init__() 829 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 830 831 def forward(self, x): 832 return self.block(x) 833 834 835class SingleConv2DBlock(nn.Module): 836 """@private 837 """ 838 def __init__(self, in_channels, out_channels, kernel_size): 839 super().__init__() 840 self.block = nn.Conv2d( 841 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 842 ) 843 844 def forward(self, x): 845 return self.block(x) 846 847 848class Conv2DBlock(nn.Module): 849 """@private 850 """ 851 def __init__(self, in_channels, out_channels, kernel_size=3): 852 super().__init__() 853 self.block = nn.Sequential( 854 SingleConv2DBlock(in_channels, out_channels, kernel_size), 855 nn.BatchNorm2d(out_channels), 856 nn.ReLU(True) 857 ) 858 859 def forward(self, x): 860 return self.block(x) 861 862 863class Deconv2DBlock(nn.Module): 864 """@private 865 """ 866 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 867 super().__init__() 868 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 869 self.block = nn.Sequential( 870 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 871 SingleConv2DBlock(out_channels, out_channels, kernel_size), 872 nn.BatchNorm2d(out_channels), 873 nn.ReLU(True) 874 ) 875 876 def forward(self, x): 877 return self.block(x)
34class UNETRBase(nn.Module): 35 """Base class for implementing a UNETR. 36 37 Args: 38 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 39 backbone: The name of the vision transformer implementation. 40 One of "sam", "sam2", "sam3, "mae", "scalemae", "dinov2", "dinov3" (see all combinations below) 41 encoder: The vision transformer. Can either be a name, such as "vit_b" 42 (see all combinations for this below) or a torch module. 43 decoder: The convolutional decoder. 44 out_channels: The number of output channels of the UNETR. 45 use_sam_stats: Whether to normalize the input data with the statistics of the 46 pretrained SAM / SAM2 / SAM3 model. 47 use_dino_stats: Whether to normalize the input data with the statistics of the 48 pretrained DINOv2 / DINOv3 model. 49 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 50 resize_input: Whether to resize the input images to match `img_size`. 51 By default, it resizes the inputs to match the `img_size`. 52 encoder_checkpoint: Checkpoint for initializing the vision transformer. 53 Can either be a filepath or an already loaded checkpoint. 54 final_activation: The activation to apply to the UNETR output. 55 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 56 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 57 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 58 By default, it uses resampling for upsampling. 59 60 NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following: 61 62 SAM_family_models: 63 - 'sam' x 'vit_b' 64 - 'sam' x 'vit_l' 65 - 'sam' x 'vit_h' 66 - 'sam2' x 'hvit_t' 67 - 'sam2' x 'hvit_s' 68 - 'sam2' x 'hvit_b' 69 - 'sam2' x 'hvit_l' 70 - 'sam3' x 'vit_pe' 71 72 DINO_family_models: 73 - 'dinov2' x 'vit_s' 74 - 'dinov2' x 'vit_b' 75 - 'dinov2' x 'vit_l' 76 - 'dinov2' x 'vit_g' 77 - 'dinov2' x 'vit_s_reg4' 78 - 'dinov2' x 'vit_b_reg4' 79 - 'dinov2' x 'vit_l_reg4' 80 - 'dinov2' x 'vit_g_reg4' 81 - 'dinov3' x 'vit_s' 82 - 'dinov3' x 'vit_s+' 83 - 'dinov3' x 'vit_b' 84 - 'dinov3' x 'vit_l' 85 - 'dinov3' x 'vit_l+' 86 - 'dinov3' x 'vit_h+' 87 - 'dinov3' x 'vit_7b' 88 89 MAE_family_models: 90 - 'mae' x 'vit_b' 91 - 'mae' x 'vit_l' 92 - 'mae' x 'vit_h' 93 - 'scalemae' x 'vit_b' 94 - 'scalemae' x 'vit_l' 95 - 'scalemae' x 'vit_h' 96 """ 97 def __init__( 98 self, 99 img_size: int = 1024, 100 backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 101 encoder: Optional[Union[nn.Module, str]] = "vit_b", 102 decoder: Optional[nn.Module] = None, 103 out_channels: int = 1, 104 use_sam_stats: bool = False, 105 use_mae_stats: bool = False, 106 use_dino_stats: bool = False, 107 resize_input: bool = True, 108 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 109 final_activation: Optional[Union[str, nn.Module]] = None, 110 use_skip_connection: bool = True, 111 embed_dim: Optional[int] = None, 112 use_conv_transpose: bool = False, 113 **kwargs 114 ) -> None: 115 super().__init__() 116 117 self.img_size = img_size 118 self.use_sam_stats = use_sam_stats 119 self.use_mae_stats = use_mae_stats 120 self.use_dino_stats = use_dino_stats 121 self.use_skip_connection = use_skip_connection 122 self.resize_input = resize_input 123 self.use_conv_transpose = use_conv_transpose 124 self.backbone = backbone 125 126 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 127 print(f"Using {encoder} from {backbone.upper()}") 128 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 129 130 if encoder_checkpoint is not None: 131 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 132 133 if embed_dim is None: 134 embed_dim = self.encoder.embed_dim 135 136 else: # `nn.Module` ViT backbone 137 self.encoder = encoder 138 139 have_neck = False 140 for name, _ in self.encoder.named_parameters(): 141 if name.startswith("neck"): 142 have_neck = True 143 144 if embed_dim is None: 145 if have_neck: 146 embed_dim = self.encoder.neck[2].out_channels # the value is 256 147 else: 148 embed_dim = self.encoder.patch_embed.proj.out_channels 149 150 self.embed_dim = embed_dim 151 self.final_activation = self._get_activation(final_activation) 152 153 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 154 """Function to load pretrained weights to the image encoder. 155 """ 156 if isinstance(checkpoint, str): 157 if backbone == "sam" and isinstance(encoder, str): 158 # If we have a SAM encoder, then we first try to load the full SAM Model 159 # (using micro_sam) and otherwise fall back on directly loading the encoder state 160 # from the checkpoint 161 try: 162 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 163 encoder_state = model.image_encoder.state_dict() 164 except Exception: 165 # Try loading the encoder state directly from a checkpoint. 166 encoder_state = torch.load(checkpoint, weights_only=False) 167 168 elif backbone == "sam2" and isinstance(encoder, str): 169 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 170 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 171 # from the checkpoint 172 try: 173 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 174 encoder_state = model.image_encoder.state_dict() 175 except Exception: 176 # Try loading the encoder state directly from a checkpoint. 177 encoder_state = torch.load(checkpoint, weights_only=False) 178 179 elif backbone == "sam3" and isinstance(encoder, str): 180 # If we have a SAM3 encoder, then we first try to load the full SAM3 Model. 181 # (using micro_sam3) and otherwise fall back on directly loading the encoder state 182 # from the checkpoint 183 try: 184 model = get_sam3_model(checkpoint_path=checkpoint) 185 encoder_state = model.backbone.vision_backbone.state_dict() 186 # Let's align loading the encoder weights with expected parameter names 187 encoder_state = { 188 k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items() 189 } 190 # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks. 191 encoder_state = { 192 k: v for k, v in encoder_state.items() 193 if not (k.startswith("convs.") or k.startswith("sam2_convs.")) 194 } 195 except Exception: 196 # Try loading the encoder state directly from a checkpoint. 197 encoder_state = torch.load(checkpoint, weights_only=False) 198 199 elif backbone == "mae": 200 # vit initialization hints from: 201 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 202 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 203 encoder_state = OrderedDict({ 204 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 205 }) 206 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 207 current_encoder_state = self.encoder.state_dict() 208 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 209 del self.encoder.head 210 211 elif backbone == "scalemae": 212 # Load the encoder state directly from a checkpoint. 213 encoder_state = torch.load(checkpoint)["model"] 214 encoder_state = OrderedDict({ 215 k: v for k, v in encoder_state.items() 216 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 217 }) 218 219 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 220 current_encoder_state = self.encoder.state_dict() 221 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 222 del self.encoder.head 223 224 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 225 del self.encoder.pos_embed 226 227 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 228 encoder_state = torch.load(checkpoint) 229 230 else: 231 raise ValueError( 232 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 233 ) 234 235 else: 236 encoder_state = checkpoint 237 238 self.encoder.load_state_dict(encoder_state) 239 240 def _get_activation(self, activation): 241 return_activation = None 242 if activation is None: 243 return None 244 if isinstance(activation, nn.Module): 245 return activation 246 if isinstance(activation, str): 247 return_activation = getattr(nn, activation, None) 248 if return_activation is None: 249 raise ValueError(f"Invalid activation: {activation}") 250 251 return return_activation() 252 253 @staticmethod 254 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 255 """Compute the output size given input size and target long side length. 256 257 Args: 258 oldh: The input image height. 259 oldw: The input image width. 260 long_side_length: The longest side length for resizing. 261 262 Returns: 263 The new image height. 264 The new image width. 265 """ 266 scale = long_side_length * 1.0 / max(oldh, oldw) 267 newh, neww = oldh * scale, oldw * scale 268 neww = int(neww + 0.5) 269 newh = int(newh + 0.5) 270 return (newh, neww) 271 272 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 273 """Resize the image so that the longest side has the correct length. 274 275 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 276 277 Args: 278 image: The input image. 279 280 Returns: 281 The resized image. 282 """ 283 if image.ndim == 4: # i.e. 2d image 284 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 285 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 286 elif image.ndim == 5: # i.e. 3d volume 287 B, C, Z, H, W = image.shape 288 target_size = self.get_preprocess_shape(H, W, self.img_size) 289 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 290 else: 291 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 292 293 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 294 """@private 295 """ 296 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 297 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 298 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 299 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 300 return pixel_mean, pixel_std 301 302 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 303 """@private 304 """ 305 device = x.device 306 is_3d = (x.ndim == 5) 307 device, dtype = x.device, x.dtype 308 309 if self.use_sam_stats: 310 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 311 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 312 raise NotImplementedError 313 elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"): 314 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 315 elif self.use_sam_stats and self.backbone == "sam3": 316 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 317 else: 318 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 319 320 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 321 322 if self.resize_input: 323 x = self.resize_longest_side(x) 324 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 325 326 x = (x - pixel_mean) / pixel_std 327 h, w = x.shape[-2:] 328 padh = self.encoder.img_size - h 329 padw = self.encoder.img_size - w 330 331 if is_3d: 332 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 333 else: 334 x = F.pad(x, (0, padw, 0, padh)) 335 336 return x, input_shape 337 338 def postprocess_masks( 339 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 340 ) -> torch.Tensor: 341 """@private 342 """ 343 if masks.ndim == 4: # i.e. 2d labels 344 masks = F.interpolate( 345 masks, 346 (self.encoder.img_size, self.encoder.img_size), 347 mode="bilinear", 348 align_corners=False, 349 ) 350 masks = masks[..., : input_size[0], : input_size[1]] 351 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 352 353 elif masks.ndim == 5: # i.e. 3d volumetric labels 354 masks = F.interpolate( 355 masks, 356 (input_size[0], self.img_size, self.img_size), 357 mode="trilinear", 358 align_corners=False, 359 ) 360 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 361 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 362 363 else: 364 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 365 366 return masks
Base class for implementing a UNETR.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam", "sam2", "sam3, "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
- encoder: The vision transformer. Can either be a name, such as "vit_b" (see all combinations for this below) or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM / SAM2 / SAM3 model.
- use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv2 / DINOv3 model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size. By default, it resizes the inputs to match theimg_size. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
- NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
- SAM_family_models: - 'sam' x 'vit_b'
- 'sam' x 'vit_l'
- 'sam' x 'vit_h'
- 'sam2' x 'hvit_t'
- 'sam2' x 'hvit_s'
- 'sam2' x 'hvit_b'
- 'sam2' x 'hvit_l'
- 'sam3' x 'vit_pe'
- DINO_family_models: - 'dinov2' x 'vit_s'
- 'dinov2' x 'vit_b'
- 'dinov2' x 'vit_l'
- 'dinov2' x 'vit_g'
- 'dinov2' x 'vit_s_reg4'
- 'dinov2' x 'vit_b_reg4'
- 'dinov2' x 'vit_l_reg4'
- 'dinov2' x 'vit_g_reg4'
- 'dinov3' x 'vit_s'
- 'dinov3' x 'vit_s+'
- 'dinov3' x 'vit_b'
- 'dinov3' x 'vit_l'
- 'dinov3' x 'vit_l+'
- 'dinov3' x 'vit_h+'
- 'dinov3' x 'vit_7b'
- MAE_family_models: - 'mae' x 'vit_b'
- 'mae' x 'vit_l'
- 'mae' x 'vit_h'
- 'scalemae' x 'vit_b'
- 'scalemae' x 'vit_l'
- 'scalemae' x 'vit_h'
97 def __init__( 98 self, 99 img_size: int = 1024, 100 backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 101 encoder: Optional[Union[nn.Module, str]] = "vit_b", 102 decoder: Optional[nn.Module] = None, 103 out_channels: int = 1, 104 use_sam_stats: bool = False, 105 use_mae_stats: bool = False, 106 use_dino_stats: bool = False, 107 resize_input: bool = True, 108 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 109 final_activation: Optional[Union[str, nn.Module]] = None, 110 use_skip_connection: bool = True, 111 embed_dim: Optional[int] = None, 112 use_conv_transpose: bool = False, 113 **kwargs 114 ) -> None: 115 super().__init__() 116 117 self.img_size = img_size 118 self.use_sam_stats = use_sam_stats 119 self.use_mae_stats = use_mae_stats 120 self.use_dino_stats = use_dino_stats 121 self.use_skip_connection = use_skip_connection 122 self.resize_input = resize_input 123 self.use_conv_transpose = use_conv_transpose 124 self.backbone = backbone 125 126 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 127 print(f"Using {encoder} from {backbone.upper()}") 128 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 129 130 if encoder_checkpoint is not None: 131 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 132 133 if embed_dim is None: 134 embed_dim = self.encoder.embed_dim 135 136 else: # `nn.Module` ViT backbone 137 self.encoder = encoder 138 139 have_neck = False 140 for name, _ in self.encoder.named_parameters(): 141 if name.startswith("neck"): 142 have_neck = True 143 144 if embed_dim is None: 145 if have_neck: 146 embed_dim = self.encoder.neck[2].out_channels # the value is 256 147 else: 148 embed_dim = self.encoder.patch_embed.proj.out_channels 149 150 self.embed_dim = embed_dim 151 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
253 @staticmethod 254 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 255 """Compute the output size given input size and target long side length. 256 257 Args: 258 oldh: The input image height. 259 oldw: The input image width. 260 long_side_length: The longest side length for resizing. 261 262 Returns: 263 The new image height. 264 The new image width. 265 """ 266 scale = long_side_length * 1.0 / max(oldh, oldw) 267 newh, neww = oldh * scale, oldw * scale 268 neww = int(neww + 0.5) 269 newh = int(newh + 0.5) 270 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
272 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 273 """Resize the image so that the longest side has the correct length. 274 275 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 276 277 Args: 278 image: The input image. 279 280 Returns: 281 The resized image. 282 """ 283 if image.ndim == 4: # i.e. 2d image 284 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 285 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 286 elif image.ndim == 5: # i.e. 3d volume 287 B, C, Z, H, W = image.shape 288 target_size = self.get_preprocess_shape(H, W, self.img_size) 289 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 290 else: 291 raise ValueError("Expected 4d or 5d inputs, got", image.shape)
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
369class UNETR(UNETRBase): 370 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 371 """ 372 def __init__( 373 self, 374 img_size: int = 1024, 375 backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 376 encoder: Optional[Union[nn.Module, str]] = "vit_b", 377 decoder: Optional[nn.Module] = None, 378 out_channels: int = 1, 379 use_sam_stats: bool = False, 380 use_mae_stats: bool = False, 381 use_dino_stats: bool = False, 382 resize_input: bool = True, 383 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 384 final_activation: Optional[Union[str, nn.Module]] = None, 385 use_skip_connection: bool = True, 386 embed_dim: Optional[int] = None, 387 use_conv_transpose: bool = False, 388 **kwargs 389 ) -> None: 390 391 super().__init__( 392 img_size=img_size, 393 backbone=backbone, 394 encoder=encoder, 395 decoder=decoder, 396 out_channels=out_channels, 397 use_sam_stats=use_sam_stats, 398 use_mae_stats=use_mae_stats, 399 use_dino_stats=use_dino_stats, 400 resize_input=resize_input, 401 encoder_checkpoint=encoder_checkpoint, 402 final_activation=final_activation, 403 use_skip_connection=use_skip_connection, 404 embed_dim=embed_dim, 405 use_conv_transpose=use_conv_transpose, 406 **kwargs, 407 ) 408 409 encoder = self.encoder 410 411 if backbone == "sam2" and hasattr(encoder, "trunk"): 412 in_chans = encoder.trunk.patch_embed.proj.in_channels 413 elif hasattr(encoder, "in_chans"): 414 in_chans = encoder.in_chans 415 else: # `nn.Module` ViT backbone. 416 try: 417 in_chans = encoder.patch_embed.proj.in_channels 418 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 419 in_chans = encoder.patch_embed.seq[0].c.in_channels 420 421 # parameters for the decoder network 422 depth = 3 423 initial_features = 64 424 gain = 2 425 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 426 scale_factors = depth * [2] 427 self.out_channels = out_channels 428 429 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 430 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 431 432 self.decoder = decoder or Decoder( 433 features=features_decoder, 434 scale_factors=scale_factors[::-1], 435 conv_block_impl=ConvBlock2d, 436 sampler_impl=_upsampler, 437 ) 438 439 if use_skip_connection: 440 self.deconv1 = Deconv2DBlock( 441 in_channels=self.embed_dim, 442 out_channels=features_decoder[0], 443 use_conv_transpose=use_conv_transpose, 444 ) 445 self.deconv2 = nn.Sequential( 446 Deconv2DBlock( 447 in_channels=self.embed_dim, 448 out_channels=features_decoder[0], 449 use_conv_transpose=use_conv_transpose, 450 ), 451 Deconv2DBlock( 452 in_channels=features_decoder[0], 453 out_channels=features_decoder[1], 454 use_conv_transpose=use_conv_transpose, 455 ) 456 ) 457 self.deconv3 = nn.Sequential( 458 Deconv2DBlock( 459 in_channels=self.embed_dim, 460 out_channels=features_decoder[0], 461 use_conv_transpose=use_conv_transpose, 462 ), 463 Deconv2DBlock( 464 in_channels=features_decoder[0], 465 out_channels=features_decoder[1], 466 use_conv_transpose=use_conv_transpose, 467 ), 468 Deconv2DBlock( 469 in_channels=features_decoder[1], 470 out_channels=features_decoder[2], 471 use_conv_transpose=use_conv_transpose, 472 ) 473 ) 474 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 475 else: 476 self.deconv1 = Deconv2DBlock( 477 in_channels=self.embed_dim, 478 out_channels=features_decoder[0], 479 use_conv_transpose=use_conv_transpose, 480 ) 481 self.deconv2 = Deconv2DBlock( 482 in_channels=features_decoder[0], 483 out_channels=features_decoder[1], 484 use_conv_transpose=use_conv_transpose, 485 ) 486 self.deconv3 = Deconv2DBlock( 487 in_channels=features_decoder[1], 488 out_channels=features_decoder[2], 489 use_conv_transpose=use_conv_transpose, 490 ) 491 self.deconv4 = Deconv2DBlock( 492 in_channels=features_decoder[2], 493 out_channels=features_decoder[3], 494 use_conv_transpose=use_conv_transpose, 495 ) 496 497 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 498 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 499 self.deconv_out = _upsampler( 500 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 501 ) 502 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 503 504 def forward(self, x: torch.Tensor) -> torch.Tensor: 505 """Apply the UNETR to the input data. 506 507 Args: 508 x: The input tensor. 509 510 Returns: 511 The UNETR output. 512 """ 513 original_shape = x.shape[-2:] 514 515 # Reshape the inputs to the shape expected by the encoder 516 # and normalize the inputs if normalization is part of the model. 517 x, input_shape = self.preprocess(x) 518 519 encoder_outputs = self.encoder(x) 520 521 if isinstance(encoder_outputs[-1], list): 522 # `encoder_outputs` can be arranged in only two forms: 523 # - either we only return the image embeddings 524 # - or, we return the image embeddings and the "list" of global attention layers 525 z12, from_encoder = encoder_outputs 526 else: 527 z12 = encoder_outputs 528 529 if self.use_skip_connection: 530 from_encoder = from_encoder[::-1] 531 z9 = self.deconv1(from_encoder[0]) 532 z6 = self.deconv2(from_encoder[1]) 533 z3 = self.deconv3(from_encoder[2]) 534 z0 = self.deconv4(x) 535 536 else: 537 z9 = self.deconv1(z12) 538 z6 = self.deconv2(z9) 539 z3 = self.deconv3(z6) 540 z0 = self.deconv4(z3) 541 542 updated_from_encoder = [z9, z6, z3] 543 544 x = self.base(z12) 545 x = self.decoder(x, encoder_inputs=updated_from_encoder) 546 x = self.deconv_out(x) 547 548 x = torch.cat([x, z0], dim=1) 549 x = self.decoder_head(x) 550 551 x = self.out_conv(x) 552 if self.final_activation is not None: 553 x = self.final_activation(x) 554 555 x = self.postprocess_masks(x, input_shape, original_shape) 556 return x
A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
372 def __init__( 373 self, 374 img_size: int = 1024, 375 backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 376 encoder: Optional[Union[nn.Module, str]] = "vit_b", 377 decoder: Optional[nn.Module] = None, 378 out_channels: int = 1, 379 use_sam_stats: bool = False, 380 use_mae_stats: bool = False, 381 use_dino_stats: bool = False, 382 resize_input: bool = True, 383 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 384 final_activation: Optional[Union[str, nn.Module]] = None, 385 use_skip_connection: bool = True, 386 embed_dim: Optional[int] = None, 387 use_conv_transpose: bool = False, 388 **kwargs 389 ) -> None: 390 391 super().__init__( 392 img_size=img_size, 393 backbone=backbone, 394 encoder=encoder, 395 decoder=decoder, 396 out_channels=out_channels, 397 use_sam_stats=use_sam_stats, 398 use_mae_stats=use_mae_stats, 399 use_dino_stats=use_dino_stats, 400 resize_input=resize_input, 401 encoder_checkpoint=encoder_checkpoint, 402 final_activation=final_activation, 403 use_skip_connection=use_skip_connection, 404 embed_dim=embed_dim, 405 use_conv_transpose=use_conv_transpose, 406 **kwargs, 407 ) 408 409 encoder = self.encoder 410 411 if backbone == "sam2" and hasattr(encoder, "trunk"): 412 in_chans = encoder.trunk.patch_embed.proj.in_channels 413 elif hasattr(encoder, "in_chans"): 414 in_chans = encoder.in_chans 415 else: # `nn.Module` ViT backbone. 416 try: 417 in_chans = encoder.patch_embed.proj.in_channels 418 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 419 in_chans = encoder.patch_embed.seq[0].c.in_channels 420 421 # parameters for the decoder network 422 depth = 3 423 initial_features = 64 424 gain = 2 425 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 426 scale_factors = depth * [2] 427 self.out_channels = out_channels 428 429 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 430 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 431 432 self.decoder = decoder or Decoder( 433 features=features_decoder, 434 scale_factors=scale_factors[::-1], 435 conv_block_impl=ConvBlock2d, 436 sampler_impl=_upsampler, 437 ) 438 439 if use_skip_connection: 440 self.deconv1 = Deconv2DBlock( 441 in_channels=self.embed_dim, 442 out_channels=features_decoder[0], 443 use_conv_transpose=use_conv_transpose, 444 ) 445 self.deconv2 = nn.Sequential( 446 Deconv2DBlock( 447 in_channels=self.embed_dim, 448 out_channels=features_decoder[0], 449 use_conv_transpose=use_conv_transpose, 450 ), 451 Deconv2DBlock( 452 in_channels=features_decoder[0], 453 out_channels=features_decoder[1], 454 use_conv_transpose=use_conv_transpose, 455 ) 456 ) 457 self.deconv3 = nn.Sequential( 458 Deconv2DBlock( 459 in_channels=self.embed_dim, 460 out_channels=features_decoder[0], 461 use_conv_transpose=use_conv_transpose, 462 ), 463 Deconv2DBlock( 464 in_channels=features_decoder[0], 465 out_channels=features_decoder[1], 466 use_conv_transpose=use_conv_transpose, 467 ), 468 Deconv2DBlock( 469 in_channels=features_decoder[1], 470 out_channels=features_decoder[2], 471 use_conv_transpose=use_conv_transpose, 472 ) 473 ) 474 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 475 else: 476 self.deconv1 = Deconv2DBlock( 477 in_channels=self.embed_dim, 478 out_channels=features_decoder[0], 479 use_conv_transpose=use_conv_transpose, 480 ) 481 self.deconv2 = Deconv2DBlock( 482 in_channels=features_decoder[0], 483 out_channels=features_decoder[1], 484 use_conv_transpose=use_conv_transpose, 485 ) 486 self.deconv3 = Deconv2DBlock( 487 in_channels=features_decoder[1], 488 out_channels=features_decoder[2], 489 use_conv_transpose=use_conv_transpose, 490 ) 491 self.deconv4 = Deconv2DBlock( 492 in_channels=features_decoder[2], 493 out_channels=features_decoder[3], 494 use_conv_transpose=use_conv_transpose, 495 ) 496 497 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 498 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 499 self.deconv_out = _upsampler( 500 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 501 ) 502 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
504 def forward(self, x: torch.Tensor) -> torch.Tensor: 505 """Apply the UNETR to the input data. 506 507 Args: 508 x: The input tensor. 509 510 Returns: 511 The UNETR output. 512 """ 513 original_shape = x.shape[-2:] 514 515 # Reshape the inputs to the shape expected by the encoder 516 # and normalize the inputs if normalization is part of the model. 517 x, input_shape = self.preprocess(x) 518 519 encoder_outputs = self.encoder(x) 520 521 if isinstance(encoder_outputs[-1], list): 522 # `encoder_outputs` can be arranged in only two forms: 523 # - either we only return the image embeddings 524 # - or, we return the image embeddings and the "list" of global attention layers 525 z12, from_encoder = encoder_outputs 526 else: 527 z12 = encoder_outputs 528 529 if self.use_skip_connection: 530 from_encoder = from_encoder[::-1] 531 z9 = self.deconv1(from_encoder[0]) 532 z6 = self.deconv2(from_encoder[1]) 533 z3 = self.deconv3(from_encoder[2]) 534 z0 = self.deconv4(x) 535 536 else: 537 z9 = self.deconv1(z12) 538 z6 = self.deconv2(z9) 539 z3 = self.deconv3(z6) 540 z0 = self.deconv4(z3) 541 542 updated_from_encoder = [z9, z6, z3] 543 544 x = self.base(z12) 545 x = self.decoder(x, encoder_inputs=updated_from_encoder) 546 x = self.deconv_out(x) 547 548 x = torch.cat([x, z0], dim=1) 549 x = self.decoder_head(x) 550 551 x = self.out_conv(x) 552 if self.final_activation is not None: 553 x = self.final_activation(x) 554 555 x = self.postprocess_masks(x, input_shape, original_shape) 556 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.
559class UNETR2D(UNETR): 560 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 561 """ 562 pass
A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
565class UNETR3D(UNETRBase): 566 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 567 """ 568 def __init__( 569 self, 570 img_size: int = 1024, 571 backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 572 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 573 decoder: Optional[nn.Module] = None, 574 out_channels: int = 1, 575 use_sam_stats: bool = False, 576 use_mae_stats: bool = False, 577 use_dino_stats: bool = False, 578 resize_input: bool = True, 579 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 580 final_activation: Optional[Union[str, nn.Module]] = None, 581 use_skip_connection: bool = False, 582 embed_dim: Optional[int] = None, 583 use_conv_transpose: bool = False, 584 use_strip_pooling: bool = True, 585 **kwargs 586 ): 587 if use_skip_connection: 588 raise NotImplementedError("The framework cannot handle skip connections atm.") 589 if use_conv_transpose: 590 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 591 592 # Sort the `embed_dim` out 593 embed_dim = 256 if embed_dim is None else embed_dim 594 595 super().__init__( 596 img_size=img_size, 597 backbone=backbone, 598 encoder=encoder, 599 decoder=decoder, 600 out_channels=out_channels, 601 use_sam_stats=use_sam_stats, 602 use_mae_stats=use_mae_stats, 603 use_dino_stats=use_dino_stats, 604 resize_input=resize_input, 605 encoder_checkpoint=encoder_checkpoint, 606 final_activation=final_activation, 607 use_skip_connection=use_skip_connection, 608 embed_dim=embed_dim, 609 use_conv_transpose=use_conv_transpose, 610 **kwargs, 611 ) 612 613 # The 3d convolutional decoder. 614 # First, get the important parameters for the decoder. 615 depth = 3 616 initial_features = 64 617 gain = 2 618 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 619 scale_factors = [1, 2, 2] 620 self.out_channels = out_channels 621 622 # The mapping blocks. 623 self.deconv1 = Deconv3DBlock( 624 in_channels=embed_dim, 625 out_channels=features_decoder[0], 626 scale_factor=scale_factors, 627 use_strip_pooling=use_strip_pooling, 628 ) 629 self.deconv2 = Deconv3DBlock( 630 in_channels=features_decoder[0], 631 out_channels=features_decoder[1], 632 scale_factor=scale_factors, 633 use_strip_pooling=use_strip_pooling, 634 ) 635 self.deconv3 = Deconv3DBlock( 636 in_channels=features_decoder[1], 637 out_channels=features_decoder[2], 638 scale_factor=scale_factors, 639 use_strip_pooling=use_strip_pooling, 640 ) 641 self.deconv4 = Deconv3DBlock( 642 in_channels=features_decoder[2], 643 out_channels=features_decoder[3], 644 scale_factor=scale_factors, 645 use_strip_pooling=use_strip_pooling, 646 ) 647 648 # The core decoder block. 649 self.decoder = decoder or Decoder( 650 features=features_decoder, 651 scale_factors=[scale_factors] * depth, 652 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 653 sampler_impl=Upsampler3d, 654 ) 655 656 # And the final upsampler to match the expected dimensions. 657 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 658 in_channels=features_decoder[-1], 659 out_channels=features_decoder[-1], 660 scale_factor=scale_factors, 661 use_strip_pooling=use_strip_pooling, 662 ) 663 664 # Additional conjunction blocks. 665 self.base = ConvBlock3dWithStrip( 666 in_channels=embed_dim, 667 out_channels=features_decoder[0], 668 use_strip_pooling=use_strip_pooling, 669 ) 670 671 # And the output layers. 672 self.decoder_head = ConvBlock3dWithStrip( 673 in_channels=2 * features_decoder[-1], 674 out_channels=features_decoder[-1], 675 use_strip_pooling=use_strip_pooling, 676 ) 677 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 678 679 def forward(self, x: torch.Tensor): 680 """Forward pass of the UNETR-3D model. 681 682 Args: 683 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 684 685 Returns: 686 The UNETR output. 687 """ 688 B, C, Z, H, W = x.shape 689 original_shape = (Z, H, W) 690 691 # Preprocessing step 692 x, input_shape = self.preprocess(x) 693 694 # Run the image encoder. 695 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 696 697 # Prepare the counterparts for the decoder. 698 # NOTE: The section below is sequential, there's no skip connections atm. 699 z9 = self.deconv1(curr_features) 700 z6 = self.deconv2(z9) 701 z3 = self.deconv3(z6) 702 z0 = self.deconv4(z3) 703 704 updated_from_encoder = [z9, z6, z3] 705 706 # Align the features through the base block. 707 x = self.base(curr_features) 708 # Run the decoder 709 x = self.decoder(x, encoder_inputs=updated_from_encoder) 710 x = self.deconv_out(x) # NOTE before `end_up` 711 712 # And the final output head. 713 x = torch.cat([x, z0], dim=1) 714 x = self.decoder_head(x) 715 x = self.out_conv(x) 716 if self.final_activation is not None: 717 x = self.final_activation(x) 718 719 # Postprocess the output back to original size. 720 x = self.postprocess_masks(x, input_shape, original_shape) 721 return x
A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
568 def __init__( 569 self, 570 img_size: int = 1024, 571 backbone: Literal["sam", "sam2", "sam3", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 572 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 573 decoder: Optional[nn.Module] = None, 574 out_channels: int = 1, 575 use_sam_stats: bool = False, 576 use_mae_stats: bool = False, 577 use_dino_stats: bool = False, 578 resize_input: bool = True, 579 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 580 final_activation: Optional[Union[str, nn.Module]] = None, 581 use_skip_connection: bool = False, 582 embed_dim: Optional[int] = None, 583 use_conv_transpose: bool = False, 584 use_strip_pooling: bool = True, 585 **kwargs 586 ): 587 if use_skip_connection: 588 raise NotImplementedError("The framework cannot handle skip connections atm.") 589 if use_conv_transpose: 590 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 591 592 # Sort the `embed_dim` out 593 embed_dim = 256 if embed_dim is None else embed_dim 594 595 super().__init__( 596 img_size=img_size, 597 backbone=backbone, 598 encoder=encoder, 599 decoder=decoder, 600 out_channels=out_channels, 601 use_sam_stats=use_sam_stats, 602 use_mae_stats=use_mae_stats, 603 use_dino_stats=use_dino_stats, 604 resize_input=resize_input, 605 encoder_checkpoint=encoder_checkpoint, 606 final_activation=final_activation, 607 use_skip_connection=use_skip_connection, 608 embed_dim=embed_dim, 609 use_conv_transpose=use_conv_transpose, 610 **kwargs, 611 ) 612 613 # The 3d convolutional decoder. 614 # First, get the important parameters for the decoder. 615 depth = 3 616 initial_features = 64 617 gain = 2 618 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 619 scale_factors = [1, 2, 2] 620 self.out_channels = out_channels 621 622 # The mapping blocks. 623 self.deconv1 = Deconv3DBlock( 624 in_channels=embed_dim, 625 out_channels=features_decoder[0], 626 scale_factor=scale_factors, 627 use_strip_pooling=use_strip_pooling, 628 ) 629 self.deconv2 = Deconv3DBlock( 630 in_channels=features_decoder[0], 631 out_channels=features_decoder[1], 632 scale_factor=scale_factors, 633 use_strip_pooling=use_strip_pooling, 634 ) 635 self.deconv3 = Deconv3DBlock( 636 in_channels=features_decoder[1], 637 out_channels=features_decoder[2], 638 scale_factor=scale_factors, 639 use_strip_pooling=use_strip_pooling, 640 ) 641 self.deconv4 = Deconv3DBlock( 642 in_channels=features_decoder[2], 643 out_channels=features_decoder[3], 644 scale_factor=scale_factors, 645 use_strip_pooling=use_strip_pooling, 646 ) 647 648 # The core decoder block. 649 self.decoder = decoder or Decoder( 650 features=features_decoder, 651 scale_factors=[scale_factors] * depth, 652 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 653 sampler_impl=Upsampler3d, 654 ) 655 656 # And the final upsampler to match the expected dimensions. 657 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 658 in_channels=features_decoder[-1], 659 out_channels=features_decoder[-1], 660 scale_factor=scale_factors, 661 use_strip_pooling=use_strip_pooling, 662 ) 663 664 # Additional conjunction blocks. 665 self.base = ConvBlock3dWithStrip( 666 in_channels=embed_dim, 667 out_channels=features_decoder[0], 668 use_strip_pooling=use_strip_pooling, 669 ) 670 671 # And the output layers. 672 self.decoder_head = ConvBlock3dWithStrip( 673 in_channels=2 * features_decoder[-1], 674 out_channels=features_decoder[-1], 675 use_strip_pooling=use_strip_pooling, 676 ) 677 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
679 def forward(self, x: torch.Tensor): 680 """Forward pass of the UNETR-3D model. 681 682 Args: 683 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 684 685 Returns: 686 The UNETR output. 687 """ 688 B, C, Z, H, W = x.shape 689 original_shape = (Z, H, W) 690 691 # Preprocessing step 692 x, input_shape = self.preprocess(x) 693 694 # Run the image encoder. 695 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 696 697 # Prepare the counterparts for the decoder. 698 # NOTE: The section below is sequential, there's no skip connections atm. 699 z9 = self.deconv1(curr_features) 700 z6 = self.deconv2(z9) 701 z3 = self.deconv3(z6) 702 z0 = self.deconv4(z3) 703 704 updated_from_encoder = [z9, z6, z3] 705 706 # Align the features through the base block. 707 x = self.base(curr_features) 708 # Run the decoder 709 x = self.decoder(x, encoder_inputs=updated_from_encoder) 710 x = self.deconv_out(x) # NOTE before `end_up` 711 712 # And the final output head. 713 x = torch.cat([x, z0], dim=1) 714 x = self.decoder_head(x) 715 x = self.out_conv(x) 716 if self.final_activation is not None: 717 x = self.final_activation(x) 718 719 # Postprocess the output back to original size. 720 x = self.postprocess_masks(x, input_shape, original_shape) 721 return x
Forward pass of the UNETR-3D model.
Arguments:
- x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:
The UNETR output.