torch_em.model.unetr
1from functools import partial 2from collections import OrderedDict 3from typing import Optional, Tuple, Union, Literal 4 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8 9from .vit import get_vision_transformer 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs 11 12try: 13 from micro_sam.util import get_sam_model 14except ImportError: 15 get_sam_model = None 16 17try: 18 from micro_sam.v2.util import get_sam2_model 19except ImportError: 20 get_sam2_model = None 21 22try: 23 from micro_sam3.util import get_sam3_model 24except ImportError: 25 get_sam3_model = None 26 27 28# 29# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / CellposeSAM / SAM2 / SAM3 / DINOv2 / DINOv3 / MAE / ScaleMAE) + UNet Decoder from `torch_em`] # noqa 30# 31 32 33class UNETRBase(nn.Module): 34 """Base class for implementing a UNETR. 35 36 Args: 37 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 38 backbone: The name of the vision transformer implementation. 39 One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" 40 (see all combinations below) 41 encoder: The vision transformer. Can either be a name, such as "vit_b" 42 (see all combinations for this below) or a torch module. 43 decoder: The convolutional decoder. 44 out_channels: The number of output channels of the UNETR. 45 use_sam_stats: Whether to normalize the input data with the statistics of the 46 pretrained SAM / SAM2 / SAM3 model. 47 use_dino_stats: Whether to normalize the input data with the statistics of the 48 pretrained DINOv2 / DINOv3 model. 49 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 50 resize_input: Whether to resize the input images to match `img_size`. 51 By default, it resizes the inputs to match the `img_size`. 52 encoder_checkpoint: Checkpoint for initializing the vision transformer. 53 Can either be a filepath or an already loaded checkpoint. 54 final_activation: The activation to apply to the UNETR output. 55 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 56 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 57 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 58 By default, it uses resampling for upsampling. 59 60 NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following: 61 62 SAM_family_models: 63 - 'sam' x 'vit_b' 64 - 'sam' x 'vit_l' 65 - 'sam' x 'vit_h' 66 - 'sam2' x 'hvit_t' 67 - 'sam2' x 'hvit_s' 68 - 'sam2' x 'hvit_b' 69 - 'sam2' x 'hvit_l' 70 - 'sam3' x 'vit_pe' 71 - 'cellpose_sam' x 'vit_l' 72 73 DINO_family_models: 74 - 'dinov2' x 'vit_s' 75 - 'dinov2' x 'vit_b' 76 - 'dinov2' x 'vit_l' 77 - 'dinov2' x 'vit_g' 78 - 'dinov2' x 'vit_s_reg4' 79 - 'dinov2' x 'vit_b_reg4' 80 - 'dinov2' x 'vit_l_reg4' 81 - 'dinov2' x 'vit_g_reg4' 82 - 'dinov3' x 'vit_s' 83 - 'dinov3' x 'vit_s+' 84 - 'dinov3' x 'vit_b' 85 - 'dinov3' x 'vit_l' 86 - 'dinov3' x 'vit_l+' 87 - 'dinov3' x 'vit_h+' 88 - 'dinov3' x 'vit_7b' 89 90 MAE_family_models: 91 - 'mae' x 'vit_b' 92 - 'mae' x 'vit_l' 93 - 'mae' x 'vit_h' 94 - 'scalemae' x 'vit_b' 95 - 'scalemae' x 'vit_l' 96 - 'scalemae' x 'vit_h' 97 """ 98 def __init__( 99 self, 100 img_size: int = 1024, 101 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 102 encoder: Optional[Union[nn.Module, str]] = "vit_b", 103 decoder: Optional[nn.Module] = None, 104 out_channels: int = 1, 105 use_sam_stats: bool = False, 106 use_mae_stats: bool = False, 107 use_dino_stats: bool = False, 108 resize_input: bool = True, 109 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 110 final_activation: Optional[Union[str, nn.Module]] = None, 111 use_skip_connection: bool = True, 112 embed_dim: Optional[int] = None, 113 use_conv_transpose: bool = False, 114 **kwargs 115 ) -> None: 116 super().__init__() 117 118 self.img_size = img_size 119 self.use_sam_stats = use_sam_stats 120 self.use_mae_stats = use_mae_stats 121 self.use_dino_stats = use_dino_stats 122 self.use_skip_connection = use_skip_connection 123 self.resize_input = resize_input 124 self.use_conv_transpose = use_conv_transpose 125 self.backbone = backbone 126 127 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 128 print(f"Using {encoder} from {backbone.upper()}") 129 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 130 131 if encoder_checkpoint is not None: 132 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 133 134 if embed_dim is None: 135 embed_dim = self.encoder.embed_dim 136 137 # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change. 138 if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck: 139 embed_dim = self.encoder.neck[2].out_channels # the value is 256 140 141 else: # `nn.Module` ViT backbone 142 self.encoder = encoder 143 144 have_neck = False 145 for name, _ in self.encoder.named_parameters(): 146 if name.startswith("neck"): 147 have_neck = True 148 149 if embed_dim is None: 150 if have_neck: 151 embed_dim = self.encoder.neck[2].out_channels # the value is 256 152 else: 153 embed_dim = self.encoder.patch_embed.proj.out_channels 154 155 self.embed_dim = embed_dim 156 self.final_activation = self._get_activation(final_activation) 157 158 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 159 """Function to load pretrained weights to the image encoder. 160 """ 161 if isinstance(checkpoint, str): 162 if backbone == "sam" and isinstance(encoder, str): 163 # If we have a SAM encoder, then we first try to load the full SAM Model 164 # (using micro_sam) and otherwise fall back on directly loading the encoder state 165 # from the checkpoint 166 try: 167 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 168 encoder_state = model.image_encoder.state_dict() 169 except Exception: 170 # Try loading the encoder state directly from a checkpoint. 171 encoder_state = torch.load(checkpoint, weights_only=False) 172 173 elif backbone == "cellpose_sam" and isinstance(encoder, str): 174 # The architecture matches CellposeSAM exactly (same rel_pos sizes), 175 # so weights load directly without any interpolation. 176 encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False) 177 # Handle DataParallel/DistributedDataParallel prefix. 178 if any(k.startswith("module.") for k in encoder_state.keys()): 179 encoder_state = OrderedDict( 180 {k[len("module."):]: v for k, v in encoder_state.items()} 181 ) 182 # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix). 183 if any(k.startswith("encoder.") for k in encoder_state.keys()): 184 encoder_state = OrderedDict( 185 {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")} 186 ) 187 188 elif backbone == "sam2" and isinstance(encoder, str): 189 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 190 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 191 # from the checkpoint 192 try: 193 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 194 encoder_state = model.image_encoder.state_dict() 195 except Exception: 196 # Try loading the encoder state directly from a checkpoint. 197 encoder_state = torch.load(checkpoint, weights_only=False) 198 199 elif backbone == "sam3" and isinstance(encoder, str): 200 # If we have a SAM3 encoder, then we first try to load the full SAM3 Model. 201 # (using micro_sam3) and otherwise fall back on directly loading the encoder state 202 # from the checkpoint 203 try: 204 model = get_sam3_model(checkpoint_path=checkpoint) 205 encoder_state = model.backbone.vision_backbone.state_dict() 206 # Let's align loading the encoder weights with expected parameter names 207 encoder_state = { 208 k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items() 209 } 210 # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks. 211 encoder_state = { 212 k: v for k, v in encoder_state.items() 213 if not (k.startswith("convs.") or k.startswith("sam2_convs.")) 214 } 215 except Exception: 216 # Try loading the encoder state directly from a checkpoint. 217 encoder_state = torch.load(checkpoint, weights_only=False) 218 219 elif backbone == "mae": 220 # vit initialization hints from: 221 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 222 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 223 encoder_state = OrderedDict({ 224 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 225 }) 226 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 227 current_encoder_state = self.encoder.state_dict() 228 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 229 del self.encoder.head 230 231 elif backbone == "scalemae": 232 # Load the encoder state directly from a checkpoint. 233 encoder_state = torch.load(checkpoint)["model"] 234 encoder_state = OrderedDict({ 235 k: v for k, v in encoder_state.items() 236 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 237 }) 238 239 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 240 current_encoder_state = self.encoder.state_dict() 241 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 242 del self.encoder.head 243 244 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 245 del self.encoder.pos_embed 246 247 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 248 encoder_state = torch.load(checkpoint) 249 250 else: 251 raise ValueError( 252 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 253 ) 254 255 else: 256 encoder_state = checkpoint 257 258 self.encoder.load_state_dict(encoder_state) 259 260 def _get_activation(self, activation): 261 return_activation = None 262 if activation is None: 263 return None 264 if isinstance(activation, nn.Module): 265 return activation 266 if isinstance(activation, str): 267 return_activation = getattr(nn, activation, None) 268 if return_activation is None: 269 raise ValueError(f"Invalid activation: {activation}") 270 271 return return_activation() 272 273 @staticmethod 274 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 275 """Compute the output size given input size and target long side length. 276 277 Args: 278 oldh: The input image height. 279 oldw: The input image width. 280 long_side_length: The longest side length for resizing. 281 282 Returns: 283 The new image height. 284 The new image width. 285 """ 286 scale = long_side_length * 1.0 / max(oldh, oldw) 287 newh, neww = oldh * scale, oldw * scale 288 neww = int(neww + 0.5) 289 newh = int(newh + 0.5) 290 return (newh, neww) 291 292 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 293 """Resize the image so that the longest side has the correct length. 294 295 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 296 297 Args: 298 image: The input image. 299 300 Returns: 301 The resized image. 302 """ 303 if image.ndim == 4: # i.e. 2d image 304 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 305 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 306 elif image.ndim == 5: # i.e. 3d volume 307 B, C, Z, H, W = image.shape 308 target_size = self.get_preprocess_shape(H, W, self.img_size) 309 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 310 else: 311 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 312 313 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 314 """@private 315 """ 316 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 317 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 318 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 319 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 320 return pixel_mean, pixel_std 321 322 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 323 """@private 324 """ 325 device = x.device 326 is_3d = (x.ndim == 5) 327 device, dtype = x.device, x.dtype 328 329 if self.use_sam_stats: 330 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 331 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 332 raise NotImplementedError 333 elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"): 334 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 335 elif self.use_sam_stats and self.backbone == "sam3": 336 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 337 else: 338 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 339 340 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 341 342 if self.resize_input: 343 x = self.resize_longest_side(x) 344 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 345 346 x = (x - pixel_mean) / pixel_std 347 h, w = x.shape[-2:] 348 padh = self.encoder.img_size - h 349 padw = self.encoder.img_size - w 350 351 if is_3d: 352 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 353 else: 354 x = F.pad(x, (0, padw, 0, padh)) 355 356 return x, input_shape 357 358 def postprocess_masks( 359 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 360 ) -> torch.Tensor: 361 """@private 362 """ 363 if masks.ndim == 4: # i.e. 2d labels 364 masks = F.interpolate( 365 masks, 366 (self.encoder.img_size, self.encoder.img_size), 367 mode="bilinear", 368 align_corners=False, 369 ) 370 masks = masks[..., : input_size[0], : input_size[1]] 371 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 372 373 elif masks.ndim == 5: # i.e. 3d volumetric labels 374 masks = F.interpolate( 375 masks, 376 (input_size[0], self.img_size, self.img_size), 377 mode="trilinear", 378 align_corners=False, 379 ) 380 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 381 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 382 383 else: 384 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 385 386 return masks 387 388 389class UNETR(UNETRBase): 390 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 391 """ 392 def __init__( 393 self, 394 img_size: int = 1024, 395 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 396 encoder: Optional[Union[nn.Module, str]] = "vit_b", 397 decoder: Optional[nn.Module] = None, 398 out_channels: int = 1, 399 use_sam_stats: bool = False, 400 use_mae_stats: bool = False, 401 use_dino_stats: bool = False, 402 resize_input: bool = True, 403 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 404 final_activation: Optional[Union[str, nn.Module]] = None, 405 use_skip_connection: bool = True, 406 embed_dim: Optional[int] = None, 407 use_conv_transpose: bool = False, 408 **kwargs 409 ) -> None: 410 411 super().__init__( 412 img_size=img_size, 413 backbone=backbone, 414 encoder=encoder, 415 decoder=decoder, 416 out_channels=out_channels, 417 use_sam_stats=use_sam_stats, 418 use_mae_stats=use_mae_stats, 419 use_dino_stats=use_dino_stats, 420 resize_input=resize_input, 421 encoder_checkpoint=encoder_checkpoint, 422 final_activation=final_activation, 423 use_skip_connection=use_skip_connection, 424 embed_dim=embed_dim, 425 use_conv_transpose=use_conv_transpose, 426 **kwargs, 427 ) 428 429 encoder = self.encoder 430 431 if backbone == "sam2" and hasattr(encoder, "trunk"): 432 in_chans = encoder.trunk.patch_embed.proj.in_channels 433 elif hasattr(encoder, "in_chans"): 434 in_chans = encoder.in_chans 435 else: # `nn.Module` ViT backbone. 436 try: 437 in_chans = encoder.patch_embed.proj.in_channels 438 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 439 in_chans = encoder.patch_embed.seq[0].c.in_channels 440 441 # parameters for the decoder network 442 depth = 3 443 initial_features = 64 444 gain = 2 445 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 446 scale_factors = depth * [2] 447 self.out_channels = out_channels 448 449 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 450 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 451 452 self.decoder = decoder or Decoder( 453 features=features_decoder, 454 scale_factors=scale_factors[::-1], 455 conv_block_impl=ConvBlock2d, 456 sampler_impl=_upsampler, 457 ) 458 459 if use_skip_connection: 460 self.deconv1 = Deconv2DBlock( 461 in_channels=self.embed_dim, 462 out_channels=features_decoder[0], 463 use_conv_transpose=use_conv_transpose, 464 ) 465 self.deconv2 = nn.Sequential( 466 Deconv2DBlock( 467 in_channels=self.embed_dim, 468 out_channels=features_decoder[0], 469 use_conv_transpose=use_conv_transpose, 470 ), 471 Deconv2DBlock( 472 in_channels=features_decoder[0], 473 out_channels=features_decoder[1], 474 use_conv_transpose=use_conv_transpose, 475 ) 476 ) 477 self.deconv3 = nn.Sequential( 478 Deconv2DBlock( 479 in_channels=self.embed_dim, 480 out_channels=features_decoder[0], 481 use_conv_transpose=use_conv_transpose, 482 ), 483 Deconv2DBlock( 484 in_channels=features_decoder[0], 485 out_channels=features_decoder[1], 486 use_conv_transpose=use_conv_transpose, 487 ), 488 Deconv2DBlock( 489 in_channels=features_decoder[1], 490 out_channels=features_decoder[2], 491 use_conv_transpose=use_conv_transpose, 492 ) 493 ) 494 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 495 else: 496 self.deconv1 = Deconv2DBlock( 497 in_channels=self.embed_dim, 498 out_channels=features_decoder[0], 499 use_conv_transpose=use_conv_transpose, 500 ) 501 self.deconv2 = Deconv2DBlock( 502 in_channels=features_decoder[0], 503 out_channels=features_decoder[1], 504 use_conv_transpose=use_conv_transpose, 505 ) 506 self.deconv3 = Deconv2DBlock( 507 in_channels=features_decoder[1], 508 out_channels=features_decoder[2], 509 use_conv_transpose=use_conv_transpose, 510 ) 511 self.deconv4 = Deconv2DBlock( 512 in_channels=features_decoder[2], 513 out_channels=features_decoder[3], 514 use_conv_transpose=use_conv_transpose, 515 ) 516 517 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 518 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 519 self.deconv_out = _upsampler( 520 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 521 ) 522 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 523 524 def forward(self, x: torch.Tensor) -> torch.Tensor: 525 """Apply the UNETR to the input data. 526 527 Args: 528 x: The input tensor. 529 530 Returns: 531 The UNETR output. 532 """ 533 original_shape = x.shape[-2:] 534 535 # Reshape the inputs to the shape expected by the encoder 536 # and normalize the inputs if normalization is part of the model. 537 x, input_shape = self.preprocess(x) 538 539 encoder_outputs = self.encoder(x) 540 541 if isinstance(encoder_outputs[-1], list): 542 # `encoder_outputs` can be arranged in only two forms: 543 # - either we only return the image embeddings 544 # - or, we return the image embeddings and the "list" of global attention layers 545 z12, from_encoder = encoder_outputs 546 else: 547 z12 = encoder_outputs 548 549 if self.use_skip_connection: 550 from_encoder = from_encoder[::-1] 551 z9 = self.deconv1(from_encoder[0]) 552 z6 = self.deconv2(from_encoder[1]) 553 z3 = self.deconv3(from_encoder[2]) 554 z0 = self.deconv4(x) 555 556 else: 557 z9 = self.deconv1(z12) 558 z6 = self.deconv2(z9) 559 z3 = self.deconv3(z6) 560 z0 = self.deconv4(z3) 561 562 updated_from_encoder = [z9, z6, z3] 563 564 x = self.base(z12) 565 x = self.decoder(x, encoder_inputs=updated_from_encoder) 566 x = self.deconv_out(x) 567 568 x = torch.cat([x, z0], dim=1) 569 x = self.decoder_head(x) 570 571 x = self.out_conv(x) 572 if self.final_activation is not None: 573 x = self.final_activation(x) 574 575 x = self.postprocess_masks(x, input_shape, original_shape) 576 return x 577 578 579class UNETR2D(UNETR): 580 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 581 """ 582 pass 583 584 585class UNETR3D(UNETRBase): 586 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 587 """ 588 def __init__( 589 self, 590 img_size: int = 1024, 591 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 592 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 593 decoder: Optional[nn.Module] = None, 594 out_channels: int = 1, 595 use_sam_stats: bool = False, 596 use_mae_stats: bool = False, 597 use_dino_stats: bool = False, 598 resize_input: bool = True, 599 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 600 final_activation: Optional[Union[str, nn.Module]] = None, 601 use_skip_connection: bool = False, 602 embed_dim: Optional[int] = None, 603 use_conv_transpose: bool = False, 604 use_strip_pooling: bool = True, 605 **kwargs 606 ): 607 if use_skip_connection: 608 raise NotImplementedError("The framework cannot handle skip connections atm.") 609 if use_conv_transpose: 610 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 611 612 # Sort the `embed_dim` out 613 embed_dim = 256 if embed_dim is None else embed_dim 614 615 super().__init__( 616 img_size=img_size, 617 backbone=backbone, 618 encoder=encoder, 619 decoder=decoder, 620 out_channels=out_channels, 621 use_sam_stats=use_sam_stats, 622 use_mae_stats=use_mae_stats, 623 use_dino_stats=use_dino_stats, 624 resize_input=resize_input, 625 encoder_checkpoint=encoder_checkpoint, 626 final_activation=final_activation, 627 use_skip_connection=use_skip_connection, 628 embed_dim=embed_dim, 629 use_conv_transpose=use_conv_transpose, 630 **kwargs, 631 ) 632 633 # The 3d convolutional decoder. 634 # First, get the important parameters for the decoder. 635 depth = 3 636 initial_features = 64 637 gain = 2 638 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 639 scale_factors = [1, 2, 2] 640 self.out_channels = out_channels 641 642 # The mapping blocks. 643 self.deconv1 = Deconv3DBlock( 644 in_channels=embed_dim, 645 out_channels=features_decoder[0], 646 scale_factor=scale_factors, 647 use_strip_pooling=use_strip_pooling, 648 ) 649 self.deconv2 = Deconv3DBlock( 650 in_channels=features_decoder[0], 651 out_channels=features_decoder[1], 652 scale_factor=scale_factors, 653 use_strip_pooling=use_strip_pooling, 654 ) 655 self.deconv3 = Deconv3DBlock( 656 in_channels=features_decoder[1], 657 out_channels=features_decoder[2], 658 scale_factor=scale_factors, 659 use_strip_pooling=use_strip_pooling, 660 ) 661 self.deconv4 = Deconv3DBlock( 662 in_channels=features_decoder[2], 663 out_channels=features_decoder[3], 664 scale_factor=scale_factors, 665 use_strip_pooling=use_strip_pooling, 666 ) 667 668 # The core decoder block. 669 self.decoder = decoder or Decoder( 670 features=features_decoder, 671 scale_factors=[scale_factors] * depth, 672 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 673 sampler_impl=Upsampler3d, 674 ) 675 676 # And the final upsampler to match the expected dimensions. 677 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 678 in_channels=features_decoder[-1], 679 out_channels=features_decoder[-1], 680 scale_factor=scale_factors, 681 use_strip_pooling=use_strip_pooling, 682 ) 683 684 # Additional conjunction blocks. 685 self.base = ConvBlock3dWithStrip( 686 in_channels=embed_dim, 687 out_channels=features_decoder[0], 688 use_strip_pooling=use_strip_pooling, 689 ) 690 691 # And the output layers. 692 self.decoder_head = ConvBlock3dWithStrip( 693 in_channels=2 * features_decoder[-1], 694 out_channels=features_decoder[-1], 695 use_strip_pooling=use_strip_pooling, 696 ) 697 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 698 699 def forward(self, x: torch.Tensor): 700 """Forward pass of the UNETR-3D model. 701 702 Args: 703 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 704 705 Returns: 706 The UNETR output. 707 """ 708 B, C, Z, H, W = x.shape 709 original_shape = (Z, H, W) 710 711 # Preprocessing step 712 x, input_shape = self.preprocess(x) 713 714 # Run the image encoder. 715 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 716 717 # Prepare the counterparts for the decoder. 718 # NOTE: The section below is sequential, there's no skip connections atm. 719 z9 = self.deconv1(curr_features) 720 z6 = self.deconv2(z9) 721 z3 = self.deconv3(z6) 722 z0 = self.deconv4(z3) 723 724 updated_from_encoder = [z9, z6, z3] 725 726 # Align the features through the base block. 727 x = self.base(curr_features) 728 # Run the decoder 729 x = self.decoder(x, encoder_inputs=updated_from_encoder) 730 x = self.deconv_out(x) # NOTE before `end_up` 731 732 # And the final output head. 733 x = torch.cat([x, z0], dim=1) 734 x = self.decoder_head(x) 735 x = self.out_conv(x) 736 if self.final_activation is not None: 737 x = self.final_activation(x) 738 739 # Postprocess the output back to original size. 740 x = self.postprocess_masks(x, input_shape, original_shape) 741 return x 742 743# 744# ADDITIONAL FUNCTIONALITIES 745# 746 747 748def _strip_pooling_layers(enabled, channels) -> nn.Module: 749 return DepthStripPooling(channels) if enabled else nn.Identity() 750 751 752class DepthStripPooling(nn.Module): 753 """@private 754 """ 755 def __init__(self, channels: int, reduction: int = 4): 756 """Block for strip pooling along the depth dimension (only). 757 758 eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling 759 to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to 760 modulate the original features (using a gated residual). 761 762 For 2D (Z == 1): returns input unchanged (no-op). 763 764 Args: 765 channels: The output channels. 766 reduction: The reduction of the hidden layers. 767 """ 768 super().__init__() 769 hidden = max(1, channels // reduction) 770 self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1) 771 self.bn1 = nn.BatchNorm3d(hidden) 772 self.relu = nn.ReLU(inplace=True) 773 self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1) 774 775 def forward(self, x: torch.Tensor) -> torch.Tensor: 776 if x.dim() != 5: 777 raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.") 778 779 B, C, Z, H, W = x.shape 780 if Z == 1: # i.e. always the case of all 2d. 781 return x # We simply do nothing there. 782 783 # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W) 784 feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W)) 785 feat = self.conv1(feat) 786 feat = self.bn1(feat) 787 feat = self.relu(feat) 788 feat = self.conv2(feat) 789 gate = torch.sigmoid(feat).expand(B, C, Z, H, W) # Broadcast the collapsed depth context back to all slices 790 791 # Gated residual fusion 792 return x * gate + x 793 794 795class Deconv3DBlock(nn.Module): 796 """@private 797 """ 798 def __init__( 799 self, 800 scale_factor, 801 in_channels, 802 out_channels, 803 kernel_size=3, 804 anisotropic_kernel=True, 805 use_strip_pooling=True, 806 ): 807 super().__init__() 808 conv_block_kwargs = { 809 "in_channels": out_channels, 810 "out_channels": out_channels, 811 "kernel_size": kernel_size, 812 "padding": ((kernel_size - 1) // 2), 813 } 814 if anisotropic_kernel: 815 conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor) 816 817 self.block = nn.Sequential( 818 Upsampler3d(scale_factor, in_channels, out_channels), 819 nn.Conv3d(**conv_block_kwargs), 820 nn.BatchNorm3d(out_channels), 821 nn.ReLU(True), 822 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 823 ) 824 825 def forward(self, x): 826 return self.block(x) 827 828 829class ConvBlock3dWithStrip(nn.Module): 830 """@private 831 """ 832 def __init__( 833 self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs 834 ): 835 super().__init__() 836 self.block = nn.Sequential( 837 ConvBlock3d(in_channels, out_channels, **kwargs), 838 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 839 ) 840 841 def forward(self, x): 842 return self.block(x) 843 844 845class SingleDeconv2DBlock(nn.Module): 846 """@private 847 """ 848 def __init__(self, scale_factor, in_channels, out_channels): 849 super().__init__() 850 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 851 852 def forward(self, x): 853 return self.block(x) 854 855 856class SingleConv2DBlock(nn.Module): 857 """@private 858 """ 859 def __init__(self, in_channels, out_channels, kernel_size): 860 super().__init__() 861 self.block = nn.Conv2d( 862 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 863 ) 864 865 def forward(self, x): 866 return self.block(x) 867 868 869class Conv2DBlock(nn.Module): 870 """@private 871 """ 872 def __init__(self, in_channels, out_channels, kernel_size=3): 873 super().__init__() 874 self.block = nn.Sequential( 875 SingleConv2DBlock(in_channels, out_channels, kernel_size), 876 nn.BatchNorm2d(out_channels), 877 nn.ReLU(True) 878 ) 879 880 def forward(self, x): 881 return self.block(x) 882 883 884class Deconv2DBlock(nn.Module): 885 """@private 886 """ 887 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 888 super().__init__() 889 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 890 self.block = nn.Sequential( 891 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 892 SingleConv2DBlock(out_channels, out_channels, kernel_size), 893 nn.BatchNorm2d(out_channels), 894 nn.ReLU(True) 895 ) 896 897 def forward(self, x): 898 return self.block(x)
34class UNETRBase(nn.Module): 35 """Base class for implementing a UNETR. 36 37 Args: 38 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 39 backbone: The name of the vision transformer implementation. 40 One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" 41 (see all combinations below) 42 encoder: The vision transformer. Can either be a name, such as "vit_b" 43 (see all combinations for this below) or a torch module. 44 decoder: The convolutional decoder. 45 out_channels: The number of output channels of the UNETR. 46 use_sam_stats: Whether to normalize the input data with the statistics of the 47 pretrained SAM / SAM2 / SAM3 model. 48 use_dino_stats: Whether to normalize the input data with the statistics of the 49 pretrained DINOv2 / DINOv3 model. 50 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 51 resize_input: Whether to resize the input images to match `img_size`. 52 By default, it resizes the inputs to match the `img_size`. 53 encoder_checkpoint: Checkpoint for initializing the vision transformer. 54 Can either be a filepath or an already loaded checkpoint. 55 final_activation: The activation to apply to the UNETR output. 56 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 57 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 58 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 59 By default, it uses resampling for upsampling. 60 61 NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following: 62 63 SAM_family_models: 64 - 'sam' x 'vit_b' 65 - 'sam' x 'vit_l' 66 - 'sam' x 'vit_h' 67 - 'sam2' x 'hvit_t' 68 - 'sam2' x 'hvit_s' 69 - 'sam2' x 'hvit_b' 70 - 'sam2' x 'hvit_l' 71 - 'sam3' x 'vit_pe' 72 - 'cellpose_sam' x 'vit_l' 73 74 DINO_family_models: 75 - 'dinov2' x 'vit_s' 76 - 'dinov2' x 'vit_b' 77 - 'dinov2' x 'vit_l' 78 - 'dinov2' x 'vit_g' 79 - 'dinov2' x 'vit_s_reg4' 80 - 'dinov2' x 'vit_b_reg4' 81 - 'dinov2' x 'vit_l_reg4' 82 - 'dinov2' x 'vit_g_reg4' 83 - 'dinov3' x 'vit_s' 84 - 'dinov3' x 'vit_s+' 85 - 'dinov3' x 'vit_b' 86 - 'dinov3' x 'vit_l' 87 - 'dinov3' x 'vit_l+' 88 - 'dinov3' x 'vit_h+' 89 - 'dinov3' x 'vit_7b' 90 91 MAE_family_models: 92 - 'mae' x 'vit_b' 93 - 'mae' x 'vit_l' 94 - 'mae' x 'vit_h' 95 - 'scalemae' x 'vit_b' 96 - 'scalemae' x 'vit_l' 97 - 'scalemae' x 'vit_h' 98 """ 99 def __init__( 100 self, 101 img_size: int = 1024, 102 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 103 encoder: Optional[Union[nn.Module, str]] = "vit_b", 104 decoder: Optional[nn.Module] = None, 105 out_channels: int = 1, 106 use_sam_stats: bool = False, 107 use_mae_stats: bool = False, 108 use_dino_stats: bool = False, 109 resize_input: bool = True, 110 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 111 final_activation: Optional[Union[str, nn.Module]] = None, 112 use_skip_connection: bool = True, 113 embed_dim: Optional[int] = None, 114 use_conv_transpose: bool = False, 115 **kwargs 116 ) -> None: 117 super().__init__() 118 119 self.img_size = img_size 120 self.use_sam_stats = use_sam_stats 121 self.use_mae_stats = use_mae_stats 122 self.use_dino_stats = use_dino_stats 123 self.use_skip_connection = use_skip_connection 124 self.resize_input = resize_input 125 self.use_conv_transpose = use_conv_transpose 126 self.backbone = backbone 127 128 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 129 print(f"Using {encoder} from {backbone.upper()}") 130 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 131 132 if encoder_checkpoint is not None: 133 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 134 135 if embed_dim is None: 136 embed_dim = self.encoder.embed_dim 137 138 # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change. 139 if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck: 140 embed_dim = self.encoder.neck[2].out_channels # the value is 256 141 142 else: # `nn.Module` ViT backbone 143 self.encoder = encoder 144 145 have_neck = False 146 for name, _ in self.encoder.named_parameters(): 147 if name.startswith("neck"): 148 have_neck = True 149 150 if embed_dim is None: 151 if have_neck: 152 embed_dim = self.encoder.neck[2].out_channels # the value is 256 153 else: 154 embed_dim = self.encoder.patch_embed.proj.out_channels 155 156 self.embed_dim = embed_dim 157 self.final_activation = self._get_activation(final_activation) 158 159 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 160 """Function to load pretrained weights to the image encoder. 161 """ 162 if isinstance(checkpoint, str): 163 if backbone == "sam" and isinstance(encoder, str): 164 # If we have a SAM encoder, then we first try to load the full SAM Model 165 # (using micro_sam) and otherwise fall back on directly loading the encoder state 166 # from the checkpoint 167 try: 168 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 169 encoder_state = model.image_encoder.state_dict() 170 except Exception: 171 # Try loading the encoder state directly from a checkpoint. 172 encoder_state = torch.load(checkpoint, weights_only=False) 173 174 elif backbone == "cellpose_sam" and isinstance(encoder, str): 175 # The architecture matches CellposeSAM exactly (same rel_pos sizes), 176 # so weights load directly without any interpolation. 177 encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False) 178 # Handle DataParallel/DistributedDataParallel prefix. 179 if any(k.startswith("module.") for k in encoder_state.keys()): 180 encoder_state = OrderedDict( 181 {k[len("module."):]: v for k, v in encoder_state.items()} 182 ) 183 # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix). 184 if any(k.startswith("encoder.") for k in encoder_state.keys()): 185 encoder_state = OrderedDict( 186 {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")} 187 ) 188 189 elif backbone == "sam2" and isinstance(encoder, str): 190 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 191 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 192 # from the checkpoint 193 try: 194 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 195 encoder_state = model.image_encoder.state_dict() 196 except Exception: 197 # Try loading the encoder state directly from a checkpoint. 198 encoder_state = torch.load(checkpoint, weights_only=False) 199 200 elif backbone == "sam3" and isinstance(encoder, str): 201 # If we have a SAM3 encoder, then we first try to load the full SAM3 Model. 202 # (using micro_sam3) and otherwise fall back on directly loading the encoder state 203 # from the checkpoint 204 try: 205 model = get_sam3_model(checkpoint_path=checkpoint) 206 encoder_state = model.backbone.vision_backbone.state_dict() 207 # Let's align loading the encoder weights with expected parameter names 208 encoder_state = { 209 k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items() 210 } 211 # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks. 212 encoder_state = { 213 k: v for k, v in encoder_state.items() 214 if not (k.startswith("convs.") or k.startswith("sam2_convs.")) 215 } 216 except Exception: 217 # Try loading the encoder state directly from a checkpoint. 218 encoder_state = torch.load(checkpoint, weights_only=False) 219 220 elif backbone == "mae": 221 # vit initialization hints from: 222 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 223 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 224 encoder_state = OrderedDict({ 225 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 226 }) 227 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 228 current_encoder_state = self.encoder.state_dict() 229 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 230 del self.encoder.head 231 232 elif backbone == "scalemae": 233 # Load the encoder state directly from a checkpoint. 234 encoder_state = torch.load(checkpoint)["model"] 235 encoder_state = OrderedDict({ 236 k: v for k, v in encoder_state.items() 237 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 238 }) 239 240 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 241 current_encoder_state = self.encoder.state_dict() 242 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 243 del self.encoder.head 244 245 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 246 del self.encoder.pos_embed 247 248 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 249 encoder_state = torch.load(checkpoint) 250 251 else: 252 raise ValueError( 253 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 254 ) 255 256 else: 257 encoder_state = checkpoint 258 259 self.encoder.load_state_dict(encoder_state) 260 261 def _get_activation(self, activation): 262 return_activation = None 263 if activation is None: 264 return None 265 if isinstance(activation, nn.Module): 266 return activation 267 if isinstance(activation, str): 268 return_activation = getattr(nn, activation, None) 269 if return_activation is None: 270 raise ValueError(f"Invalid activation: {activation}") 271 272 return return_activation() 273 274 @staticmethod 275 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 276 """Compute the output size given input size and target long side length. 277 278 Args: 279 oldh: The input image height. 280 oldw: The input image width. 281 long_side_length: The longest side length for resizing. 282 283 Returns: 284 The new image height. 285 The new image width. 286 """ 287 scale = long_side_length * 1.0 / max(oldh, oldw) 288 newh, neww = oldh * scale, oldw * scale 289 neww = int(neww + 0.5) 290 newh = int(newh + 0.5) 291 return (newh, neww) 292 293 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 294 """Resize the image so that the longest side has the correct length. 295 296 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 297 298 Args: 299 image: The input image. 300 301 Returns: 302 The resized image. 303 """ 304 if image.ndim == 4: # i.e. 2d image 305 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 306 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 307 elif image.ndim == 5: # i.e. 3d volume 308 B, C, Z, H, W = image.shape 309 target_size = self.get_preprocess_shape(H, W, self.img_size) 310 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 311 else: 312 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 313 314 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 315 """@private 316 """ 317 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 318 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 319 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 320 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 321 return pixel_mean, pixel_std 322 323 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 324 """@private 325 """ 326 device = x.device 327 is_3d = (x.ndim == 5) 328 device, dtype = x.device, x.dtype 329 330 if self.use_sam_stats: 331 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 332 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 333 raise NotImplementedError 334 elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"): 335 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 336 elif self.use_sam_stats and self.backbone == "sam3": 337 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 338 else: 339 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 340 341 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 342 343 if self.resize_input: 344 x = self.resize_longest_side(x) 345 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 346 347 x = (x - pixel_mean) / pixel_std 348 h, w = x.shape[-2:] 349 padh = self.encoder.img_size - h 350 padw = self.encoder.img_size - w 351 352 if is_3d: 353 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 354 else: 355 x = F.pad(x, (0, padw, 0, padh)) 356 357 return x, input_shape 358 359 def postprocess_masks( 360 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 361 ) -> torch.Tensor: 362 """@private 363 """ 364 if masks.ndim == 4: # i.e. 2d labels 365 masks = F.interpolate( 366 masks, 367 (self.encoder.img_size, self.encoder.img_size), 368 mode="bilinear", 369 align_corners=False, 370 ) 371 masks = masks[..., : input_size[0], : input_size[1]] 372 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 373 374 elif masks.ndim == 5: # i.e. 3d volumetric labels 375 masks = F.interpolate( 376 masks, 377 (input_size[0], self.img_size, self.img_size), 378 mode="trilinear", 379 align_corners=False, 380 ) 381 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 382 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 383 384 else: 385 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 386 387 return masks
Base class for implementing a UNETR.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
- encoder: The vision transformer. Can either be a name, such as "vit_b" (see all combinations for this below) or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM / SAM2 / SAM3 model.
- use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv2 / DINOv3 model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size. By default, it resizes the inputs to match theimg_size. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
- NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
- SAM_family_models: - 'sam' x 'vit_b'
- 'sam' x 'vit_l'
- 'sam' x 'vit_h'
- 'sam2' x 'hvit_t'
- 'sam2' x 'hvit_s'
- 'sam2' x 'hvit_b'
- 'sam2' x 'hvit_l'
- 'sam3' x 'vit_pe'
- 'cellpose_sam' x 'vit_l'
- DINO_family_models: - 'dinov2' x 'vit_s'
- 'dinov2' x 'vit_b'
- 'dinov2' x 'vit_l'
- 'dinov2' x 'vit_g'
- 'dinov2' x 'vit_s_reg4'
- 'dinov2' x 'vit_b_reg4'
- 'dinov2' x 'vit_l_reg4'
- 'dinov2' x 'vit_g_reg4'
- 'dinov3' x 'vit_s'
- 'dinov3' x 'vit_s+'
- 'dinov3' x 'vit_b'
- 'dinov3' x 'vit_l'
- 'dinov3' x 'vit_l+'
- 'dinov3' x 'vit_h+'
- 'dinov3' x 'vit_7b'
- MAE_family_models: - 'mae' x 'vit_b'
- 'mae' x 'vit_l'
- 'mae' x 'vit_h'
- 'scalemae' x 'vit_b'
- 'scalemae' x 'vit_l'
- 'scalemae' x 'vit_h'
99 def __init__( 100 self, 101 img_size: int = 1024, 102 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 103 encoder: Optional[Union[nn.Module, str]] = "vit_b", 104 decoder: Optional[nn.Module] = None, 105 out_channels: int = 1, 106 use_sam_stats: bool = False, 107 use_mae_stats: bool = False, 108 use_dino_stats: bool = False, 109 resize_input: bool = True, 110 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 111 final_activation: Optional[Union[str, nn.Module]] = None, 112 use_skip_connection: bool = True, 113 embed_dim: Optional[int] = None, 114 use_conv_transpose: bool = False, 115 **kwargs 116 ) -> None: 117 super().__init__() 118 119 self.img_size = img_size 120 self.use_sam_stats = use_sam_stats 121 self.use_mae_stats = use_mae_stats 122 self.use_dino_stats = use_dino_stats 123 self.use_skip_connection = use_skip_connection 124 self.resize_input = resize_input 125 self.use_conv_transpose = use_conv_transpose 126 self.backbone = backbone 127 128 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 129 print(f"Using {encoder} from {backbone.upper()}") 130 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 131 132 if encoder_checkpoint is not None: 133 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 134 135 if embed_dim is None: 136 embed_dim = self.encoder.embed_dim 137 138 # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change. 139 if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck: 140 embed_dim = self.encoder.neck[2].out_channels # the value is 256 141 142 else: # `nn.Module` ViT backbone 143 self.encoder = encoder 144 145 have_neck = False 146 for name, _ in self.encoder.named_parameters(): 147 if name.startswith("neck"): 148 have_neck = True 149 150 if embed_dim is None: 151 if have_neck: 152 embed_dim = self.encoder.neck[2].out_channels # the value is 256 153 else: 154 embed_dim = self.encoder.patch_embed.proj.out_channels 155 156 self.embed_dim = embed_dim 157 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
274 @staticmethod 275 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 276 """Compute the output size given input size and target long side length. 277 278 Args: 279 oldh: The input image height. 280 oldw: The input image width. 281 long_side_length: The longest side length for resizing. 282 283 Returns: 284 The new image height. 285 The new image width. 286 """ 287 scale = long_side_length * 1.0 / max(oldh, oldw) 288 newh, neww = oldh * scale, oldw * scale 289 neww = int(neww + 0.5) 290 newh = int(newh + 0.5) 291 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
293 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 294 """Resize the image so that the longest side has the correct length. 295 296 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 297 298 Args: 299 image: The input image. 300 301 Returns: 302 The resized image. 303 """ 304 if image.ndim == 4: # i.e. 2d image 305 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 306 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 307 elif image.ndim == 5: # i.e. 3d volume 308 B, C, Z, H, W = image.shape 309 target_size = self.get_preprocess_shape(H, W, self.img_size) 310 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 311 else: 312 raise ValueError("Expected 4d or 5d inputs, got", image.shape)
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
390class UNETR(UNETRBase): 391 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 392 """ 393 def __init__( 394 self, 395 img_size: int = 1024, 396 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 397 encoder: Optional[Union[nn.Module, str]] = "vit_b", 398 decoder: Optional[nn.Module] = None, 399 out_channels: int = 1, 400 use_sam_stats: bool = False, 401 use_mae_stats: bool = False, 402 use_dino_stats: bool = False, 403 resize_input: bool = True, 404 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 405 final_activation: Optional[Union[str, nn.Module]] = None, 406 use_skip_connection: bool = True, 407 embed_dim: Optional[int] = None, 408 use_conv_transpose: bool = False, 409 **kwargs 410 ) -> None: 411 412 super().__init__( 413 img_size=img_size, 414 backbone=backbone, 415 encoder=encoder, 416 decoder=decoder, 417 out_channels=out_channels, 418 use_sam_stats=use_sam_stats, 419 use_mae_stats=use_mae_stats, 420 use_dino_stats=use_dino_stats, 421 resize_input=resize_input, 422 encoder_checkpoint=encoder_checkpoint, 423 final_activation=final_activation, 424 use_skip_connection=use_skip_connection, 425 embed_dim=embed_dim, 426 use_conv_transpose=use_conv_transpose, 427 **kwargs, 428 ) 429 430 encoder = self.encoder 431 432 if backbone == "sam2" and hasattr(encoder, "trunk"): 433 in_chans = encoder.trunk.patch_embed.proj.in_channels 434 elif hasattr(encoder, "in_chans"): 435 in_chans = encoder.in_chans 436 else: # `nn.Module` ViT backbone. 437 try: 438 in_chans = encoder.patch_embed.proj.in_channels 439 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 440 in_chans = encoder.patch_embed.seq[0].c.in_channels 441 442 # parameters for the decoder network 443 depth = 3 444 initial_features = 64 445 gain = 2 446 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 447 scale_factors = depth * [2] 448 self.out_channels = out_channels 449 450 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 451 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 452 453 self.decoder = decoder or Decoder( 454 features=features_decoder, 455 scale_factors=scale_factors[::-1], 456 conv_block_impl=ConvBlock2d, 457 sampler_impl=_upsampler, 458 ) 459 460 if use_skip_connection: 461 self.deconv1 = Deconv2DBlock( 462 in_channels=self.embed_dim, 463 out_channels=features_decoder[0], 464 use_conv_transpose=use_conv_transpose, 465 ) 466 self.deconv2 = nn.Sequential( 467 Deconv2DBlock( 468 in_channels=self.embed_dim, 469 out_channels=features_decoder[0], 470 use_conv_transpose=use_conv_transpose, 471 ), 472 Deconv2DBlock( 473 in_channels=features_decoder[0], 474 out_channels=features_decoder[1], 475 use_conv_transpose=use_conv_transpose, 476 ) 477 ) 478 self.deconv3 = nn.Sequential( 479 Deconv2DBlock( 480 in_channels=self.embed_dim, 481 out_channels=features_decoder[0], 482 use_conv_transpose=use_conv_transpose, 483 ), 484 Deconv2DBlock( 485 in_channels=features_decoder[0], 486 out_channels=features_decoder[1], 487 use_conv_transpose=use_conv_transpose, 488 ), 489 Deconv2DBlock( 490 in_channels=features_decoder[1], 491 out_channels=features_decoder[2], 492 use_conv_transpose=use_conv_transpose, 493 ) 494 ) 495 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 496 else: 497 self.deconv1 = Deconv2DBlock( 498 in_channels=self.embed_dim, 499 out_channels=features_decoder[0], 500 use_conv_transpose=use_conv_transpose, 501 ) 502 self.deconv2 = Deconv2DBlock( 503 in_channels=features_decoder[0], 504 out_channels=features_decoder[1], 505 use_conv_transpose=use_conv_transpose, 506 ) 507 self.deconv3 = Deconv2DBlock( 508 in_channels=features_decoder[1], 509 out_channels=features_decoder[2], 510 use_conv_transpose=use_conv_transpose, 511 ) 512 self.deconv4 = Deconv2DBlock( 513 in_channels=features_decoder[2], 514 out_channels=features_decoder[3], 515 use_conv_transpose=use_conv_transpose, 516 ) 517 518 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 519 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 520 self.deconv_out = _upsampler( 521 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 522 ) 523 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 524 525 def forward(self, x: torch.Tensor) -> torch.Tensor: 526 """Apply the UNETR to the input data. 527 528 Args: 529 x: The input tensor. 530 531 Returns: 532 The UNETR output. 533 """ 534 original_shape = x.shape[-2:] 535 536 # Reshape the inputs to the shape expected by the encoder 537 # and normalize the inputs if normalization is part of the model. 538 x, input_shape = self.preprocess(x) 539 540 encoder_outputs = self.encoder(x) 541 542 if isinstance(encoder_outputs[-1], list): 543 # `encoder_outputs` can be arranged in only two forms: 544 # - either we only return the image embeddings 545 # - or, we return the image embeddings and the "list" of global attention layers 546 z12, from_encoder = encoder_outputs 547 else: 548 z12 = encoder_outputs 549 550 if self.use_skip_connection: 551 from_encoder = from_encoder[::-1] 552 z9 = self.deconv1(from_encoder[0]) 553 z6 = self.deconv2(from_encoder[1]) 554 z3 = self.deconv3(from_encoder[2]) 555 z0 = self.deconv4(x) 556 557 else: 558 z9 = self.deconv1(z12) 559 z6 = self.deconv2(z9) 560 z3 = self.deconv3(z6) 561 z0 = self.deconv4(z3) 562 563 updated_from_encoder = [z9, z6, z3] 564 565 x = self.base(z12) 566 x = self.decoder(x, encoder_inputs=updated_from_encoder) 567 x = self.deconv_out(x) 568 569 x = torch.cat([x, z0], dim=1) 570 x = self.decoder_head(x) 571 572 x = self.out_conv(x) 573 if self.final_activation is not None: 574 x = self.final_activation(x) 575 576 x = self.postprocess_masks(x, input_shape, original_shape) 577 return x
A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
393 def __init__( 394 self, 395 img_size: int = 1024, 396 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 397 encoder: Optional[Union[nn.Module, str]] = "vit_b", 398 decoder: Optional[nn.Module] = None, 399 out_channels: int = 1, 400 use_sam_stats: bool = False, 401 use_mae_stats: bool = False, 402 use_dino_stats: bool = False, 403 resize_input: bool = True, 404 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 405 final_activation: Optional[Union[str, nn.Module]] = None, 406 use_skip_connection: bool = True, 407 embed_dim: Optional[int] = None, 408 use_conv_transpose: bool = False, 409 **kwargs 410 ) -> None: 411 412 super().__init__( 413 img_size=img_size, 414 backbone=backbone, 415 encoder=encoder, 416 decoder=decoder, 417 out_channels=out_channels, 418 use_sam_stats=use_sam_stats, 419 use_mae_stats=use_mae_stats, 420 use_dino_stats=use_dino_stats, 421 resize_input=resize_input, 422 encoder_checkpoint=encoder_checkpoint, 423 final_activation=final_activation, 424 use_skip_connection=use_skip_connection, 425 embed_dim=embed_dim, 426 use_conv_transpose=use_conv_transpose, 427 **kwargs, 428 ) 429 430 encoder = self.encoder 431 432 if backbone == "sam2" and hasattr(encoder, "trunk"): 433 in_chans = encoder.trunk.patch_embed.proj.in_channels 434 elif hasattr(encoder, "in_chans"): 435 in_chans = encoder.in_chans 436 else: # `nn.Module` ViT backbone. 437 try: 438 in_chans = encoder.patch_embed.proj.in_channels 439 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 440 in_chans = encoder.patch_embed.seq[0].c.in_channels 441 442 # parameters for the decoder network 443 depth = 3 444 initial_features = 64 445 gain = 2 446 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 447 scale_factors = depth * [2] 448 self.out_channels = out_channels 449 450 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 451 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 452 453 self.decoder = decoder or Decoder( 454 features=features_decoder, 455 scale_factors=scale_factors[::-1], 456 conv_block_impl=ConvBlock2d, 457 sampler_impl=_upsampler, 458 ) 459 460 if use_skip_connection: 461 self.deconv1 = Deconv2DBlock( 462 in_channels=self.embed_dim, 463 out_channels=features_decoder[0], 464 use_conv_transpose=use_conv_transpose, 465 ) 466 self.deconv2 = nn.Sequential( 467 Deconv2DBlock( 468 in_channels=self.embed_dim, 469 out_channels=features_decoder[0], 470 use_conv_transpose=use_conv_transpose, 471 ), 472 Deconv2DBlock( 473 in_channels=features_decoder[0], 474 out_channels=features_decoder[1], 475 use_conv_transpose=use_conv_transpose, 476 ) 477 ) 478 self.deconv3 = nn.Sequential( 479 Deconv2DBlock( 480 in_channels=self.embed_dim, 481 out_channels=features_decoder[0], 482 use_conv_transpose=use_conv_transpose, 483 ), 484 Deconv2DBlock( 485 in_channels=features_decoder[0], 486 out_channels=features_decoder[1], 487 use_conv_transpose=use_conv_transpose, 488 ), 489 Deconv2DBlock( 490 in_channels=features_decoder[1], 491 out_channels=features_decoder[2], 492 use_conv_transpose=use_conv_transpose, 493 ) 494 ) 495 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 496 else: 497 self.deconv1 = Deconv2DBlock( 498 in_channels=self.embed_dim, 499 out_channels=features_decoder[0], 500 use_conv_transpose=use_conv_transpose, 501 ) 502 self.deconv2 = Deconv2DBlock( 503 in_channels=features_decoder[0], 504 out_channels=features_decoder[1], 505 use_conv_transpose=use_conv_transpose, 506 ) 507 self.deconv3 = Deconv2DBlock( 508 in_channels=features_decoder[1], 509 out_channels=features_decoder[2], 510 use_conv_transpose=use_conv_transpose, 511 ) 512 self.deconv4 = Deconv2DBlock( 513 in_channels=features_decoder[2], 514 out_channels=features_decoder[3], 515 use_conv_transpose=use_conv_transpose, 516 ) 517 518 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 519 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 520 self.deconv_out = _upsampler( 521 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 522 ) 523 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
525 def forward(self, x: torch.Tensor) -> torch.Tensor: 526 """Apply the UNETR to the input data. 527 528 Args: 529 x: The input tensor. 530 531 Returns: 532 The UNETR output. 533 """ 534 original_shape = x.shape[-2:] 535 536 # Reshape the inputs to the shape expected by the encoder 537 # and normalize the inputs if normalization is part of the model. 538 x, input_shape = self.preprocess(x) 539 540 encoder_outputs = self.encoder(x) 541 542 if isinstance(encoder_outputs[-1], list): 543 # `encoder_outputs` can be arranged in only two forms: 544 # - either we only return the image embeddings 545 # - or, we return the image embeddings and the "list" of global attention layers 546 z12, from_encoder = encoder_outputs 547 else: 548 z12 = encoder_outputs 549 550 if self.use_skip_connection: 551 from_encoder = from_encoder[::-1] 552 z9 = self.deconv1(from_encoder[0]) 553 z6 = self.deconv2(from_encoder[1]) 554 z3 = self.deconv3(from_encoder[2]) 555 z0 = self.deconv4(x) 556 557 else: 558 z9 = self.deconv1(z12) 559 z6 = self.deconv2(z9) 560 z3 = self.deconv3(z6) 561 z0 = self.deconv4(z3) 562 563 updated_from_encoder = [z9, z6, z3] 564 565 x = self.base(z12) 566 x = self.decoder(x, encoder_inputs=updated_from_encoder) 567 x = self.deconv_out(x) 568 569 x = torch.cat([x, z0], dim=1) 570 x = self.decoder_head(x) 571 572 x = self.out_conv(x) 573 if self.final_activation is not None: 574 x = self.final_activation(x) 575 576 x = self.postprocess_masks(x, input_shape, original_shape) 577 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.
580class UNETR2D(UNETR): 581 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 582 """ 583 pass
A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
586class UNETR3D(UNETRBase): 587 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 588 """ 589 def __init__( 590 self, 591 img_size: int = 1024, 592 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 593 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 594 decoder: Optional[nn.Module] = None, 595 out_channels: int = 1, 596 use_sam_stats: bool = False, 597 use_mae_stats: bool = False, 598 use_dino_stats: bool = False, 599 resize_input: bool = True, 600 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 601 final_activation: Optional[Union[str, nn.Module]] = None, 602 use_skip_connection: bool = False, 603 embed_dim: Optional[int] = None, 604 use_conv_transpose: bool = False, 605 use_strip_pooling: bool = True, 606 **kwargs 607 ): 608 if use_skip_connection: 609 raise NotImplementedError("The framework cannot handle skip connections atm.") 610 if use_conv_transpose: 611 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 612 613 # Sort the `embed_dim` out 614 embed_dim = 256 if embed_dim is None else embed_dim 615 616 super().__init__( 617 img_size=img_size, 618 backbone=backbone, 619 encoder=encoder, 620 decoder=decoder, 621 out_channels=out_channels, 622 use_sam_stats=use_sam_stats, 623 use_mae_stats=use_mae_stats, 624 use_dino_stats=use_dino_stats, 625 resize_input=resize_input, 626 encoder_checkpoint=encoder_checkpoint, 627 final_activation=final_activation, 628 use_skip_connection=use_skip_connection, 629 embed_dim=embed_dim, 630 use_conv_transpose=use_conv_transpose, 631 **kwargs, 632 ) 633 634 # The 3d convolutional decoder. 635 # First, get the important parameters for the decoder. 636 depth = 3 637 initial_features = 64 638 gain = 2 639 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 640 scale_factors = [1, 2, 2] 641 self.out_channels = out_channels 642 643 # The mapping blocks. 644 self.deconv1 = Deconv3DBlock( 645 in_channels=embed_dim, 646 out_channels=features_decoder[0], 647 scale_factor=scale_factors, 648 use_strip_pooling=use_strip_pooling, 649 ) 650 self.deconv2 = Deconv3DBlock( 651 in_channels=features_decoder[0], 652 out_channels=features_decoder[1], 653 scale_factor=scale_factors, 654 use_strip_pooling=use_strip_pooling, 655 ) 656 self.deconv3 = Deconv3DBlock( 657 in_channels=features_decoder[1], 658 out_channels=features_decoder[2], 659 scale_factor=scale_factors, 660 use_strip_pooling=use_strip_pooling, 661 ) 662 self.deconv4 = Deconv3DBlock( 663 in_channels=features_decoder[2], 664 out_channels=features_decoder[3], 665 scale_factor=scale_factors, 666 use_strip_pooling=use_strip_pooling, 667 ) 668 669 # The core decoder block. 670 self.decoder = decoder or Decoder( 671 features=features_decoder, 672 scale_factors=[scale_factors] * depth, 673 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 674 sampler_impl=Upsampler3d, 675 ) 676 677 # And the final upsampler to match the expected dimensions. 678 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 679 in_channels=features_decoder[-1], 680 out_channels=features_decoder[-1], 681 scale_factor=scale_factors, 682 use_strip_pooling=use_strip_pooling, 683 ) 684 685 # Additional conjunction blocks. 686 self.base = ConvBlock3dWithStrip( 687 in_channels=embed_dim, 688 out_channels=features_decoder[0], 689 use_strip_pooling=use_strip_pooling, 690 ) 691 692 # And the output layers. 693 self.decoder_head = ConvBlock3dWithStrip( 694 in_channels=2 * features_decoder[-1], 695 out_channels=features_decoder[-1], 696 use_strip_pooling=use_strip_pooling, 697 ) 698 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 699 700 def forward(self, x: torch.Tensor): 701 """Forward pass of the UNETR-3D model. 702 703 Args: 704 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 705 706 Returns: 707 The UNETR output. 708 """ 709 B, C, Z, H, W = x.shape 710 original_shape = (Z, H, W) 711 712 # Preprocessing step 713 x, input_shape = self.preprocess(x) 714 715 # Run the image encoder. 716 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 717 718 # Prepare the counterparts for the decoder. 719 # NOTE: The section below is sequential, there's no skip connections atm. 720 z9 = self.deconv1(curr_features) 721 z6 = self.deconv2(z9) 722 z3 = self.deconv3(z6) 723 z0 = self.deconv4(z3) 724 725 updated_from_encoder = [z9, z6, z3] 726 727 # Align the features through the base block. 728 x = self.base(curr_features) 729 # Run the decoder 730 x = self.decoder(x, encoder_inputs=updated_from_encoder) 731 x = self.deconv_out(x) # NOTE before `end_up` 732 733 # And the final output head. 734 x = torch.cat([x, z0], dim=1) 735 x = self.decoder_head(x) 736 x = self.out_conv(x) 737 if self.final_activation is not None: 738 x = self.final_activation(x) 739 740 # Postprocess the output back to original size. 741 x = self.postprocess_masks(x, input_shape, original_shape) 742 return x
A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
589 def __init__( 590 self, 591 img_size: int = 1024, 592 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 593 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 594 decoder: Optional[nn.Module] = None, 595 out_channels: int = 1, 596 use_sam_stats: bool = False, 597 use_mae_stats: bool = False, 598 use_dino_stats: bool = False, 599 resize_input: bool = True, 600 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 601 final_activation: Optional[Union[str, nn.Module]] = None, 602 use_skip_connection: bool = False, 603 embed_dim: Optional[int] = None, 604 use_conv_transpose: bool = False, 605 use_strip_pooling: bool = True, 606 **kwargs 607 ): 608 if use_skip_connection: 609 raise NotImplementedError("The framework cannot handle skip connections atm.") 610 if use_conv_transpose: 611 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 612 613 # Sort the `embed_dim` out 614 embed_dim = 256 if embed_dim is None else embed_dim 615 616 super().__init__( 617 img_size=img_size, 618 backbone=backbone, 619 encoder=encoder, 620 decoder=decoder, 621 out_channels=out_channels, 622 use_sam_stats=use_sam_stats, 623 use_mae_stats=use_mae_stats, 624 use_dino_stats=use_dino_stats, 625 resize_input=resize_input, 626 encoder_checkpoint=encoder_checkpoint, 627 final_activation=final_activation, 628 use_skip_connection=use_skip_connection, 629 embed_dim=embed_dim, 630 use_conv_transpose=use_conv_transpose, 631 **kwargs, 632 ) 633 634 # The 3d convolutional decoder. 635 # First, get the important parameters for the decoder. 636 depth = 3 637 initial_features = 64 638 gain = 2 639 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 640 scale_factors = [1, 2, 2] 641 self.out_channels = out_channels 642 643 # The mapping blocks. 644 self.deconv1 = Deconv3DBlock( 645 in_channels=embed_dim, 646 out_channels=features_decoder[0], 647 scale_factor=scale_factors, 648 use_strip_pooling=use_strip_pooling, 649 ) 650 self.deconv2 = Deconv3DBlock( 651 in_channels=features_decoder[0], 652 out_channels=features_decoder[1], 653 scale_factor=scale_factors, 654 use_strip_pooling=use_strip_pooling, 655 ) 656 self.deconv3 = Deconv3DBlock( 657 in_channels=features_decoder[1], 658 out_channels=features_decoder[2], 659 scale_factor=scale_factors, 660 use_strip_pooling=use_strip_pooling, 661 ) 662 self.deconv4 = Deconv3DBlock( 663 in_channels=features_decoder[2], 664 out_channels=features_decoder[3], 665 scale_factor=scale_factors, 666 use_strip_pooling=use_strip_pooling, 667 ) 668 669 # The core decoder block. 670 self.decoder = decoder or Decoder( 671 features=features_decoder, 672 scale_factors=[scale_factors] * depth, 673 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 674 sampler_impl=Upsampler3d, 675 ) 676 677 # And the final upsampler to match the expected dimensions. 678 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 679 in_channels=features_decoder[-1], 680 out_channels=features_decoder[-1], 681 scale_factor=scale_factors, 682 use_strip_pooling=use_strip_pooling, 683 ) 684 685 # Additional conjunction blocks. 686 self.base = ConvBlock3dWithStrip( 687 in_channels=embed_dim, 688 out_channels=features_decoder[0], 689 use_strip_pooling=use_strip_pooling, 690 ) 691 692 # And the output layers. 693 self.decoder_head = ConvBlock3dWithStrip( 694 in_channels=2 * features_decoder[-1], 695 out_channels=features_decoder[-1], 696 use_strip_pooling=use_strip_pooling, 697 ) 698 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
700 def forward(self, x: torch.Tensor): 701 """Forward pass of the UNETR-3D model. 702 703 Args: 704 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 705 706 Returns: 707 The UNETR output. 708 """ 709 B, C, Z, H, W = x.shape 710 original_shape = (Z, H, W) 711 712 # Preprocessing step 713 x, input_shape = self.preprocess(x) 714 715 # Run the image encoder. 716 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 717 718 # Prepare the counterparts for the decoder. 719 # NOTE: The section below is sequential, there's no skip connections atm. 720 z9 = self.deconv1(curr_features) 721 z6 = self.deconv2(z9) 722 z3 = self.deconv3(z6) 723 z0 = self.deconv4(z3) 724 725 updated_from_encoder = [z9, z6, z3] 726 727 # Align the features through the base block. 728 x = self.base(curr_features) 729 # Run the decoder 730 x = self.decoder(x, encoder_inputs=updated_from_encoder) 731 x = self.deconv_out(x) # NOTE before `end_up` 732 733 # And the final output head. 734 x = torch.cat([x, z0], dim=1) 735 x = self.decoder_head(x) 736 x = self.out_conv(x) 737 if self.final_activation is not None: 738 x = self.final_activation(x) 739 740 # Postprocess the output back to original size. 741 x = self.postprocess_masks(x, input_shape, original_shape) 742 return x
Forward pass of the UNETR-3D model.
Arguments:
- x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:
The UNETR output.