torch_em.model.unetr
1from functools import partial 2from collections import OrderedDict 3from typing import Optional, Tuple, Union, Literal 4 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8 9from .vit import get_vision_transformer 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs 11 12try: 13 from micro_sam.util import get_sam_model 14except ImportError: 15 get_sam_model = None 16 17try: 18 from micro_sam.v2.util import get_sam2_model 19except ImportError: 20 get_sam2_model = None 21 22try: 23 from micro_sam3.util import get_sam3_model 24except ImportError: 25 get_sam3_model = None 26 27 28# 29# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / CellposeSAM / SAM2 / SAM3 / DINOv2 / DINOv3 / MAE / ScaleMAE) + UNet Decoder from `torch_em`] # noqa 30# 31 32 33class UNETRBase(nn.Module): 34 """Base class for implementing a UNETR. 35 36 Args: 37 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 38 backbone: The name of the vision transformer implementation. 39 One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" 40 (see all combinations below) 41 encoder: The vision transformer. Can either be a name, such as "vit_b" 42 (see all combinations for this below) or a torch module. 43 decoder: The convolutional decoder. 44 out_channels: The number of output channels of the UNETR. 45 use_sam_stats: Whether to normalize the input data with the statistics of the 46 pretrained SAM / SAM2 / SAM3 model. 47 use_dino_stats: Whether to normalize the input data with the statistics of the 48 pretrained DINOv2 / DINOv3 model. 49 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 50 resize_input: Whether to resize the input images to match `img_size`. 51 By default, it resizes the inputs to match the `img_size`. 52 encoder_checkpoint: Checkpoint for initializing the vision transformer. 53 Can either be a filepath or an already loaded checkpoint. 54 final_activation: The activation to apply to the UNETR output. 55 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 56 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 57 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 58 By default, it uses resampling for upsampling. 59 60 NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following: 61 62 SAM_family_models: 63 - 'sam' x 'vit_b' 64 - 'sam' x 'vit_l' 65 - 'sam' x 'vit_h' 66 - 'sam2' x 'hvit_t' 67 - 'sam2' x 'hvit_s' 68 - 'sam2' x 'hvit_b' 69 - 'sam2' x 'hvit_l' 70 - 'sam3' x 'vit_pe' 71 - 'cellpose_sam' x 'vit_l' 72 73 DINO_family_models: 74 - 'dinov2' x 'vit_s' 75 - 'dinov2' x 'vit_b' 76 - 'dinov2' x 'vit_l' 77 - 'dinov2' x 'vit_g' 78 - 'dinov2' x 'vit_s_reg4' 79 - 'dinov2' x 'vit_b_reg4' 80 - 'dinov2' x 'vit_l_reg4' 81 - 'dinov2' x 'vit_g_reg4' 82 - 'dinov3' x 'vit_s' 83 - 'dinov3' x 'vit_s+' 84 - 'dinov3' x 'vit_b' 85 - 'dinov3' x 'vit_l' 86 - 'dinov3' x 'vit_l+' 87 - 'dinov3' x 'vit_h+' 88 - 'dinov3' x 'vit_7b' 89 90 MAE_family_models: 91 - 'mae' x 'vit_b' 92 - 'mae' x 'vit_l' 93 - 'mae' x 'vit_h' 94 - 'scalemae' x 'vit_b' 95 - 'scalemae' x 'vit_l' 96 - 'scalemae' x 'vit_h' 97 """ 98 def __init__( 99 self, 100 img_size: int = 1024, 101 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 102 encoder: Optional[Union[nn.Module, str]] = "vit_b", 103 decoder: Optional[nn.Module] = None, 104 out_channels: int = 1, 105 use_sam_stats: bool = False, 106 use_mae_stats: bool = False, 107 use_dino_stats: bool = False, 108 resize_input: bool = True, 109 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 110 final_activation: Optional[Union[str, nn.Module]] = None, 111 use_skip_connection: bool = True, 112 embed_dim: Optional[int] = None, 113 use_conv_transpose: bool = False, 114 **kwargs 115 ) -> None: 116 super().__init__() 117 118 self.img_size = img_size 119 self.use_sam_stats = use_sam_stats 120 self.use_mae_stats = use_mae_stats 121 self.use_dino_stats = use_dino_stats 122 self.use_skip_connection = use_skip_connection 123 self.resize_input = resize_input 124 self.use_conv_transpose = use_conv_transpose 125 self.backbone = backbone 126 127 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 128 print(f"Using {encoder} from {backbone.upper()}") 129 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 130 131 if encoder_checkpoint is not None: 132 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 133 134 if embed_dim is None: 135 embed_dim = self.encoder.embed_dim 136 137 # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change. 138 if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck: 139 embed_dim = self.encoder.neck[2].out_channels # the value is 256 140 141 else: # `nn.Module` ViT backbone 142 self.encoder = encoder 143 144 have_neck = False 145 for name, _ in self.encoder.named_parameters(): 146 if name.startswith("neck"): 147 have_neck = True 148 149 if embed_dim is None: 150 if have_neck: 151 embed_dim = self.encoder.neck[2].out_channels # the value is 256 152 else: 153 embed_dim = self.encoder.patch_embed.proj.out_channels 154 155 self.embed_dim = embed_dim 156 self.final_activation = self._get_activation(final_activation) 157 158 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 159 """Function to load pretrained weights to the image encoder. 160 """ 161 if isinstance(checkpoint, str): 162 if backbone == "sam" and isinstance(encoder, str): 163 # If we have a SAM encoder, then we first try to load the full SAM Model 164 # (using micro_sam) and otherwise fall back on directly loading the encoder state 165 # from the checkpoint 166 try: 167 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 168 encoder_state = model.image_encoder.state_dict() 169 except Exception: 170 # Try loading the encoder state directly from a checkpoint. 171 encoder_state = torch.load(checkpoint, weights_only=False) 172 173 elif backbone == "cellpose_sam" and isinstance(encoder, str): 174 # The architecture matches CellposeSAM exactly (same rel_pos sizes), 175 # so weights load directly without any interpolation. 176 encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False) 177 # Handle DataParallel/DistributedDataParallel prefix. 178 if any(k.startswith("module.") for k in encoder_state.keys()): 179 encoder_state = OrderedDict( 180 {k[len("module."):]: v for k, v in encoder_state.items()} 181 ) 182 # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix). 183 if any(k.startswith("encoder.") for k in encoder_state.keys()): 184 encoder_state = OrderedDict( 185 {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")} 186 ) 187 188 elif backbone == "sam2" and isinstance(encoder, str): 189 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 190 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 191 # from the checkpoint 192 try: 193 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 194 encoder_state = model.image_encoder.state_dict() 195 except Exception: 196 # Try loading the encoder state directly from a checkpoint. 197 encoder_state = torch.load(checkpoint, weights_only=False) 198 199 elif backbone == "sam3" and isinstance(encoder, str): 200 # If we have a SAM3 encoder, then we first try to load the full SAM3 Model. 201 # (using micro_sam3) and otherwise fall back on directly loading the encoder state 202 # from the checkpoint 203 try: 204 model = get_sam3_model(checkpoint_path=checkpoint) 205 encoder_state = model.backbone.vision_backbone.state_dict() 206 # Let's align loading the encoder weights with expected parameter names 207 encoder_state = { 208 k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items() 209 } 210 # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks. 211 encoder_state = { 212 k: v for k, v in encoder_state.items() 213 if not (k.startswith("convs.") or k.startswith("sam2_convs.")) 214 } 215 except Exception: 216 # Try loading the encoder state directly from a checkpoint. 217 encoder_state = torch.load(checkpoint, weights_only=False) 218 219 elif backbone == "mae": 220 # vit initialization hints from: 221 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 222 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 223 encoder_state = OrderedDict({ 224 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 225 }) 226 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 227 current_encoder_state = self.encoder.state_dict() 228 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 229 del self.encoder.head 230 231 elif backbone == "scalemae": 232 # Load the encoder state directly from a checkpoint. 233 encoder_state = torch.load(checkpoint)["model"] 234 encoder_state = OrderedDict({ 235 k: v for k, v in encoder_state.items() 236 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 237 }) 238 239 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 240 current_encoder_state = self.encoder.state_dict() 241 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 242 del self.encoder.head 243 244 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 245 del self.encoder.pos_embed 246 247 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 248 encoder_state = torch.load(checkpoint) 249 250 else: 251 raise ValueError( 252 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 253 ) 254 255 else: 256 encoder_state = checkpoint 257 258 self.encoder.load_state_dict(encoder_state) 259 260 def _get_activation(self, activation): 261 return_activation = None 262 if activation is None: 263 return None 264 if isinstance(activation, nn.Module): 265 return activation 266 if isinstance(activation, str): 267 return_activation = getattr(nn, activation, None) 268 if return_activation is None: 269 raise ValueError(f"Invalid activation: {activation}") 270 271 return return_activation() 272 273 @staticmethod 274 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 275 """Compute the output size given input size and target long side length. 276 277 Args: 278 oldh: The input image height. 279 oldw: The input image width. 280 long_side_length: The longest side length for resizing. 281 282 Returns: 283 The new image height. 284 The new image width. 285 """ 286 scale = long_side_length * 1.0 / max(oldh, oldw) 287 newh, neww = oldh * scale, oldw * scale 288 neww = int(neww + 0.5) 289 newh = int(newh + 0.5) 290 return (newh, neww) 291 292 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 293 """Resize the image so that the longest side has the correct length. 294 295 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 296 297 Args: 298 image: The input image. 299 300 Returns: 301 The resized image. 302 """ 303 if image.ndim == 4: # i.e. 2d image 304 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 305 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 306 elif image.ndim == 5: # i.e. 3d volume 307 B, C, Z, H, W = image.shape 308 target_size = self.get_preprocess_shape(H, W, self.img_size) 309 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 310 else: 311 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 312 313 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 314 """@private 315 """ 316 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 317 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 318 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 319 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 320 return pixel_mean, pixel_std 321 322 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 323 """@private 324 """ 325 is_3d = (x.ndim == 5) 326 device, dtype = x.device, x.dtype 327 328 if self.use_sam_stats: 329 if self.backbone == "sam2": 330 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 331 elif self.backbone == "sam3": 332 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 333 else: # sam1 / default 334 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 335 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 336 raise NotImplementedError 337 elif self.use_dino_stats: 338 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 339 else: 340 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 341 342 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 343 344 if self.resize_input: 345 x = self.resize_longest_side(x) 346 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 347 348 x = (x - pixel_mean) / pixel_std 349 h, w = x.shape[-2:] 350 padh = self.encoder.img_size - h 351 padw = self.encoder.img_size - w 352 353 if is_3d: 354 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 355 else: 356 x = F.pad(x, (0, padw, 0, padh)) 357 358 return x, input_shape 359 360 def postprocess_masks( 361 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 362 ) -> torch.Tensor: 363 """@private 364 """ 365 if masks.ndim == 4: # i.e. 2d labels 366 masks = F.interpolate( 367 masks, 368 (self.encoder.img_size, self.encoder.img_size), 369 mode="bilinear", 370 align_corners=False, 371 ) 372 masks = masks[..., : input_size[0], : input_size[1]] 373 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 374 375 elif masks.ndim == 5: # i.e. 3d volumetric labels 376 masks = F.interpolate( 377 masks, 378 (input_size[0], self.img_size, self.img_size), 379 mode="trilinear", 380 align_corners=False, 381 ) 382 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 383 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 384 385 else: 386 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 387 388 return masks 389 390 391class UNETR(UNETRBase): 392 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 393 """ 394 def __init__( 395 self, 396 img_size: int = 1024, 397 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 398 encoder: Optional[Union[nn.Module, str]] = "vit_b", 399 decoder: Optional[nn.Module] = None, 400 out_channels: int = 1, 401 use_sam_stats: bool = False, 402 use_mae_stats: bool = False, 403 use_dino_stats: bool = False, 404 resize_input: bool = True, 405 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 406 final_activation: Optional[Union[str, nn.Module]] = None, 407 use_skip_connection: bool = True, 408 embed_dim: Optional[int] = None, 409 use_conv_transpose: bool = False, 410 **kwargs 411 ) -> None: 412 413 super().__init__( 414 img_size=img_size, 415 backbone=backbone, 416 encoder=encoder, 417 decoder=decoder, 418 out_channels=out_channels, 419 use_sam_stats=use_sam_stats, 420 use_mae_stats=use_mae_stats, 421 use_dino_stats=use_dino_stats, 422 resize_input=resize_input, 423 encoder_checkpoint=encoder_checkpoint, 424 final_activation=final_activation, 425 use_skip_connection=use_skip_connection, 426 embed_dim=embed_dim, 427 use_conv_transpose=use_conv_transpose, 428 **kwargs, 429 ) 430 431 encoder = self.encoder 432 433 if backbone == "sam2" and hasattr(encoder, "trunk"): 434 in_chans = encoder.trunk.patch_embed.proj.in_channels 435 elif hasattr(encoder, "in_chans"): 436 in_chans = encoder.in_chans 437 else: # `nn.Module` ViT backbone. 438 try: 439 in_chans = encoder.patch_embed.proj.in_channels 440 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 441 in_chans = encoder.patch_embed.seq[0].c.in_channels 442 443 # parameters for the decoder network 444 depth = 3 445 initial_features = 64 446 gain = 2 447 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 448 scale_factors = depth * [2] 449 self.out_channels = out_channels 450 451 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 452 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 453 454 self.decoder = decoder or Decoder( 455 features=features_decoder, 456 scale_factors=scale_factors[::-1], 457 conv_block_impl=ConvBlock2d, 458 sampler_impl=_upsampler, 459 ) 460 461 if use_skip_connection: 462 self.deconv1 = Deconv2DBlock( 463 in_channels=self.embed_dim, 464 out_channels=features_decoder[0], 465 use_conv_transpose=use_conv_transpose, 466 ) 467 self.deconv2 = nn.Sequential( 468 Deconv2DBlock( 469 in_channels=self.embed_dim, 470 out_channels=features_decoder[0], 471 use_conv_transpose=use_conv_transpose, 472 ), 473 Deconv2DBlock( 474 in_channels=features_decoder[0], 475 out_channels=features_decoder[1], 476 use_conv_transpose=use_conv_transpose, 477 ) 478 ) 479 self.deconv3 = nn.Sequential( 480 Deconv2DBlock( 481 in_channels=self.embed_dim, 482 out_channels=features_decoder[0], 483 use_conv_transpose=use_conv_transpose, 484 ), 485 Deconv2DBlock( 486 in_channels=features_decoder[0], 487 out_channels=features_decoder[1], 488 use_conv_transpose=use_conv_transpose, 489 ), 490 Deconv2DBlock( 491 in_channels=features_decoder[1], 492 out_channels=features_decoder[2], 493 use_conv_transpose=use_conv_transpose, 494 ) 495 ) 496 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 497 else: 498 self.deconv1 = Deconv2DBlock( 499 in_channels=self.embed_dim, 500 out_channels=features_decoder[0], 501 use_conv_transpose=use_conv_transpose, 502 ) 503 self.deconv2 = Deconv2DBlock( 504 in_channels=features_decoder[0], 505 out_channels=features_decoder[1], 506 use_conv_transpose=use_conv_transpose, 507 ) 508 self.deconv3 = Deconv2DBlock( 509 in_channels=features_decoder[1], 510 out_channels=features_decoder[2], 511 use_conv_transpose=use_conv_transpose, 512 ) 513 self.deconv4 = Deconv2DBlock( 514 in_channels=features_decoder[2], 515 out_channels=features_decoder[3], 516 use_conv_transpose=use_conv_transpose, 517 ) 518 519 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 520 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 521 self.deconv_out = _upsampler( 522 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 523 ) 524 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 525 526 def forward(self, x: torch.Tensor) -> torch.Tensor: 527 """Apply the UNETR to the input data. 528 529 Args: 530 x: The input tensor. 531 532 Returns: 533 The UNETR output. 534 """ 535 original_shape = x.shape[-2:] 536 537 # Reshape the inputs to the shape expected by the encoder 538 # and normalize the inputs if normalization is part of the model. 539 x, input_shape = self.preprocess(x) 540 541 encoder_outputs = self.encoder(x) 542 543 if isinstance(encoder_outputs[-1], list): 544 # `encoder_outputs` can be arranged in only two forms: 545 # - either we only return the image embeddings 546 # - or, we return the image embeddings and the "list" of global attention layers 547 z12, from_encoder = encoder_outputs 548 else: 549 z12 = encoder_outputs 550 551 if self.use_skip_connection: 552 from_encoder = from_encoder[::-1] 553 z9 = self.deconv1(from_encoder[0]) 554 z6 = self.deconv2(from_encoder[1]) 555 z3 = self.deconv3(from_encoder[2]) 556 z0 = self.deconv4(x) 557 558 else: 559 z9 = self.deconv1(z12) 560 z6 = self.deconv2(z9) 561 z3 = self.deconv3(z6) 562 z0 = self.deconv4(z3) 563 564 updated_from_encoder = [z9, z6, z3] 565 566 x = self.base(z12) 567 x = self.decoder(x, encoder_inputs=updated_from_encoder) 568 x = self.deconv_out(x) 569 570 x = torch.cat([x, z0], dim=1) 571 x = self.decoder_head(x) 572 573 x = self.out_conv(x) 574 if self.final_activation is not None: 575 x = self.final_activation(x) 576 577 x = self.postprocess_masks(x, input_shape, original_shape) 578 return x 579 580 581class UNETR2D(UNETR): 582 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 583 """ 584 pass 585 586 587class UNETR3D(UNETRBase): 588 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 589 """ 590 def __init__( 591 self, 592 img_size: int = 1024, 593 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 594 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 595 decoder: Optional[nn.Module] = None, 596 out_channels: int = 1, 597 use_sam_stats: bool = False, 598 use_mae_stats: bool = False, 599 use_dino_stats: bool = False, 600 resize_input: bool = True, 601 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 602 final_activation: Optional[Union[str, nn.Module]] = None, 603 use_skip_connection: bool = False, 604 embed_dim: Optional[int] = None, 605 use_conv_transpose: bool = False, 606 use_strip_pooling: bool = True, 607 **kwargs 608 ): 609 if use_skip_connection: 610 raise NotImplementedError("The framework cannot handle skip connections atm.") 611 if use_conv_transpose: 612 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 613 614 # Sort the `embed_dim` out 615 embed_dim = 256 if embed_dim is None else embed_dim 616 617 super().__init__( 618 img_size=img_size, 619 backbone=backbone, 620 encoder=encoder, 621 decoder=decoder, 622 out_channels=out_channels, 623 use_sam_stats=use_sam_stats, 624 use_mae_stats=use_mae_stats, 625 use_dino_stats=use_dino_stats, 626 resize_input=resize_input, 627 encoder_checkpoint=encoder_checkpoint, 628 final_activation=final_activation, 629 use_skip_connection=use_skip_connection, 630 embed_dim=embed_dim, 631 use_conv_transpose=use_conv_transpose, 632 **kwargs, 633 ) 634 635 # The 3d convolutional decoder. 636 # First, get the important parameters for the decoder. 637 depth = 3 638 initial_features = 64 639 gain = 2 640 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 641 scale_factors = [1, 2, 2] 642 self.out_channels = out_channels 643 644 # The mapping blocks. 645 self.deconv1 = Deconv3DBlock( 646 in_channels=embed_dim, 647 out_channels=features_decoder[0], 648 scale_factor=scale_factors, 649 use_strip_pooling=use_strip_pooling, 650 ) 651 self.deconv2 = Deconv3DBlock( 652 in_channels=features_decoder[0], 653 out_channels=features_decoder[1], 654 scale_factor=scale_factors, 655 use_strip_pooling=use_strip_pooling, 656 ) 657 self.deconv3 = Deconv3DBlock( 658 in_channels=features_decoder[1], 659 out_channels=features_decoder[2], 660 scale_factor=scale_factors, 661 use_strip_pooling=use_strip_pooling, 662 ) 663 self.deconv4 = Deconv3DBlock( 664 in_channels=features_decoder[2], 665 out_channels=features_decoder[3], 666 scale_factor=scale_factors, 667 use_strip_pooling=use_strip_pooling, 668 ) 669 670 # The core decoder block. 671 self.decoder = decoder or Decoder( 672 features=features_decoder, 673 scale_factors=[scale_factors] * depth, 674 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 675 sampler_impl=Upsampler3d, 676 ) 677 678 # And the final upsampler to match the expected dimensions. 679 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 680 in_channels=features_decoder[-1], 681 out_channels=features_decoder[-1], 682 scale_factor=scale_factors, 683 use_strip_pooling=use_strip_pooling, 684 ) 685 686 # Additional conjunction blocks. 687 self.base = ConvBlock3dWithStrip( 688 in_channels=embed_dim, 689 out_channels=features_decoder[0], 690 use_strip_pooling=use_strip_pooling, 691 ) 692 693 # And the output layers. 694 self.decoder_head = ConvBlock3dWithStrip( 695 in_channels=2 * features_decoder[-1], 696 out_channels=features_decoder[-1], 697 use_strip_pooling=use_strip_pooling, 698 ) 699 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 700 701 def forward(self, x: torch.Tensor): 702 """Forward pass of the UNETR-3D model. 703 704 Args: 705 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 706 707 Returns: 708 The UNETR output. 709 """ 710 B, C, Z, H, W = x.shape 711 original_shape = (Z, H, W) 712 713 # Preprocessing step 714 x, input_shape = self.preprocess(x) 715 716 # Run the image encoder. 717 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 718 719 # Prepare the counterparts for the decoder. 720 # NOTE: The section below is sequential, there's no skip connections atm. 721 z9 = self.deconv1(curr_features) 722 z6 = self.deconv2(z9) 723 z3 = self.deconv3(z6) 724 z0 = self.deconv4(z3) 725 726 updated_from_encoder = [z9, z6, z3] 727 728 # Align the features through the base block. 729 x = self.base(curr_features) 730 # Run the decoder 731 x = self.decoder(x, encoder_inputs=updated_from_encoder) 732 x = self.deconv_out(x) # NOTE before `end_up` 733 734 # And the final output head. 735 x = torch.cat([x, z0], dim=1) 736 x = self.decoder_head(x) 737 x = self.out_conv(x) 738 if self.final_activation is not None: 739 x = self.final_activation(x) 740 741 # Postprocess the output back to original size. 742 x = self.postprocess_masks(x, input_shape, original_shape) 743 return x 744 745# 746# ADDITIONAL FUNCTIONALITIES 747# 748 749 750def _strip_pooling_layers(enabled, channels) -> nn.Module: 751 return DepthStripPooling(channels) if enabled else nn.Identity() 752 753 754class DepthStripPooling(nn.Module): 755 """@private 756 """ 757 def __init__(self, channels: int, reduction: int = 4): 758 """Block for strip pooling along the depth dimension (only). 759 760 eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling 761 to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to 762 modulate the original features (using a gated residual). 763 764 For 2D (Z == 1): returns input unchanged (no-op). 765 766 Args: 767 channels: The output channels. 768 reduction: The reduction of the hidden layers. 769 """ 770 super().__init__() 771 hidden = max(1, channels // reduction) 772 self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1) 773 self.bn1 = nn.BatchNorm3d(hidden) 774 self.relu = nn.ReLU(inplace=True) 775 self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1) 776 777 def forward(self, x: torch.Tensor) -> torch.Tensor: 778 if x.dim() != 5: 779 raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.") 780 781 B, C, Z, H, W = x.shape 782 if Z == 1: # i.e. always the case of all 2d. 783 return x # We simply do nothing there. 784 785 # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W) 786 feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W)) 787 feat = self.conv1(feat) 788 feat = self.bn1(feat) 789 feat = self.relu(feat) 790 feat = self.conv2(feat) 791 gate = torch.sigmoid(feat).expand(B, C, Z, H, W) # Broadcast the collapsed depth context back to all slices 792 793 # Gated residual fusion 794 return x * gate + x 795 796 797class Deconv3DBlock(nn.Module): 798 """@private 799 """ 800 def __init__( 801 self, 802 scale_factor, 803 in_channels, 804 out_channels, 805 kernel_size=3, 806 anisotropic_kernel=True, 807 use_strip_pooling=True, 808 ): 809 super().__init__() 810 conv_block_kwargs = { 811 "in_channels": out_channels, 812 "out_channels": out_channels, 813 "kernel_size": kernel_size, 814 "padding": ((kernel_size - 1) // 2), 815 } 816 if anisotropic_kernel: 817 conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor) 818 819 self.block = nn.Sequential( 820 Upsampler3d(scale_factor, in_channels, out_channels), 821 nn.Conv3d(**conv_block_kwargs), 822 nn.BatchNorm3d(out_channels), 823 nn.ReLU(True), 824 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 825 ) 826 827 def forward(self, x): 828 return self.block(x) 829 830 831class ConvBlock3dWithStrip(nn.Module): 832 """@private 833 """ 834 def __init__( 835 self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs 836 ): 837 super().__init__() 838 self.block = nn.Sequential( 839 ConvBlock3d(in_channels, out_channels, **kwargs), 840 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 841 ) 842 843 def forward(self, x): 844 return self.block(x) 845 846 847class SingleDeconv2DBlock(nn.Module): 848 """@private 849 """ 850 def __init__(self, scale_factor, in_channels, out_channels): 851 super().__init__() 852 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 853 854 def forward(self, x): 855 return self.block(x) 856 857 858class SingleConv2DBlock(nn.Module): 859 """@private 860 """ 861 def __init__(self, in_channels, out_channels, kernel_size): 862 super().__init__() 863 self.block = nn.Conv2d( 864 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 865 ) 866 867 def forward(self, x): 868 return self.block(x) 869 870 871class Conv2DBlock(nn.Module): 872 """@private 873 """ 874 def __init__(self, in_channels, out_channels, kernel_size=3): 875 super().__init__() 876 self.block = nn.Sequential( 877 SingleConv2DBlock(in_channels, out_channels, kernel_size), 878 nn.BatchNorm2d(out_channels), 879 nn.ReLU(True) 880 ) 881 882 def forward(self, x): 883 return self.block(x) 884 885 886class Deconv2DBlock(nn.Module): 887 """@private 888 """ 889 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 890 super().__init__() 891 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 892 self.block = nn.Sequential( 893 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 894 SingleConv2DBlock(out_channels, out_channels, kernel_size), 895 nn.BatchNorm2d(out_channels), 896 nn.ReLU(True) 897 ) 898 899 def forward(self, x): 900 return self.block(x)
34class UNETRBase(nn.Module): 35 """Base class for implementing a UNETR. 36 37 Args: 38 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 39 backbone: The name of the vision transformer implementation. 40 One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" 41 (see all combinations below) 42 encoder: The vision transformer. Can either be a name, such as "vit_b" 43 (see all combinations for this below) or a torch module. 44 decoder: The convolutional decoder. 45 out_channels: The number of output channels of the UNETR. 46 use_sam_stats: Whether to normalize the input data with the statistics of the 47 pretrained SAM / SAM2 / SAM3 model. 48 use_dino_stats: Whether to normalize the input data with the statistics of the 49 pretrained DINOv2 / DINOv3 model. 50 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 51 resize_input: Whether to resize the input images to match `img_size`. 52 By default, it resizes the inputs to match the `img_size`. 53 encoder_checkpoint: Checkpoint for initializing the vision transformer. 54 Can either be a filepath or an already loaded checkpoint. 55 final_activation: The activation to apply to the UNETR output. 56 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 57 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 58 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 59 By default, it uses resampling for upsampling. 60 61 NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following: 62 63 SAM_family_models: 64 - 'sam' x 'vit_b' 65 - 'sam' x 'vit_l' 66 - 'sam' x 'vit_h' 67 - 'sam2' x 'hvit_t' 68 - 'sam2' x 'hvit_s' 69 - 'sam2' x 'hvit_b' 70 - 'sam2' x 'hvit_l' 71 - 'sam3' x 'vit_pe' 72 - 'cellpose_sam' x 'vit_l' 73 74 DINO_family_models: 75 - 'dinov2' x 'vit_s' 76 - 'dinov2' x 'vit_b' 77 - 'dinov2' x 'vit_l' 78 - 'dinov2' x 'vit_g' 79 - 'dinov2' x 'vit_s_reg4' 80 - 'dinov2' x 'vit_b_reg4' 81 - 'dinov2' x 'vit_l_reg4' 82 - 'dinov2' x 'vit_g_reg4' 83 - 'dinov3' x 'vit_s' 84 - 'dinov3' x 'vit_s+' 85 - 'dinov3' x 'vit_b' 86 - 'dinov3' x 'vit_l' 87 - 'dinov3' x 'vit_l+' 88 - 'dinov3' x 'vit_h+' 89 - 'dinov3' x 'vit_7b' 90 91 MAE_family_models: 92 - 'mae' x 'vit_b' 93 - 'mae' x 'vit_l' 94 - 'mae' x 'vit_h' 95 - 'scalemae' x 'vit_b' 96 - 'scalemae' x 'vit_l' 97 - 'scalemae' x 'vit_h' 98 """ 99 def __init__( 100 self, 101 img_size: int = 1024, 102 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 103 encoder: Optional[Union[nn.Module, str]] = "vit_b", 104 decoder: Optional[nn.Module] = None, 105 out_channels: int = 1, 106 use_sam_stats: bool = False, 107 use_mae_stats: bool = False, 108 use_dino_stats: bool = False, 109 resize_input: bool = True, 110 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 111 final_activation: Optional[Union[str, nn.Module]] = None, 112 use_skip_connection: bool = True, 113 embed_dim: Optional[int] = None, 114 use_conv_transpose: bool = False, 115 **kwargs 116 ) -> None: 117 super().__init__() 118 119 self.img_size = img_size 120 self.use_sam_stats = use_sam_stats 121 self.use_mae_stats = use_mae_stats 122 self.use_dino_stats = use_dino_stats 123 self.use_skip_connection = use_skip_connection 124 self.resize_input = resize_input 125 self.use_conv_transpose = use_conv_transpose 126 self.backbone = backbone 127 128 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 129 print(f"Using {encoder} from {backbone.upper()}") 130 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 131 132 if encoder_checkpoint is not None: 133 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 134 135 if embed_dim is None: 136 embed_dim = self.encoder.embed_dim 137 138 # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change. 139 if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck: 140 embed_dim = self.encoder.neck[2].out_channels # the value is 256 141 142 else: # `nn.Module` ViT backbone 143 self.encoder = encoder 144 145 have_neck = False 146 for name, _ in self.encoder.named_parameters(): 147 if name.startswith("neck"): 148 have_neck = True 149 150 if embed_dim is None: 151 if have_neck: 152 embed_dim = self.encoder.neck[2].out_channels # the value is 256 153 else: 154 embed_dim = self.encoder.patch_embed.proj.out_channels 155 156 self.embed_dim = embed_dim 157 self.final_activation = self._get_activation(final_activation) 158 159 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 160 """Function to load pretrained weights to the image encoder. 161 """ 162 if isinstance(checkpoint, str): 163 if backbone == "sam" and isinstance(encoder, str): 164 # If we have a SAM encoder, then we first try to load the full SAM Model 165 # (using micro_sam) and otherwise fall back on directly loading the encoder state 166 # from the checkpoint 167 try: 168 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 169 encoder_state = model.image_encoder.state_dict() 170 except Exception: 171 # Try loading the encoder state directly from a checkpoint. 172 encoder_state = torch.load(checkpoint, weights_only=False) 173 174 elif backbone == "cellpose_sam" and isinstance(encoder, str): 175 # The architecture matches CellposeSAM exactly (same rel_pos sizes), 176 # so weights load directly without any interpolation. 177 encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False) 178 # Handle DataParallel/DistributedDataParallel prefix. 179 if any(k.startswith("module.") for k in encoder_state.keys()): 180 encoder_state = OrderedDict( 181 {k[len("module."):]: v for k, v in encoder_state.items()} 182 ) 183 # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix). 184 if any(k.startswith("encoder.") for k in encoder_state.keys()): 185 encoder_state = OrderedDict( 186 {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")} 187 ) 188 189 elif backbone == "sam2" and isinstance(encoder, str): 190 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 191 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 192 # from the checkpoint 193 try: 194 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 195 encoder_state = model.image_encoder.state_dict() 196 except Exception: 197 # Try loading the encoder state directly from a checkpoint. 198 encoder_state = torch.load(checkpoint, weights_only=False) 199 200 elif backbone == "sam3" and isinstance(encoder, str): 201 # If we have a SAM3 encoder, then we first try to load the full SAM3 Model. 202 # (using micro_sam3) and otherwise fall back on directly loading the encoder state 203 # from the checkpoint 204 try: 205 model = get_sam3_model(checkpoint_path=checkpoint) 206 encoder_state = model.backbone.vision_backbone.state_dict() 207 # Let's align loading the encoder weights with expected parameter names 208 encoder_state = { 209 k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items() 210 } 211 # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks. 212 encoder_state = { 213 k: v for k, v in encoder_state.items() 214 if not (k.startswith("convs.") or k.startswith("sam2_convs.")) 215 } 216 except Exception: 217 # Try loading the encoder state directly from a checkpoint. 218 encoder_state = torch.load(checkpoint, weights_only=False) 219 220 elif backbone == "mae": 221 # vit initialization hints from: 222 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 223 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 224 encoder_state = OrderedDict({ 225 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 226 }) 227 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 228 current_encoder_state = self.encoder.state_dict() 229 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 230 del self.encoder.head 231 232 elif backbone == "scalemae": 233 # Load the encoder state directly from a checkpoint. 234 encoder_state = torch.load(checkpoint)["model"] 235 encoder_state = OrderedDict({ 236 k: v for k, v in encoder_state.items() 237 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 238 }) 239 240 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 241 current_encoder_state = self.encoder.state_dict() 242 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 243 del self.encoder.head 244 245 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 246 del self.encoder.pos_embed 247 248 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 249 encoder_state = torch.load(checkpoint) 250 251 else: 252 raise ValueError( 253 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 254 ) 255 256 else: 257 encoder_state = checkpoint 258 259 self.encoder.load_state_dict(encoder_state) 260 261 def _get_activation(self, activation): 262 return_activation = None 263 if activation is None: 264 return None 265 if isinstance(activation, nn.Module): 266 return activation 267 if isinstance(activation, str): 268 return_activation = getattr(nn, activation, None) 269 if return_activation is None: 270 raise ValueError(f"Invalid activation: {activation}") 271 272 return return_activation() 273 274 @staticmethod 275 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 276 """Compute the output size given input size and target long side length. 277 278 Args: 279 oldh: The input image height. 280 oldw: The input image width. 281 long_side_length: The longest side length for resizing. 282 283 Returns: 284 The new image height. 285 The new image width. 286 """ 287 scale = long_side_length * 1.0 / max(oldh, oldw) 288 newh, neww = oldh * scale, oldw * scale 289 neww = int(neww + 0.5) 290 newh = int(newh + 0.5) 291 return (newh, neww) 292 293 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 294 """Resize the image so that the longest side has the correct length. 295 296 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 297 298 Args: 299 image: The input image. 300 301 Returns: 302 The resized image. 303 """ 304 if image.ndim == 4: # i.e. 2d image 305 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 306 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 307 elif image.ndim == 5: # i.e. 3d volume 308 B, C, Z, H, W = image.shape 309 target_size = self.get_preprocess_shape(H, W, self.img_size) 310 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 311 else: 312 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 313 314 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 315 """@private 316 """ 317 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 318 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 319 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 320 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 321 return pixel_mean, pixel_std 322 323 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 324 """@private 325 """ 326 is_3d = (x.ndim == 5) 327 device, dtype = x.device, x.dtype 328 329 if self.use_sam_stats: 330 if self.backbone == "sam2": 331 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 332 elif self.backbone == "sam3": 333 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 334 else: # sam1 / default 335 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 336 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 337 raise NotImplementedError 338 elif self.use_dino_stats: 339 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 340 else: 341 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 342 343 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 344 345 if self.resize_input: 346 x = self.resize_longest_side(x) 347 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 348 349 x = (x - pixel_mean) / pixel_std 350 h, w = x.shape[-2:] 351 padh = self.encoder.img_size - h 352 padw = self.encoder.img_size - w 353 354 if is_3d: 355 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 356 else: 357 x = F.pad(x, (0, padw, 0, padh)) 358 359 return x, input_shape 360 361 def postprocess_masks( 362 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 363 ) -> torch.Tensor: 364 """@private 365 """ 366 if masks.ndim == 4: # i.e. 2d labels 367 masks = F.interpolate( 368 masks, 369 (self.encoder.img_size, self.encoder.img_size), 370 mode="bilinear", 371 align_corners=False, 372 ) 373 masks = masks[..., : input_size[0], : input_size[1]] 374 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 375 376 elif masks.ndim == 5: # i.e. 3d volumetric labels 377 masks = F.interpolate( 378 masks, 379 (input_size[0], self.img_size, self.img_size), 380 mode="trilinear", 381 align_corners=False, 382 ) 383 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 384 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 385 386 else: 387 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 388 389 return masks
Base class for implementing a UNETR.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
- encoder: The vision transformer. Can either be a name, such as "vit_b" (see all combinations for this below) or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM / SAM2 / SAM3 model.
- use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv2 / DINOv3 model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size. By default, it resizes the inputs to match theimg_size. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
- NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
- SAM_family_models: - 'sam' x 'vit_b'
- 'sam' x 'vit_l'
- 'sam' x 'vit_h'
- 'sam2' x 'hvit_t'
- 'sam2' x 'hvit_s'
- 'sam2' x 'hvit_b'
- 'sam2' x 'hvit_l'
- 'sam3' x 'vit_pe'
- 'cellpose_sam' x 'vit_l'
- DINO_family_models: - 'dinov2' x 'vit_s'
- 'dinov2' x 'vit_b'
- 'dinov2' x 'vit_l'
- 'dinov2' x 'vit_g'
- 'dinov2' x 'vit_s_reg4'
- 'dinov2' x 'vit_b_reg4'
- 'dinov2' x 'vit_l_reg4'
- 'dinov2' x 'vit_g_reg4'
- 'dinov3' x 'vit_s'
- 'dinov3' x 'vit_s+'
- 'dinov3' x 'vit_b'
- 'dinov3' x 'vit_l'
- 'dinov3' x 'vit_l+'
- 'dinov3' x 'vit_h+'
- 'dinov3' x 'vit_7b'
- MAE_family_models: - 'mae' x 'vit_b'
- 'mae' x 'vit_l'
- 'mae' x 'vit_h'
- 'scalemae' x 'vit_b'
- 'scalemae' x 'vit_l'
- 'scalemae' x 'vit_h'
99 def __init__( 100 self, 101 img_size: int = 1024, 102 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 103 encoder: Optional[Union[nn.Module, str]] = "vit_b", 104 decoder: Optional[nn.Module] = None, 105 out_channels: int = 1, 106 use_sam_stats: bool = False, 107 use_mae_stats: bool = False, 108 use_dino_stats: bool = False, 109 resize_input: bool = True, 110 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 111 final_activation: Optional[Union[str, nn.Module]] = None, 112 use_skip_connection: bool = True, 113 embed_dim: Optional[int] = None, 114 use_conv_transpose: bool = False, 115 **kwargs 116 ) -> None: 117 super().__init__() 118 119 self.img_size = img_size 120 self.use_sam_stats = use_sam_stats 121 self.use_mae_stats = use_mae_stats 122 self.use_dino_stats = use_dino_stats 123 self.use_skip_connection = use_skip_connection 124 self.resize_input = resize_input 125 self.use_conv_transpose = use_conv_transpose 126 self.backbone = backbone 127 128 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 129 print(f"Using {encoder} from {backbone.upper()}") 130 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 131 132 if encoder_checkpoint is not None: 133 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 134 135 if embed_dim is None: 136 embed_dim = self.encoder.embed_dim 137 138 # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change. 139 if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck: 140 embed_dim = self.encoder.neck[2].out_channels # the value is 256 141 142 else: # `nn.Module` ViT backbone 143 self.encoder = encoder 144 145 have_neck = False 146 for name, _ in self.encoder.named_parameters(): 147 if name.startswith("neck"): 148 have_neck = True 149 150 if embed_dim is None: 151 if have_neck: 152 embed_dim = self.encoder.neck[2].out_channels # the value is 256 153 else: 154 embed_dim = self.encoder.patch_embed.proj.out_channels 155 156 self.embed_dim = embed_dim 157 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
274 @staticmethod 275 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 276 """Compute the output size given input size and target long side length. 277 278 Args: 279 oldh: The input image height. 280 oldw: The input image width. 281 long_side_length: The longest side length for resizing. 282 283 Returns: 284 The new image height. 285 The new image width. 286 """ 287 scale = long_side_length * 1.0 / max(oldh, oldw) 288 newh, neww = oldh * scale, oldw * scale 289 neww = int(neww + 0.5) 290 newh = int(newh + 0.5) 291 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
293 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 294 """Resize the image so that the longest side has the correct length. 295 296 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 297 298 Args: 299 image: The input image. 300 301 Returns: 302 The resized image. 303 """ 304 if image.ndim == 4: # i.e. 2d image 305 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 306 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 307 elif image.ndim == 5: # i.e. 3d volume 308 B, C, Z, H, W = image.shape 309 target_size = self.get_preprocess_shape(H, W, self.img_size) 310 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 311 else: 312 raise ValueError("Expected 4d or 5d inputs, got", image.shape)
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
392class UNETR(UNETRBase): 393 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 394 """ 395 def __init__( 396 self, 397 img_size: int = 1024, 398 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 399 encoder: Optional[Union[nn.Module, str]] = "vit_b", 400 decoder: Optional[nn.Module] = None, 401 out_channels: int = 1, 402 use_sam_stats: bool = False, 403 use_mae_stats: bool = False, 404 use_dino_stats: bool = False, 405 resize_input: bool = True, 406 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 407 final_activation: Optional[Union[str, nn.Module]] = None, 408 use_skip_connection: bool = True, 409 embed_dim: Optional[int] = None, 410 use_conv_transpose: bool = False, 411 **kwargs 412 ) -> None: 413 414 super().__init__( 415 img_size=img_size, 416 backbone=backbone, 417 encoder=encoder, 418 decoder=decoder, 419 out_channels=out_channels, 420 use_sam_stats=use_sam_stats, 421 use_mae_stats=use_mae_stats, 422 use_dino_stats=use_dino_stats, 423 resize_input=resize_input, 424 encoder_checkpoint=encoder_checkpoint, 425 final_activation=final_activation, 426 use_skip_connection=use_skip_connection, 427 embed_dim=embed_dim, 428 use_conv_transpose=use_conv_transpose, 429 **kwargs, 430 ) 431 432 encoder = self.encoder 433 434 if backbone == "sam2" and hasattr(encoder, "trunk"): 435 in_chans = encoder.trunk.patch_embed.proj.in_channels 436 elif hasattr(encoder, "in_chans"): 437 in_chans = encoder.in_chans 438 else: # `nn.Module` ViT backbone. 439 try: 440 in_chans = encoder.patch_embed.proj.in_channels 441 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 442 in_chans = encoder.patch_embed.seq[0].c.in_channels 443 444 # parameters for the decoder network 445 depth = 3 446 initial_features = 64 447 gain = 2 448 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 449 scale_factors = depth * [2] 450 self.out_channels = out_channels 451 452 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 453 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 454 455 self.decoder = decoder or Decoder( 456 features=features_decoder, 457 scale_factors=scale_factors[::-1], 458 conv_block_impl=ConvBlock2d, 459 sampler_impl=_upsampler, 460 ) 461 462 if use_skip_connection: 463 self.deconv1 = Deconv2DBlock( 464 in_channels=self.embed_dim, 465 out_channels=features_decoder[0], 466 use_conv_transpose=use_conv_transpose, 467 ) 468 self.deconv2 = nn.Sequential( 469 Deconv2DBlock( 470 in_channels=self.embed_dim, 471 out_channels=features_decoder[0], 472 use_conv_transpose=use_conv_transpose, 473 ), 474 Deconv2DBlock( 475 in_channels=features_decoder[0], 476 out_channels=features_decoder[1], 477 use_conv_transpose=use_conv_transpose, 478 ) 479 ) 480 self.deconv3 = nn.Sequential( 481 Deconv2DBlock( 482 in_channels=self.embed_dim, 483 out_channels=features_decoder[0], 484 use_conv_transpose=use_conv_transpose, 485 ), 486 Deconv2DBlock( 487 in_channels=features_decoder[0], 488 out_channels=features_decoder[1], 489 use_conv_transpose=use_conv_transpose, 490 ), 491 Deconv2DBlock( 492 in_channels=features_decoder[1], 493 out_channels=features_decoder[2], 494 use_conv_transpose=use_conv_transpose, 495 ) 496 ) 497 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 498 else: 499 self.deconv1 = Deconv2DBlock( 500 in_channels=self.embed_dim, 501 out_channels=features_decoder[0], 502 use_conv_transpose=use_conv_transpose, 503 ) 504 self.deconv2 = Deconv2DBlock( 505 in_channels=features_decoder[0], 506 out_channels=features_decoder[1], 507 use_conv_transpose=use_conv_transpose, 508 ) 509 self.deconv3 = Deconv2DBlock( 510 in_channels=features_decoder[1], 511 out_channels=features_decoder[2], 512 use_conv_transpose=use_conv_transpose, 513 ) 514 self.deconv4 = Deconv2DBlock( 515 in_channels=features_decoder[2], 516 out_channels=features_decoder[3], 517 use_conv_transpose=use_conv_transpose, 518 ) 519 520 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 521 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 522 self.deconv_out = _upsampler( 523 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 524 ) 525 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 526 527 def forward(self, x: torch.Tensor) -> torch.Tensor: 528 """Apply the UNETR to the input data. 529 530 Args: 531 x: The input tensor. 532 533 Returns: 534 The UNETR output. 535 """ 536 original_shape = x.shape[-2:] 537 538 # Reshape the inputs to the shape expected by the encoder 539 # and normalize the inputs if normalization is part of the model. 540 x, input_shape = self.preprocess(x) 541 542 encoder_outputs = self.encoder(x) 543 544 if isinstance(encoder_outputs[-1], list): 545 # `encoder_outputs` can be arranged in only two forms: 546 # - either we only return the image embeddings 547 # - or, we return the image embeddings and the "list" of global attention layers 548 z12, from_encoder = encoder_outputs 549 else: 550 z12 = encoder_outputs 551 552 if self.use_skip_connection: 553 from_encoder = from_encoder[::-1] 554 z9 = self.deconv1(from_encoder[0]) 555 z6 = self.deconv2(from_encoder[1]) 556 z3 = self.deconv3(from_encoder[2]) 557 z0 = self.deconv4(x) 558 559 else: 560 z9 = self.deconv1(z12) 561 z6 = self.deconv2(z9) 562 z3 = self.deconv3(z6) 563 z0 = self.deconv4(z3) 564 565 updated_from_encoder = [z9, z6, z3] 566 567 x = self.base(z12) 568 x = self.decoder(x, encoder_inputs=updated_from_encoder) 569 x = self.deconv_out(x) 570 571 x = torch.cat([x, z0], dim=1) 572 x = self.decoder_head(x) 573 574 x = self.out_conv(x) 575 if self.final_activation is not None: 576 x = self.final_activation(x) 577 578 x = self.postprocess_masks(x, input_shape, original_shape) 579 return x
A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
395 def __init__( 396 self, 397 img_size: int = 1024, 398 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 399 encoder: Optional[Union[nn.Module, str]] = "vit_b", 400 decoder: Optional[nn.Module] = None, 401 out_channels: int = 1, 402 use_sam_stats: bool = False, 403 use_mae_stats: bool = False, 404 use_dino_stats: bool = False, 405 resize_input: bool = True, 406 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 407 final_activation: Optional[Union[str, nn.Module]] = None, 408 use_skip_connection: bool = True, 409 embed_dim: Optional[int] = None, 410 use_conv_transpose: bool = False, 411 **kwargs 412 ) -> None: 413 414 super().__init__( 415 img_size=img_size, 416 backbone=backbone, 417 encoder=encoder, 418 decoder=decoder, 419 out_channels=out_channels, 420 use_sam_stats=use_sam_stats, 421 use_mae_stats=use_mae_stats, 422 use_dino_stats=use_dino_stats, 423 resize_input=resize_input, 424 encoder_checkpoint=encoder_checkpoint, 425 final_activation=final_activation, 426 use_skip_connection=use_skip_connection, 427 embed_dim=embed_dim, 428 use_conv_transpose=use_conv_transpose, 429 **kwargs, 430 ) 431 432 encoder = self.encoder 433 434 if backbone == "sam2" and hasattr(encoder, "trunk"): 435 in_chans = encoder.trunk.patch_embed.proj.in_channels 436 elif hasattr(encoder, "in_chans"): 437 in_chans = encoder.in_chans 438 else: # `nn.Module` ViT backbone. 439 try: 440 in_chans = encoder.patch_embed.proj.in_channels 441 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 442 in_chans = encoder.patch_embed.seq[0].c.in_channels 443 444 # parameters for the decoder network 445 depth = 3 446 initial_features = 64 447 gain = 2 448 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 449 scale_factors = depth * [2] 450 self.out_channels = out_channels 451 452 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 453 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 454 455 self.decoder = decoder or Decoder( 456 features=features_decoder, 457 scale_factors=scale_factors[::-1], 458 conv_block_impl=ConvBlock2d, 459 sampler_impl=_upsampler, 460 ) 461 462 if use_skip_connection: 463 self.deconv1 = Deconv2DBlock( 464 in_channels=self.embed_dim, 465 out_channels=features_decoder[0], 466 use_conv_transpose=use_conv_transpose, 467 ) 468 self.deconv2 = nn.Sequential( 469 Deconv2DBlock( 470 in_channels=self.embed_dim, 471 out_channels=features_decoder[0], 472 use_conv_transpose=use_conv_transpose, 473 ), 474 Deconv2DBlock( 475 in_channels=features_decoder[0], 476 out_channels=features_decoder[1], 477 use_conv_transpose=use_conv_transpose, 478 ) 479 ) 480 self.deconv3 = nn.Sequential( 481 Deconv2DBlock( 482 in_channels=self.embed_dim, 483 out_channels=features_decoder[0], 484 use_conv_transpose=use_conv_transpose, 485 ), 486 Deconv2DBlock( 487 in_channels=features_decoder[0], 488 out_channels=features_decoder[1], 489 use_conv_transpose=use_conv_transpose, 490 ), 491 Deconv2DBlock( 492 in_channels=features_decoder[1], 493 out_channels=features_decoder[2], 494 use_conv_transpose=use_conv_transpose, 495 ) 496 ) 497 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 498 else: 499 self.deconv1 = Deconv2DBlock( 500 in_channels=self.embed_dim, 501 out_channels=features_decoder[0], 502 use_conv_transpose=use_conv_transpose, 503 ) 504 self.deconv2 = Deconv2DBlock( 505 in_channels=features_decoder[0], 506 out_channels=features_decoder[1], 507 use_conv_transpose=use_conv_transpose, 508 ) 509 self.deconv3 = Deconv2DBlock( 510 in_channels=features_decoder[1], 511 out_channels=features_decoder[2], 512 use_conv_transpose=use_conv_transpose, 513 ) 514 self.deconv4 = Deconv2DBlock( 515 in_channels=features_decoder[2], 516 out_channels=features_decoder[3], 517 use_conv_transpose=use_conv_transpose, 518 ) 519 520 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 521 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 522 self.deconv_out = _upsampler( 523 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 524 ) 525 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
527 def forward(self, x: torch.Tensor) -> torch.Tensor: 528 """Apply the UNETR to the input data. 529 530 Args: 531 x: The input tensor. 532 533 Returns: 534 The UNETR output. 535 """ 536 original_shape = x.shape[-2:] 537 538 # Reshape the inputs to the shape expected by the encoder 539 # and normalize the inputs if normalization is part of the model. 540 x, input_shape = self.preprocess(x) 541 542 encoder_outputs = self.encoder(x) 543 544 if isinstance(encoder_outputs[-1], list): 545 # `encoder_outputs` can be arranged in only two forms: 546 # - either we only return the image embeddings 547 # - or, we return the image embeddings and the "list" of global attention layers 548 z12, from_encoder = encoder_outputs 549 else: 550 z12 = encoder_outputs 551 552 if self.use_skip_connection: 553 from_encoder = from_encoder[::-1] 554 z9 = self.deconv1(from_encoder[0]) 555 z6 = self.deconv2(from_encoder[1]) 556 z3 = self.deconv3(from_encoder[2]) 557 z0 = self.deconv4(x) 558 559 else: 560 z9 = self.deconv1(z12) 561 z6 = self.deconv2(z9) 562 z3 = self.deconv3(z6) 563 z0 = self.deconv4(z3) 564 565 updated_from_encoder = [z9, z6, z3] 566 567 x = self.base(z12) 568 x = self.decoder(x, encoder_inputs=updated_from_encoder) 569 x = self.deconv_out(x) 570 571 x = torch.cat([x, z0], dim=1) 572 x = self.decoder_head(x) 573 574 x = self.out_conv(x) 575 if self.final_activation is not None: 576 x = self.final_activation(x) 577 578 x = self.postprocess_masks(x, input_shape, original_shape) 579 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.
582class UNETR2D(UNETR): 583 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 584 """ 585 pass
A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
588class UNETR3D(UNETRBase): 589 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 590 """ 591 def __init__( 592 self, 593 img_size: int = 1024, 594 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 595 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 596 decoder: Optional[nn.Module] = None, 597 out_channels: int = 1, 598 use_sam_stats: bool = False, 599 use_mae_stats: bool = False, 600 use_dino_stats: bool = False, 601 resize_input: bool = True, 602 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 603 final_activation: Optional[Union[str, nn.Module]] = None, 604 use_skip_connection: bool = False, 605 embed_dim: Optional[int] = None, 606 use_conv_transpose: bool = False, 607 use_strip_pooling: bool = True, 608 **kwargs 609 ): 610 if use_skip_connection: 611 raise NotImplementedError("The framework cannot handle skip connections atm.") 612 if use_conv_transpose: 613 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 614 615 # Sort the `embed_dim` out 616 embed_dim = 256 if embed_dim is None else embed_dim 617 618 super().__init__( 619 img_size=img_size, 620 backbone=backbone, 621 encoder=encoder, 622 decoder=decoder, 623 out_channels=out_channels, 624 use_sam_stats=use_sam_stats, 625 use_mae_stats=use_mae_stats, 626 use_dino_stats=use_dino_stats, 627 resize_input=resize_input, 628 encoder_checkpoint=encoder_checkpoint, 629 final_activation=final_activation, 630 use_skip_connection=use_skip_connection, 631 embed_dim=embed_dim, 632 use_conv_transpose=use_conv_transpose, 633 **kwargs, 634 ) 635 636 # The 3d convolutional decoder. 637 # First, get the important parameters for the decoder. 638 depth = 3 639 initial_features = 64 640 gain = 2 641 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 642 scale_factors = [1, 2, 2] 643 self.out_channels = out_channels 644 645 # The mapping blocks. 646 self.deconv1 = Deconv3DBlock( 647 in_channels=embed_dim, 648 out_channels=features_decoder[0], 649 scale_factor=scale_factors, 650 use_strip_pooling=use_strip_pooling, 651 ) 652 self.deconv2 = Deconv3DBlock( 653 in_channels=features_decoder[0], 654 out_channels=features_decoder[1], 655 scale_factor=scale_factors, 656 use_strip_pooling=use_strip_pooling, 657 ) 658 self.deconv3 = Deconv3DBlock( 659 in_channels=features_decoder[1], 660 out_channels=features_decoder[2], 661 scale_factor=scale_factors, 662 use_strip_pooling=use_strip_pooling, 663 ) 664 self.deconv4 = Deconv3DBlock( 665 in_channels=features_decoder[2], 666 out_channels=features_decoder[3], 667 scale_factor=scale_factors, 668 use_strip_pooling=use_strip_pooling, 669 ) 670 671 # The core decoder block. 672 self.decoder = decoder or Decoder( 673 features=features_decoder, 674 scale_factors=[scale_factors] * depth, 675 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 676 sampler_impl=Upsampler3d, 677 ) 678 679 # And the final upsampler to match the expected dimensions. 680 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 681 in_channels=features_decoder[-1], 682 out_channels=features_decoder[-1], 683 scale_factor=scale_factors, 684 use_strip_pooling=use_strip_pooling, 685 ) 686 687 # Additional conjunction blocks. 688 self.base = ConvBlock3dWithStrip( 689 in_channels=embed_dim, 690 out_channels=features_decoder[0], 691 use_strip_pooling=use_strip_pooling, 692 ) 693 694 # And the output layers. 695 self.decoder_head = ConvBlock3dWithStrip( 696 in_channels=2 * features_decoder[-1], 697 out_channels=features_decoder[-1], 698 use_strip_pooling=use_strip_pooling, 699 ) 700 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 701 702 def forward(self, x: torch.Tensor): 703 """Forward pass of the UNETR-3D model. 704 705 Args: 706 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 707 708 Returns: 709 The UNETR output. 710 """ 711 B, C, Z, H, W = x.shape 712 original_shape = (Z, H, W) 713 714 # Preprocessing step 715 x, input_shape = self.preprocess(x) 716 717 # Run the image encoder. 718 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 719 720 # Prepare the counterparts for the decoder. 721 # NOTE: The section below is sequential, there's no skip connections atm. 722 z9 = self.deconv1(curr_features) 723 z6 = self.deconv2(z9) 724 z3 = self.deconv3(z6) 725 z0 = self.deconv4(z3) 726 727 updated_from_encoder = [z9, z6, z3] 728 729 # Align the features through the base block. 730 x = self.base(curr_features) 731 # Run the decoder 732 x = self.decoder(x, encoder_inputs=updated_from_encoder) 733 x = self.deconv_out(x) # NOTE before `end_up` 734 735 # And the final output head. 736 x = torch.cat([x, z0], dim=1) 737 x = self.decoder_head(x) 738 x = self.out_conv(x) 739 if self.final_activation is not None: 740 x = self.final_activation(x) 741 742 # Postprocess the output back to original size. 743 x = self.postprocess_masks(x, input_shape, original_shape) 744 return x
A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
591 def __init__( 592 self, 593 img_size: int = 1024, 594 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 595 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 596 decoder: Optional[nn.Module] = None, 597 out_channels: int = 1, 598 use_sam_stats: bool = False, 599 use_mae_stats: bool = False, 600 use_dino_stats: bool = False, 601 resize_input: bool = True, 602 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 603 final_activation: Optional[Union[str, nn.Module]] = None, 604 use_skip_connection: bool = False, 605 embed_dim: Optional[int] = None, 606 use_conv_transpose: bool = False, 607 use_strip_pooling: bool = True, 608 **kwargs 609 ): 610 if use_skip_connection: 611 raise NotImplementedError("The framework cannot handle skip connections atm.") 612 if use_conv_transpose: 613 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 614 615 # Sort the `embed_dim` out 616 embed_dim = 256 if embed_dim is None else embed_dim 617 618 super().__init__( 619 img_size=img_size, 620 backbone=backbone, 621 encoder=encoder, 622 decoder=decoder, 623 out_channels=out_channels, 624 use_sam_stats=use_sam_stats, 625 use_mae_stats=use_mae_stats, 626 use_dino_stats=use_dino_stats, 627 resize_input=resize_input, 628 encoder_checkpoint=encoder_checkpoint, 629 final_activation=final_activation, 630 use_skip_connection=use_skip_connection, 631 embed_dim=embed_dim, 632 use_conv_transpose=use_conv_transpose, 633 **kwargs, 634 ) 635 636 # The 3d convolutional decoder. 637 # First, get the important parameters for the decoder. 638 depth = 3 639 initial_features = 64 640 gain = 2 641 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 642 scale_factors = [1, 2, 2] 643 self.out_channels = out_channels 644 645 # The mapping blocks. 646 self.deconv1 = Deconv3DBlock( 647 in_channels=embed_dim, 648 out_channels=features_decoder[0], 649 scale_factor=scale_factors, 650 use_strip_pooling=use_strip_pooling, 651 ) 652 self.deconv2 = Deconv3DBlock( 653 in_channels=features_decoder[0], 654 out_channels=features_decoder[1], 655 scale_factor=scale_factors, 656 use_strip_pooling=use_strip_pooling, 657 ) 658 self.deconv3 = Deconv3DBlock( 659 in_channels=features_decoder[1], 660 out_channels=features_decoder[2], 661 scale_factor=scale_factors, 662 use_strip_pooling=use_strip_pooling, 663 ) 664 self.deconv4 = Deconv3DBlock( 665 in_channels=features_decoder[2], 666 out_channels=features_decoder[3], 667 scale_factor=scale_factors, 668 use_strip_pooling=use_strip_pooling, 669 ) 670 671 # The core decoder block. 672 self.decoder = decoder or Decoder( 673 features=features_decoder, 674 scale_factors=[scale_factors] * depth, 675 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 676 sampler_impl=Upsampler3d, 677 ) 678 679 # And the final upsampler to match the expected dimensions. 680 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 681 in_channels=features_decoder[-1], 682 out_channels=features_decoder[-1], 683 scale_factor=scale_factors, 684 use_strip_pooling=use_strip_pooling, 685 ) 686 687 # Additional conjunction blocks. 688 self.base = ConvBlock3dWithStrip( 689 in_channels=embed_dim, 690 out_channels=features_decoder[0], 691 use_strip_pooling=use_strip_pooling, 692 ) 693 694 # And the output layers. 695 self.decoder_head = ConvBlock3dWithStrip( 696 in_channels=2 * features_decoder[-1], 697 out_channels=features_decoder[-1], 698 use_strip_pooling=use_strip_pooling, 699 ) 700 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
702 def forward(self, x: torch.Tensor): 703 """Forward pass of the UNETR-3D model. 704 705 Args: 706 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 707 708 Returns: 709 The UNETR output. 710 """ 711 B, C, Z, H, W = x.shape 712 original_shape = (Z, H, W) 713 714 # Preprocessing step 715 x, input_shape = self.preprocess(x) 716 717 # Run the image encoder. 718 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 719 720 # Prepare the counterparts for the decoder. 721 # NOTE: The section below is sequential, there's no skip connections atm. 722 z9 = self.deconv1(curr_features) 723 z6 = self.deconv2(z9) 724 z3 = self.deconv3(z6) 725 z0 = self.deconv4(z3) 726 727 updated_from_encoder = [z9, z6, z3] 728 729 # Align the features through the base block. 730 x = self.base(curr_features) 731 # Run the decoder 732 x = self.decoder(x, encoder_inputs=updated_from_encoder) 733 x = self.deconv_out(x) # NOTE before `end_up` 734 735 # And the final output head. 736 x = torch.cat([x, z0], dim=1) 737 x = self.decoder_head(x) 738 x = self.out_conv(x) 739 if self.final_activation is not None: 740 x = self.final_activation(x) 741 742 # Postprocess the output back to original size. 743 x = self.postprocess_masks(x, input_shape, original_shape) 744 return x
Forward pass of the UNETR-3D model.
Arguments:
- x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:
The UNETR output.