torch_em.model.unetr
1from functools import partial 2from collections import OrderedDict 3from typing import Optional, Tuple, Union, Literal 4 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8 9from .vit import get_vision_transformer 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs 11 12try: 13 from micro_sam.util import get_sam_model 14except ImportError: 15 get_sam_model = None 16 17try: 18 from micro_sam.v2.util import get_sam2_model 19except ImportError: 20 get_sam2_model = None 21 22try: 23 from micro_sam3.util import get_sam3_model 24except ImportError: 25 get_sam3_model = None 26 27 28# 29# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / CellposeSAM / SAM2 / SAM3 / DINOv2 / DINOv3 / MAE / ScaleMAE) + UNet Decoder from `torch_em`] # noqa 30# 31 32 33class UNETRBase(nn.Module): 34 """Base class for implementing a UNETR. 35 36 Args: 37 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 38 backbone: The name of the vision transformer implementation. 39 One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" 40 (see all combinations below) 41 encoder: The vision transformer. Can either be a name, such as "vit_b" 42 (see all combinations for this below) or a torch module. 43 decoder: The convolutional decoder. 44 out_channels: The number of output channels of the UNETR. 45 use_sam_stats: Whether to normalize the input data with the statistics of the 46 pretrained SAM / SAM2 / SAM3 model. 47 use_dino_stats: Whether to normalize the input data with the statistics of the 48 pretrained DINOv2 / DINOv3 model. 49 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 50 resize_input: Whether to resize the input images to match `img_size`. 51 By default, it resizes the inputs to match the `img_size`. 52 encoder_checkpoint: Checkpoint for initializing the vision transformer. 53 Can either be a filepath or an already loaded checkpoint. 54 final_activation: The activation to apply to the UNETR output. 55 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 56 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 57 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 58 By default, it uses resampling for upsampling. 59 60 NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following: 61 62 SAM_family_models: 63 - 'sam' x 'vit_b' 64 - 'sam' x 'vit_l' 65 - 'sam' x 'vit_h' 66 - 'sam2' x 'hvit_t' 67 - 'sam2' x 'hvit_s' 68 - 'sam2' x 'hvit_b' 69 - 'sam2' x 'hvit_l' 70 - 'sam3' x 'vit_pe' 71 - 'cellpose_sam' x 'vit_l' 72 73 DINO_family_models: 74 - 'dinov2' x 'vit_s' 75 - 'dinov2' x 'vit_b' 76 - 'dinov2' x 'vit_l' 77 - 'dinov2' x 'vit_g' 78 - 'dinov2' x 'vit_s_reg4' 79 - 'dinov2' x 'vit_b_reg4' 80 - 'dinov2' x 'vit_l_reg4' 81 - 'dinov2' x 'vit_g_reg4' 82 - 'dinov3' x 'vit_s' 83 - 'dinov3' x 'vit_s+' 84 - 'dinov3' x 'vit_b' 85 - 'dinov3' x 'vit_l' 86 - 'dinov3' x 'vit_l+' 87 - 'dinov3' x 'vit_h+' 88 - 'dinov3' x 'vit_7b' 89 90 MAE_family_models: 91 - 'mae' x 'vit_b' 92 - 'mae' x 'vit_l' 93 - 'mae' x 'vit_h' 94 - 'scalemae' x 'vit_b' 95 - 'scalemae' x 'vit_l' 96 - 'scalemae' x 'vit_h' 97 """ 98 def __init__( 99 self, 100 img_size: int = 1024, 101 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 102 encoder: Optional[Union[nn.Module, str]] = "vit_b", 103 decoder: Optional[nn.Module] = None, 104 out_channels: int = 1, 105 use_sam_stats: bool = False, 106 use_mae_stats: bool = False, 107 use_dino_stats: bool = False, 108 resize_input: bool = True, 109 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 110 final_activation: Optional[Union[str, nn.Module]] = None, 111 use_skip_connection: bool = True, 112 embed_dim: Optional[int] = None, 113 use_conv_transpose: bool = False, 114 **kwargs 115 ) -> None: 116 super().__init__() 117 118 self.img_size = img_size 119 self.use_sam_stats = use_sam_stats 120 self.use_mae_stats = use_mae_stats 121 self.use_dino_stats = use_dino_stats 122 self.use_skip_connection = use_skip_connection 123 self.resize_input = resize_input 124 self.use_conv_transpose = use_conv_transpose 125 self.backbone = backbone 126 127 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 128 print(f"Using {encoder} from {backbone.upper()}") 129 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 130 131 if encoder_checkpoint is not None: 132 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 133 134 if embed_dim is None: 135 embed_dim = self.encoder.embed_dim 136 137 else: # `nn.Module` ViT backbone 138 self.encoder = encoder 139 140 have_neck = False 141 for name, _ in self.encoder.named_parameters(): 142 if name.startswith("neck"): 143 have_neck = True 144 145 if embed_dim is None: 146 if have_neck: 147 embed_dim = self.encoder.neck[2].out_channels # the value is 256 148 else: 149 embed_dim = self.encoder.patch_embed.proj.out_channels 150 151 self.embed_dim = embed_dim 152 self.final_activation = self._get_activation(final_activation) 153 154 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 155 """Function to load pretrained weights to the image encoder. 156 """ 157 if isinstance(checkpoint, str): 158 if backbone == "sam" and isinstance(encoder, str): 159 # If we have a SAM encoder, then we first try to load the full SAM Model 160 # (using micro_sam) and otherwise fall back on directly loading the encoder state 161 # from the checkpoint 162 try: 163 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 164 encoder_state = model.image_encoder.state_dict() 165 except Exception: 166 # Try loading the encoder state directly from a checkpoint. 167 encoder_state = torch.load(checkpoint, weights_only=False) 168 169 elif backbone == "cellpose_sam" and isinstance(encoder, str): 170 # The architecture matches CellposeSAM exactly (same rel_pos sizes), 171 # so weights load directly without any interpolation. 172 encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False) 173 # Handle DataParallel/DistributedDataParallel prefix. 174 if any(k.startswith("module.") for k in encoder_state.keys()): 175 encoder_state = OrderedDict( 176 {k[len("module."):]: v for k, v in encoder_state.items()} 177 ) 178 # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix). 179 if any(k.startswith("encoder.") for k in encoder_state.keys()): 180 encoder_state = OrderedDict( 181 {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")} 182 ) 183 184 elif backbone == "sam2" and isinstance(encoder, str): 185 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 186 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 187 # from the checkpoint 188 try: 189 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 190 encoder_state = model.image_encoder.state_dict() 191 except Exception: 192 # Try loading the encoder state directly from a checkpoint. 193 encoder_state = torch.load(checkpoint, weights_only=False) 194 195 elif backbone == "sam3" and isinstance(encoder, str): 196 # If we have a SAM3 encoder, then we first try to load the full SAM3 Model. 197 # (using micro_sam3) and otherwise fall back on directly loading the encoder state 198 # from the checkpoint 199 try: 200 model = get_sam3_model(checkpoint_path=checkpoint) 201 encoder_state = model.backbone.vision_backbone.state_dict() 202 # Let's align loading the encoder weights with expected parameter names 203 encoder_state = { 204 k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items() 205 } 206 # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks. 207 encoder_state = { 208 k: v for k, v in encoder_state.items() 209 if not (k.startswith("convs.") or k.startswith("sam2_convs.")) 210 } 211 except Exception: 212 # Try loading the encoder state directly from a checkpoint. 213 encoder_state = torch.load(checkpoint, weights_only=False) 214 215 elif backbone == "mae": 216 # vit initialization hints from: 217 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 218 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 219 encoder_state = OrderedDict({ 220 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 221 }) 222 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 223 current_encoder_state = self.encoder.state_dict() 224 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 225 del self.encoder.head 226 227 elif backbone == "scalemae": 228 # Load the encoder state directly from a checkpoint. 229 encoder_state = torch.load(checkpoint)["model"] 230 encoder_state = OrderedDict({ 231 k: v for k, v in encoder_state.items() 232 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 233 }) 234 235 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 236 current_encoder_state = self.encoder.state_dict() 237 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 238 del self.encoder.head 239 240 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 241 del self.encoder.pos_embed 242 243 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 244 encoder_state = torch.load(checkpoint) 245 246 else: 247 raise ValueError( 248 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 249 ) 250 251 else: 252 encoder_state = checkpoint 253 254 self.encoder.load_state_dict(encoder_state) 255 256 def _get_activation(self, activation): 257 return_activation = None 258 if activation is None: 259 return None 260 if isinstance(activation, nn.Module): 261 return activation 262 if isinstance(activation, str): 263 return_activation = getattr(nn, activation, None) 264 if return_activation is None: 265 raise ValueError(f"Invalid activation: {activation}") 266 267 return return_activation() 268 269 @staticmethod 270 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 271 """Compute the output size given input size and target long side length. 272 273 Args: 274 oldh: The input image height. 275 oldw: The input image width. 276 long_side_length: The longest side length for resizing. 277 278 Returns: 279 The new image height. 280 The new image width. 281 """ 282 scale = long_side_length * 1.0 / max(oldh, oldw) 283 newh, neww = oldh * scale, oldw * scale 284 neww = int(neww + 0.5) 285 newh = int(newh + 0.5) 286 return (newh, neww) 287 288 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 289 """Resize the image so that the longest side has the correct length. 290 291 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 292 293 Args: 294 image: The input image. 295 296 Returns: 297 The resized image. 298 """ 299 if image.ndim == 4: # i.e. 2d image 300 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 301 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 302 elif image.ndim == 5: # i.e. 3d volume 303 B, C, Z, H, W = image.shape 304 target_size = self.get_preprocess_shape(H, W, self.img_size) 305 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 306 else: 307 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 308 309 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 310 """@private 311 """ 312 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 313 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 314 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 315 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 316 return pixel_mean, pixel_std 317 318 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 319 """@private 320 """ 321 device = x.device 322 is_3d = (x.ndim == 5) 323 device, dtype = x.device, x.dtype 324 325 if self.use_sam_stats: 326 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 327 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 328 raise NotImplementedError 329 elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"): 330 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 331 elif self.use_sam_stats and self.backbone == "sam3": 332 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 333 else: 334 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 335 336 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 337 338 if self.resize_input: 339 x = self.resize_longest_side(x) 340 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 341 342 x = (x - pixel_mean) / pixel_std 343 h, w = x.shape[-2:] 344 padh = self.encoder.img_size - h 345 padw = self.encoder.img_size - w 346 347 if is_3d: 348 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 349 else: 350 x = F.pad(x, (0, padw, 0, padh)) 351 352 return x, input_shape 353 354 def postprocess_masks( 355 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 356 ) -> torch.Tensor: 357 """@private 358 """ 359 if masks.ndim == 4: # i.e. 2d labels 360 masks = F.interpolate( 361 masks, 362 (self.encoder.img_size, self.encoder.img_size), 363 mode="bilinear", 364 align_corners=False, 365 ) 366 masks = masks[..., : input_size[0], : input_size[1]] 367 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 368 369 elif masks.ndim == 5: # i.e. 3d volumetric labels 370 masks = F.interpolate( 371 masks, 372 (input_size[0], self.img_size, self.img_size), 373 mode="trilinear", 374 align_corners=False, 375 ) 376 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 377 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 378 379 else: 380 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 381 382 return masks 383 384 385class UNETR(UNETRBase): 386 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 387 """ 388 def __init__( 389 self, 390 img_size: int = 1024, 391 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 392 encoder: Optional[Union[nn.Module, str]] = "vit_b", 393 decoder: Optional[nn.Module] = None, 394 out_channels: int = 1, 395 use_sam_stats: bool = False, 396 use_mae_stats: bool = False, 397 use_dino_stats: bool = False, 398 resize_input: bool = True, 399 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 400 final_activation: Optional[Union[str, nn.Module]] = None, 401 use_skip_connection: bool = True, 402 embed_dim: Optional[int] = None, 403 use_conv_transpose: bool = False, 404 **kwargs 405 ) -> None: 406 407 super().__init__( 408 img_size=img_size, 409 backbone=backbone, 410 encoder=encoder, 411 decoder=decoder, 412 out_channels=out_channels, 413 use_sam_stats=use_sam_stats, 414 use_mae_stats=use_mae_stats, 415 use_dino_stats=use_dino_stats, 416 resize_input=resize_input, 417 encoder_checkpoint=encoder_checkpoint, 418 final_activation=final_activation, 419 use_skip_connection=use_skip_connection, 420 embed_dim=embed_dim, 421 use_conv_transpose=use_conv_transpose, 422 **kwargs, 423 ) 424 425 encoder = self.encoder 426 427 if backbone == "sam2" and hasattr(encoder, "trunk"): 428 in_chans = encoder.trunk.patch_embed.proj.in_channels 429 elif hasattr(encoder, "in_chans"): 430 in_chans = encoder.in_chans 431 else: # `nn.Module` ViT backbone. 432 try: 433 in_chans = encoder.patch_embed.proj.in_channels 434 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 435 in_chans = encoder.patch_embed.seq[0].c.in_channels 436 437 # parameters for the decoder network 438 depth = 3 439 initial_features = 64 440 gain = 2 441 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 442 scale_factors = depth * [2] 443 self.out_channels = out_channels 444 445 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 446 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 447 448 self.decoder = decoder or Decoder( 449 features=features_decoder, 450 scale_factors=scale_factors[::-1], 451 conv_block_impl=ConvBlock2d, 452 sampler_impl=_upsampler, 453 ) 454 455 if use_skip_connection: 456 self.deconv1 = Deconv2DBlock( 457 in_channels=self.embed_dim, 458 out_channels=features_decoder[0], 459 use_conv_transpose=use_conv_transpose, 460 ) 461 self.deconv2 = nn.Sequential( 462 Deconv2DBlock( 463 in_channels=self.embed_dim, 464 out_channels=features_decoder[0], 465 use_conv_transpose=use_conv_transpose, 466 ), 467 Deconv2DBlock( 468 in_channels=features_decoder[0], 469 out_channels=features_decoder[1], 470 use_conv_transpose=use_conv_transpose, 471 ) 472 ) 473 self.deconv3 = nn.Sequential( 474 Deconv2DBlock( 475 in_channels=self.embed_dim, 476 out_channels=features_decoder[0], 477 use_conv_transpose=use_conv_transpose, 478 ), 479 Deconv2DBlock( 480 in_channels=features_decoder[0], 481 out_channels=features_decoder[1], 482 use_conv_transpose=use_conv_transpose, 483 ), 484 Deconv2DBlock( 485 in_channels=features_decoder[1], 486 out_channels=features_decoder[2], 487 use_conv_transpose=use_conv_transpose, 488 ) 489 ) 490 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 491 else: 492 self.deconv1 = Deconv2DBlock( 493 in_channels=self.embed_dim, 494 out_channels=features_decoder[0], 495 use_conv_transpose=use_conv_transpose, 496 ) 497 self.deconv2 = Deconv2DBlock( 498 in_channels=features_decoder[0], 499 out_channels=features_decoder[1], 500 use_conv_transpose=use_conv_transpose, 501 ) 502 self.deconv3 = Deconv2DBlock( 503 in_channels=features_decoder[1], 504 out_channels=features_decoder[2], 505 use_conv_transpose=use_conv_transpose, 506 ) 507 self.deconv4 = Deconv2DBlock( 508 in_channels=features_decoder[2], 509 out_channels=features_decoder[3], 510 use_conv_transpose=use_conv_transpose, 511 ) 512 513 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 514 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 515 self.deconv_out = _upsampler( 516 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 517 ) 518 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 519 520 def forward(self, x: torch.Tensor) -> torch.Tensor: 521 """Apply the UNETR to the input data. 522 523 Args: 524 x: The input tensor. 525 526 Returns: 527 The UNETR output. 528 """ 529 original_shape = x.shape[-2:] 530 531 # Reshape the inputs to the shape expected by the encoder 532 # and normalize the inputs if normalization is part of the model. 533 x, input_shape = self.preprocess(x) 534 535 encoder_outputs = self.encoder(x) 536 537 if isinstance(encoder_outputs[-1], list): 538 # `encoder_outputs` can be arranged in only two forms: 539 # - either we only return the image embeddings 540 # - or, we return the image embeddings and the "list" of global attention layers 541 z12, from_encoder = encoder_outputs 542 else: 543 z12 = encoder_outputs 544 545 if self.use_skip_connection: 546 from_encoder = from_encoder[::-1] 547 z9 = self.deconv1(from_encoder[0]) 548 z6 = self.deconv2(from_encoder[1]) 549 z3 = self.deconv3(from_encoder[2]) 550 z0 = self.deconv4(x) 551 552 else: 553 z9 = self.deconv1(z12) 554 z6 = self.deconv2(z9) 555 z3 = self.deconv3(z6) 556 z0 = self.deconv4(z3) 557 558 updated_from_encoder = [z9, z6, z3] 559 560 x = self.base(z12) 561 x = self.decoder(x, encoder_inputs=updated_from_encoder) 562 x = self.deconv_out(x) 563 564 x = torch.cat([x, z0], dim=1) 565 x = self.decoder_head(x) 566 567 x = self.out_conv(x) 568 if self.final_activation is not None: 569 x = self.final_activation(x) 570 571 x = self.postprocess_masks(x, input_shape, original_shape) 572 return x 573 574 575class UNETR2D(UNETR): 576 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 577 """ 578 pass 579 580 581class UNETR3D(UNETRBase): 582 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 583 """ 584 def __init__( 585 self, 586 img_size: int = 1024, 587 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 588 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 589 decoder: Optional[nn.Module] = None, 590 out_channels: int = 1, 591 use_sam_stats: bool = False, 592 use_mae_stats: bool = False, 593 use_dino_stats: bool = False, 594 resize_input: bool = True, 595 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 596 final_activation: Optional[Union[str, nn.Module]] = None, 597 use_skip_connection: bool = False, 598 embed_dim: Optional[int] = None, 599 use_conv_transpose: bool = False, 600 use_strip_pooling: bool = True, 601 **kwargs 602 ): 603 if use_skip_connection: 604 raise NotImplementedError("The framework cannot handle skip connections atm.") 605 if use_conv_transpose: 606 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 607 608 # Sort the `embed_dim` out 609 embed_dim = 256 if embed_dim is None else embed_dim 610 611 super().__init__( 612 img_size=img_size, 613 backbone=backbone, 614 encoder=encoder, 615 decoder=decoder, 616 out_channels=out_channels, 617 use_sam_stats=use_sam_stats, 618 use_mae_stats=use_mae_stats, 619 use_dino_stats=use_dino_stats, 620 resize_input=resize_input, 621 encoder_checkpoint=encoder_checkpoint, 622 final_activation=final_activation, 623 use_skip_connection=use_skip_connection, 624 embed_dim=embed_dim, 625 use_conv_transpose=use_conv_transpose, 626 **kwargs, 627 ) 628 629 # The 3d convolutional decoder. 630 # First, get the important parameters for the decoder. 631 depth = 3 632 initial_features = 64 633 gain = 2 634 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 635 scale_factors = [1, 2, 2] 636 self.out_channels = out_channels 637 638 # The mapping blocks. 639 self.deconv1 = Deconv3DBlock( 640 in_channels=embed_dim, 641 out_channels=features_decoder[0], 642 scale_factor=scale_factors, 643 use_strip_pooling=use_strip_pooling, 644 ) 645 self.deconv2 = Deconv3DBlock( 646 in_channels=features_decoder[0], 647 out_channels=features_decoder[1], 648 scale_factor=scale_factors, 649 use_strip_pooling=use_strip_pooling, 650 ) 651 self.deconv3 = Deconv3DBlock( 652 in_channels=features_decoder[1], 653 out_channels=features_decoder[2], 654 scale_factor=scale_factors, 655 use_strip_pooling=use_strip_pooling, 656 ) 657 self.deconv4 = Deconv3DBlock( 658 in_channels=features_decoder[2], 659 out_channels=features_decoder[3], 660 scale_factor=scale_factors, 661 use_strip_pooling=use_strip_pooling, 662 ) 663 664 # The core decoder block. 665 self.decoder = decoder or Decoder( 666 features=features_decoder, 667 scale_factors=[scale_factors] * depth, 668 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 669 sampler_impl=Upsampler3d, 670 ) 671 672 # And the final upsampler to match the expected dimensions. 673 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 674 in_channels=features_decoder[-1], 675 out_channels=features_decoder[-1], 676 scale_factor=scale_factors, 677 use_strip_pooling=use_strip_pooling, 678 ) 679 680 # Additional conjunction blocks. 681 self.base = ConvBlock3dWithStrip( 682 in_channels=embed_dim, 683 out_channels=features_decoder[0], 684 use_strip_pooling=use_strip_pooling, 685 ) 686 687 # And the output layers. 688 self.decoder_head = ConvBlock3dWithStrip( 689 in_channels=2 * features_decoder[-1], 690 out_channels=features_decoder[-1], 691 use_strip_pooling=use_strip_pooling, 692 ) 693 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 694 695 def forward(self, x: torch.Tensor): 696 """Forward pass of the UNETR-3D model. 697 698 Args: 699 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 700 701 Returns: 702 The UNETR output. 703 """ 704 B, C, Z, H, W = x.shape 705 original_shape = (Z, H, W) 706 707 # Preprocessing step 708 x, input_shape = self.preprocess(x) 709 710 # Run the image encoder. 711 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 712 713 # Prepare the counterparts for the decoder. 714 # NOTE: The section below is sequential, there's no skip connections atm. 715 z9 = self.deconv1(curr_features) 716 z6 = self.deconv2(z9) 717 z3 = self.deconv3(z6) 718 z0 = self.deconv4(z3) 719 720 updated_from_encoder = [z9, z6, z3] 721 722 # Align the features through the base block. 723 x = self.base(curr_features) 724 # Run the decoder 725 x = self.decoder(x, encoder_inputs=updated_from_encoder) 726 x = self.deconv_out(x) # NOTE before `end_up` 727 728 # And the final output head. 729 x = torch.cat([x, z0], dim=1) 730 x = self.decoder_head(x) 731 x = self.out_conv(x) 732 if self.final_activation is not None: 733 x = self.final_activation(x) 734 735 # Postprocess the output back to original size. 736 x = self.postprocess_masks(x, input_shape, original_shape) 737 return x 738 739# 740# ADDITIONAL FUNCTIONALITIES 741# 742 743 744def _strip_pooling_layers(enabled, channels) -> nn.Module: 745 return DepthStripPooling(channels) if enabled else nn.Identity() 746 747 748class DepthStripPooling(nn.Module): 749 """@private 750 """ 751 def __init__(self, channels: int, reduction: int = 4): 752 """Block for strip pooling along the depth dimension (only). 753 754 eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling 755 to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to 756 modulate the original features (using a gated residual). 757 758 For 2D (Z == 1 or Z == 3): returns input unchanged (no-op). 759 760 Args: 761 channels: The output channels. 762 reduction: The reduction of the hidden layers. 763 """ 764 super().__init__() 765 hidden = max(1, channels // reduction) 766 self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1) 767 self.bn1 = nn.BatchNorm3d(hidden) 768 self.relu = nn.ReLU(inplace=True) 769 self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1) 770 771 def forward(self, x: torch.Tensor) -> torch.Tensor: 772 if x.dim() != 5: 773 raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.") 774 775 B, C, Z, H, W = x.shape 776 if Z == 1 or Z == 3: # i.e. 2d-as-1-slice or RGB_2d-as-1-slice. 777 return x # We simply do nothing there. 778 779 # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W) 780 feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W)) 781 feat = self.conv1(feat) 782 feat = self.bn1(feat) 783 feat = self.relu(feat) 784 feat = self.conv2(feat) 785 gate = torch.sigmoid(feat).expand(B, C, Z, H, W) # Broadcast the collapsed depth context back to all slices 786 787 # Gated residual fusion 788 return x * gate + x 789 790 791class Deconv3DBlock(nn.Module): 792 """@private 793 """ 794 def __init__( 795 self, 796 scale_factor, 797 in_channels, 798 out_channels, 799 kernel_size=3, 800 anisotropic_kernel=True, 801 use_strip_pooling=True, 802 ): 803 super().__init__() 804 conv_block_kwargs = { 805 "in_channels": out_channels, 806 "out_channels": out_channels, 807 "kernel_size": kernel_size, 808 "padding": ((kernel_size - 1) // 2), 809 } 810 if anisotropic_kernel: 811 conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor) 812 813 self.block = nn.Sequential( 814 Upsampler3d(scale_factor, in_channels, out_channels), 815 nn.Conv3d(**conv_block_kwargs), 816 nn.BatchNorm3d(out_channels), 817 nn.ReLU(True), 818 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 819 ) 820 821 def forward(self, x): 822 return self.block(x) 823 824 825class ConvBlock3dWithStrip(nn.Module): 826 """@private 827 """ 828 def __init__( 829 self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs 830 ): 831 super().__init__() 832 self.block = nn.Sequential( 833 ConvBlock3d(in_channels, out_channels, **kwargs), 834 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 835 ) 836 837 def forward(self, x): 838 return self.block(x) 839 840 841class SingleDeconv2DBlock(nn.Module): 842 """@private 843 """ 844 def __init__(self, scale_factor, in_channels, out_channels): 845 super().__init__() 846 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 847 848 def forward(self, x): 849 return self.block(x) 850 851 852class SingleConv2DBlock(nn.Module): 853 """@private 854 """ 855 def __init__(self, in_channels, out_channels, kernel_size): 856 super().__init__() 857 self.block = nn.Conv2d( 858 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 859 ) 860 861 def forward(self, x): 862 return self.block(x) 863 864 865class Conv2DBlock(nn.Module): 866 """@private 867 """ 868 def __init__(self, in_channels, out_channels, kernel_size=3): 869 super().__init__() 870 self.block = nn.Sequential( 871 SingleConv2DBlock(in_channels, out_channels, kernel_size), 872 nn.BatchNorm2d(out_channels), 873 nn.ReLU(True) 874 ) 875 876 def forward(self, x): 877 return self.block(x) 878 879 880class Deconv2DBlock(nn.Module): 881 """@private 882 """ 883 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 884 super().__init__() 885 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 886 self.block = nn.Sequential( 887 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 888 SingleConv2DBlock(out_channels, out_channels, kernel_size), 889 nn.BatchNorm2d(out_channels), 890 nn.ReLU(True) 891 ) 892 893 def forward(self, x): 894 return self.block(x)
34class UNETRBase(nn.Module): 35 """Base class for implementing a UNETR. 36 37 Args: 38 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 39 backbone: The name of the vision transformer implementation. 40 One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" 41 (see all combinations below) 42 encoder: The vision transformer. Can either be a name, such as "vit_b" 43 (see all combinations for this below) or a torch module. 44 decoder: The convolutional decoder. 45 out_channels: The number of output channels of the UNETR. 46 use_sam_stats: Whether to normalize the input data with the statistics of the 47 pretrained SAM / SAM2 / SAM3 model. 48 use_dino_stats: Whether to normalize the input data with the statistics of the 49 pretrained DINOv2 / DINOv3 model. 50 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 51 resize_input: Whether to resize the input images to match `img_size`. 52 By default, it resizes the inputs to match the `img_size`. 53 encoder_checkpoint: Checkpoint for initializing the vision transformer. 54 Can either be a filepath or an already loaded checkpoint. 55 final_activation: The activation to apply to the UNETR output. 56 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 57 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 58 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 59 By default, it uses resampling for upsampling. 60 61 NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following: 62 63 SAM_family_models: 64 - 'sam' x 'vit_b' 65 - 'sam' x 'vit_l' 66 - 'sam' x 'vit_h' 67 - 'sam2' x 'hvit_t' 68 - 'sam2' x 'hvit_s' 69 - 'sam2' x 'hvit_b' 70 - 'sam2' x 'hvit_l' 71 - 'sam3' x 'vit_pe' 72 - 'cellpose_sam' x 'vit_l' 73 74 DINO_family_models: 75 - 'dinov2' x 'vit_s' 76 - 'dinov2' x 'vit_b' 77 - 'dinov2' x 'vit_l' 78 - 'dinov2' x 'vit_g' 79 - 'dinov2' x 'vit_s_reg4' 80 - 'dinov2' x 'vit_b_reg4' 81 - 'dinov2' x 'vit_l_reg4' 82 - 'dinov2' x 'vit_g_reg4' 83 - 'dinov3' x 'vit_s' 84 - 'dinov3' x 'vit_s+' 85 - 'dinov3' x 'vit_b' 86 - 'dinov3' x 'vit_l' 87 - 'dinov3' x 'vit_l+' 88 - 'dinov3' x 'vit_h+' 89 - 'dinov3' x 'vit_7b' 90 91 MAE_family_models: 92 - 'mae' x 'vit_b' 93 - 'mae' x 'vit_l' 94 - 'mae' x 'vit_h' 95 - 'scalemae' x 'vit_b' 96 - 'scalemae' x 'vit_l' 97 - 'scalemae' x 'vit_h' 98 """ 99 def __init__( 100 self, 101 img_size: int = 1024, 102 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 103 encoder: Optional[Union[nn.Module, str]] = "vit_b", 104 decoder: Optional[nn.Module] = None, 105 out_channels: int = 1, 106 use_sam_stats: bool = False, 107 use_mae_stats: bool = False, 108 use_dino_stats: bool = False, 109 resize_input: bool = True, 110 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 111 final_activation: Optional[Union[str, nn.Module]] = None, 112 use_skip_connection: bool = True, 113 embed_dim: Optional[int] = None, 114 use_conv_transpose: bool = False, 115 **kwargs 116 ) -> None: 117 super().__init__() 118 119 self.img_size = img_size 120 self.use_sam_stats = use_sam_stats 121 self.use_mae_stats = use_mae_stats 122 self.use_dino_stats = use_dino_stats 123 self.use_skip_connection = use_skip_connection 124 self.resize_input = resize_input 125 self.use_conv_transpose = use_conv_transpose 126 self.backbone = backbone 127 128 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 129 print(f"Using {encoder} from {backbone.upper()}") 130 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 131 132 if encoder_checkpoint is not None: 133 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 134 135 if embed_dim is None: 136 embed_dim = self.encoder.embed_dim 137 138 else: # `nn.Module` ViT backbone 139 self.encoder = encoder 140 141 have_neck = False 142 for name, _ in self.encoder.named_parameters(): 143 if name.startswith("neck"): 144 have_neck = True 145 146 if embed_dim is None: 147 if have_neck: 148 embed_dim = self.encoder.neck[2].out_channels # the value is 256 149 else: 150 embed_dim = self.encoder.patch_embed.proj.out_channels 151 152 self.embed_dim = embed_dim 153 self.final_activation = self._get_activation(final_activation) 154 155 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 156 """Function to load pretrained weights to the image encoder. 157 """ 158 if isinstance(checkpoint, str): 159 if backbone == "sam" and isinstance(encoder, str): 160 # If we have a SAM encoder, then we first try to load the full SAM Model 161 # (using micro_sam) and otherwise fall back on directly loading the encoder state 162 # from the checkpoint 163 try: 164 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 165 encoder_state = model.image_encoder.state_dict() 166 except Exception: 167 # Try loading the encoder state directly from a checkpoint. 168 encoder_state = torch.load(checkpoint, weights_only=False) 169 170 elif backbone == "cellpose_sam" and isinstance(encoder, str): 171 # The architecture matches CellposeSAM exactly (same rel_pos sizes), 172 # so weights load directly without any interpolation. 173 encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False) 174 # Handle DataParallel/DistributedDataParallel prefix. 175 if any(k.startswith("module.") for k in encoder_state.keys()): 176 encoder_state = OrderedDict( 177 {k[len("module."):]: v for k, v in encoder_state.items()} 178 ) 179 # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix). 180 if any(k.startswith("encoder.") for k in encoder_state.keys()): 181 encoder_state = OrderedDict( 182 {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")} 183 ) 184 185 elif backbone == "sam2" and isinstance(encoder, str): 186 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 187 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 188 # from the checkpoint 189 try: 190 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 191 encoder_state = model.image_encoder.state_dict() 192 except Exception: 193 # Try loading the encoder state directly from a checkpoint. 194 encoder_state = torch.load(checkpoint, weights_only=False) 195 196 elif backbone == "sam3" and isinstance(encoder, str): 197 # If we have a SAM3 encoder, then we first try to load the full SAM3 Model. 198 # (using micro_sam3) and otherwise fall back on directly loading the encoder state 199 # from the checkpoint 200 try: 201 model = get_sam3_model(checkpoint_path=checkpoint) 202 encoder_state = model.backbone.vision_backbone.state_dict() 203 # Let's align loading the encoder weights with expected parameter names 204 encoder_state = { 205 k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items() 206 } 207 # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks. 208 encoder_state = { 209 k: v for k, v in encoder_state.items() 210 if not (k.startswith("convs.") or k.startswith("sam2_convs.")) 211 } 212 except Exception: 213 # Try loading the encoder state directly from a checkpoint. 214 encoder_state = torch.load(checkpoint, weights_only=False) 215 216 elif backbone == "mae": 217 # vit initialization hints from: 218 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 219 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 220 encoder_state = OrderedDict({ 221 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 222 }) 223 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 224 current_encoder_state = self.encoder.state_dict() 225 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 226 del self.encoder.head 227 228 elif backbone == "scalemae": 229 # Load the encoder state directly from a checkpoint. 230 encoder_state = torch.load(checkpoint)["model"] 231 encoder_state = OrderedDict({ 232 k: v for k, v in encoder_state.items() 233 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 234 }) 235 236 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 237 current_encoder_state = self.encoder.state_dict() 238 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 239 del self.encoder.head 240 241 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 242 del self.encoder.pos_embed 243 244 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 245 encoder_state = torch.load(checkpoint) 246 247 else: 248 raise ValueError( 249 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 250 ) 251 252 else: 253 encoder_state = checkpoint 254 255 self.encoder.load_state_dict(encoder_state) 256 257 def _get_activation(self, activation): 258 return_activation = None 259 if activation is None: 260 return None 261 if isinstance(activation, nn.Module): 262 return activation 263 if isinstance(activation, str): 264 return_activation = getattr(nn, activation, None) 265 if return_activation is None: 266 raise ValueError(f"Invalid activation: {activation}") 267 268 return return_activation() 269 270 @staticmethod 271 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 272 """Compute the output size given input size and target long side length. 273 274 Args: 275 oldh: The input image height. 276 oldw: The input image width. 277 long_side_length: The longest side length for resizing. 278 279 Returns: 280 The new image height. 281 The new image width. 282 """ 283 scale = long_side_length * 1.0 / max(oldh, oldw) 284 newh, neww = oldh * scale, oldw * scale 285 neww = int(neww + 0.5) 286 newh = int(newh + 0.5) 287 return (newh, neww) 288 289 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 290 """Resize the image so that the longest side has the correct length. 291 292 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 293 294 Args: 295 image: The input image. 296 297 Returns: 298 The resized image. 299 """ 300 if image.ndim == 4: # i.e. 2d image 301 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 302 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 303 elif image.ndim == 5: # i.e. 3d volume 304 B, C, Z, H, W = image.shape 305 target_size = self.get_preprocess_shape(H, W, self.img_size) 306 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 307 else: 308 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 309 310 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 311 """@private 312 """ 313 # Either 2d batch: (1, C, 1, 1) or 3d batch: (1, C, 1, 1, 1). 314 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 315 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 316 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 317 return pixel_mean, pixel_std 318 319 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 320 """@private 321 """ 322 device = x.device 323 is_3d = (x.ndim == 5) 324 device, dtype = x.device, x.dtype 325 326 if self.use_sam_stats: 327 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 328 elif self.use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 329 raise NotImplementedError 330 elif self.use_dino_stats or (self.use_sam_stats and self.backbone == "sam2"): 331 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 332 elif self.use_sam_stats and self.backbone == "sam3": 333 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 334 else: 335 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 336 337 pixel_mean, pixel_std = self._as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 338 339 if self.resize_input: 340 x = self.resize_longest_side(x) 341 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 342 343 x = (x - pixel_mean) / pixel_std 344 h, w = x.shape[-2:] 345 padh = self.encoder.img_size - h 346 padw = self.encoder.img_size - w 347 348 if is_3d: 349 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 350 else: 351 x = F.pad(x, (0, padw, 0, padh)) 352 353 return x, input_shape 354 355 def postprocess_masks( 356 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 357 ) -> torch.Tensor: 358 """@private 359 """ 360 if masks.ndim == 4: # i.e. 2d labels 361 masks = F.interpolate( 362 masks, 363 (self.encoder.img_size, self.encoder.img_size), 364 mode="bilinear", 365 align_corners=False, 366 ) 367 masks = masks[..., : input_size[0], : input_size[1]] 368 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 369 370 elif masks.ndim == 5: # i.e. 3d volumetric labels 371 masks = F.interpolate( 372 masks, 373 (input_size[0], self.img_size, self.img_size), 374 mode="trilinear", 375 align_corners=False, 376 ) 377 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 378 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 379 380 else: 381 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 382 383 return masks
Base class for implementing a UNETR.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
- encoder: The vision transformer. Can either be a name, such as "vit_b" (see all combinations for this below) or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM / SAM2 / SAM3 model.
- use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv2 / DINOv3 model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size. By default, it resizes the inputs to match theimg_size. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
- NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
- SAM_family_models: - 'sam' x 'vit_b'
- 'sam' x 'vit_l'
- 'sam' x 'vit_h'
- 'sam2' x 'hvit_t'
- 'sam2' x 'hvit_s'
- 'sam2' x 'hvit_b'
- 'sam2' x 'hvit_l'
- 'sam3' x 'vit_pe'
- 'cellpose_sam' x 'vit_l'
- DINO_family_models: - 'dinov2' x 'vit_s'
- 'dinov2' x 'vit_b'
- 'dinov2' x 'vit_l'
- 'dinov2' x 'vit_g'
- 'dinov2' x 'vit_s_reg4'
- 'dinov2' x 'vit_b_reg4'
- 'dinov2' x 'vit_l_reg4'
- 'dinov2' x 'vit_g_reg4'
- 'dinov3' x 'vit_s'
- 'dinov3' x 'vit_s+'
- 'dinov3' x 'vit_b'
- 'dinov3' x 'vit_l'
- 'dinov3' x 'vit_l+'
- 'dinov3' x 'vit_h+'
- 'dinov3' x 'vit_7b'
- MAE_family_models: - 'mae' x 'vit_b'
- 'mae' x 'vit_l'
- 'mae' x 'vit_h'
- 'scalemae' x 'vit_b'
- 'scalemae' x 'vit_l'
- 'scalemae' x 'vit_h'
99 def __init__( 100 self, 101 img_size: int = 1024, 102 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 103 encoder: Optional[Union[nn.Module, str]] = "vit_b", 104 decoder: Optional[nn.Module] = None, 105 out_channels: int = 1, 106 use_sam_stats: bool = False, 107 use_mae_stats: bool = False, 108 use_dino_stats: bool = False, 109 resize_input: bool = True, 110 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 111 final_activation: Optional[Union[str, nn.Module]] = None, 112 use_skip_connection: bool = True, 113 embed_dim: Optional[int] = None, 114 use_conv_transpose: bool = False, 115 **kwargs 116 ) -> None: 117 super().__init__() 118 119 self.img_size = img_size 120 self.use_sam_stats = use_sam_stats 121 self.use_mae_stats = use_mae_stats 122 self.use_dino_stats = use_dino_stats 123 self.use_skip_connection = use_skip_connection 124 self.resize_input = resize_input 125 self.use_conv_transpose = use_conv_transpose 126 self.backbone = backbone 127 128 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 129 print(f"Using {encoder} from {backbone.upper()}") 130 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 131 132 if encoder_checkpoint is not None: 133 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 134 135 if embed_dim is None: 136 embed_dim = self.encoder.embed_dim 137 138 else: # `nn.Module` ViT backbone 139 self.encoder = encoder 140 141 have_neck = False 142 for name, _ in self.encoder.named_parameters(): 143 if name.startswith("neck"): 144 have_neck = True 145 146 if embed_dim is None: 147 if have_neck: 148 embed_dim = self.encoder.neck[2].out_channels # the value is 256 149 else: 150 embed_dim = self.encoder.patch_embed.proj.out_channels 151 152 self.embed_dim = embed_dim 153 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
270 @staticmethod 271 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 272 """Compute the output size given input size and target long side length. 273 274 Args: 275 oldh: The input image height. 276 oldw: The input image width. 277 long_side_length: The longest side length for resizing. 278 279 Returns: 280 The new image height. 281 The new image width. 282 """ 283 scale = long_side_length * 1.0 / max(oldh, oldw) 284 newh, neww = oldh * scale, oldw * scale 285 neww = int(neww + 0.5) 286 newh = int(newh + 0.5) 287 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
289 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 290 """Resize the image so that the longest side has the correct length. 291 292 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 293 294 Args: 295 image: The input image. 296 297 Returns: 298 The resized image. 299 """ 300 if image.ndim == 4: # i.e. 2d image 301 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 302 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 303 elif image.ndim == 5: # i.e. 3d volume 304 B, C, Z, H, W = image.shape 305 target_size = self.get_preprocess_shape(H, W, self.img_size) 306 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 307 else: 308 raise ValueError("Expected 4d or 5d inputs, got", image.shape)
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
386class UNETR(UNETRBase): 387 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 388 """ 389 def __init__( 390 self, 391 img_size: int = 1024, 392 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 393 encoder: Optional[Union[nn.Module, str]] = "vit_b", 394 decoder: Optional[nn.Module] = None, 395 out_channels: int = 1, 396 use_sam_stats: bool = False, 397 use_mae_stats: bool = False, 398 use_dino_stats: bool = False, 399 resize_input: bool = True, 400 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 401 final_activation: Optional[Union[str, nn.Module]] = None, 402 use_skip_connection: bool = True, 403 embed_dim: Optional[int] = None, 404 use_conv_transpose: bool = False, 405 **kwargs 406 ) -> None: 407 408 super().__init__( 409 img_size=img_size, 410 backbone=backbone, 411 encoder=encoder, 412 decoder=decoder, 413 out_channels=out_channels, 414 use_sam_stats=use_sam_stats, 415 use_mae_stats=use_mae_stats, 416 use_dino_stats=use_dino_stats, 417 resize_input=resize_input, 418 encoder_checkpoint=encoder_checkpoint, 419 final_activation=final_activation, 420 use_skip_connection=use_skip_connection, 421 embed_dim=embed_dim, 422 use_conv_transpose=use_conv_transpose, 423 **kwargs, 424 ) 425 426 encoder = self.encoder 427 428 if backbone == "sam2" and hasattr(encoder, "trunk"): 429 in_chans = encoder.trunk.patch_embed.proj.in_channels 430 elif hasattr(encoder, "in_chans"): 431 in_chans = encoder.in_chans 432 else: # `nn.Module` ViT backbone. 433 try: 434 in_chans = encoder.patch_embed.proj.in_channels 435 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 436 in_chans = encoder.patch_embed.seq[0].c.in_channels 437 438 # parameters for the decoder network 439 depth = 3 440 initial_features = 64 441 gain = 2 442 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 443 scale_factors = depth * [2] 444 self.out_channels = out_channels 445 446 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 447 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 448 449 self.decoder = decoder or Decoder( 450 features=features_decoder, 451 scale_factors=scale_factors[::-1], 452 conv_block_impl=ConvBlock2d, 453 sampler_impl=_upsampler, 454 ) 455 456 if use_skip_connection: 457 self.deconv1 = Deconv2DBlock( 458 in_channels=self.embed_dim, 459 out_channels=features_decoder[0], 460 use_conv_transpose=use_conv_transpose, 461 ) 462 self.deconv2 = nn.Sequential( 463 Deconv2DBlock( 464 in_channels=self.embed_dim, 465 out_channels=features_decoder[0], 466 use_conv_transpose=use_conv_transpose, 467 ), 468 Deconv2DBlock( 469 in_channels=features_decoder[0], 470 out_channels=features_decoder[1], 471 use_conv_transpose=use_conv_transpose, 472 ) 473 ) 474 self.deconv3 = nn.Sequential( 475 Deconv2DBlock( 476 in_channels=self.embed_dim, 477 out_channels=features_decoder[0], 478 use_conv_transpose=use_conv_transpose, 479 ), 480 Deconv2DBlock( 481 in_channels=features_decoder[0], 482 out_channels=features_decoder[1], 483 use_conv_transpose=use_conv_transpose, 484 ), 485 Deconv2DBlock( 486 in_channels=features_decoder[1], 487 out_channels=features_decoder[2], 488 use_conv_transpose=use_conv_transpose, 489 ) 490 ) 491 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 492 else: 493 self.deconv1 = Deconv2DBlock( 494 in_channels=self.embed_dim, 495 out_channels=features_decoder[0], 496 use_conv_transpose=use_conv_transpose, 497 ) 498 self.deconv2 = Deconv2DBlock( 499 in_channels=features_decoder[0], 500 out_channels=features_decoder[1], 501 use_conv_transpose=use_conv_transpose, 502 ) 503 self.deconv3 = Deconv2DBlock( 504 in_channels=features_decoder[1], 505 out_channels=features_decoder[2], 506 use_conv_transpose=use_conv_transpose, 507 ) 508 self.deconv4 = Deconv2DBlock( 509 in_channels=features_decoder[2], 510 out_channels=features_decoder[3], 511 use_conv_transpose=use_conv_transpose, 512 ) 513 514 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 515 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 516 self.deconv_out = _upsampler( 517 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 518 ) 519 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 520 521 def forward(self, x: torch.Tensor) -> torch.Tensor: 522 """Apply the UNETR to the input data. 523 524 Args: 525 x: The input tensor. 526 527 Returns: 528 The UNETR output. 529 """ 530 original_shape = x.shape[-2:] 531 532 # Reshape the inputs to the shape expected by the encoder 533 # and normalize the inputs if normalization is part of the model. 534 x, input_shape = self.preprocess(x) 535 536 encoder_outputs = self.encoder(x) 537 538 if isinstance(encoder_outputs[-1], list): 539 # `encoder_outputs` can be arranged in only two forms: 540 # - either we only return the image embeddings 541 # - or, we return the image embeddings and the "list" of global attention layers 542 z12, from_encoder = encoder_outputs 543 else: 544 z12 = encoder_outputs 545 546 if self.use_skip_connection: 547 from_encoder = from_encoder[::-1] 548 z9 = self.deconv1(from_encoder[0]) 549 z6 = self.deconv2(from_encoder[1]) 550 z3 = self.deconv3(from_encoder[2]) 551 z0 = self.deconv4(x) 552 553 else: 554 z9 = self.deconv1(z12) 555 z6 = self.deconv2(z9) 556 z3 = self.deconv3(z6) 557 z0 = self.deconv4(z3) 558 559 updated_from_encoder = [z9, z6, z3] 560 561 x = self.base(z12) 562 x = self.decoder(x, encoder_inputs=updated_from_encoder) 563 x = self.deconv_out(x) 564 565 x = torch.cat([x, z0], dim=1) 566 x = self.decoder_head(x) 567 568 x = self.out_conv(x) 569 if self.final_activation is not None: 570 x = self.final_activation(x) 571 572 x = self.postprocess_masks(x, input_shape, original_shape) 573 return x
A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
389 def __init__( 390 self, 391 img_size: int = 1024, 392 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 393 encoder: Optional[Union[nn.Module, str]] = "vit_b", 394 decoder: Optional[nn.Module] = None, 395 out_channels: int = 1, 396 use_sam_stats: bool = False, 397 use_mae_stats: bool = False, 398 use_dino_stats: bool = False, 399 resize_input: bool = True, 400 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 401 final_activation: Optional[Union[str, nn.Module]] = None, 402 use_skip_connection: bool = True, 403 embed_dim: Optional[int] = None, 404 use_conv_transpose: bool = False, 405 **kwargs 406 ) -> None: 407 408 super().__init__( 409 img_size=img_size, 410 backbone=backbone, 411 encoder=encoder, 412 decoder=decoder, 413 out_channels=out_channels, 414 use_sam_stats=use_sam_stats, 415 use_mae_stats=use_mae_stats, 416 use_dino_stats=use_dino_stats, 417 resize_input=resize_input, 418 encoder_checkpoint=encoder_checkpoint, 419 final_activation=final_activation, 420 use_skip_connection=use_skip_connection, 421 embed_dim=embed_dim, 422 use_conv_transpose=use_conv_transpose, 423 **kwargs, 424 ) 425 426 encoder = self.encoder 427 428 if backbone == "sam2" and hasattr(encoder, "trunk"): 429 in_chans = encoder.trunk.patch_embed.proj.in_channels 430 elif hasattr(encoder, "in_chans"): 431 in_chans = encoder.in_chans 432 else: # `nn.Module` ViT backbone. 433 try: 434 in_chans = encoder.patch_embed.proj.in_channels 435 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 436 in_chans = encoder.patch_embed.seq[0].c.in_channels 437 438 # parameters for the decoder network 439 depth = 3 440 initial_features = 64 441 gain = 2 442 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 443 scale_factors = depth * [2] 444 self.out_channels = out_channels 445 446 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 447 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 448 449 self.decoder = decoder or Decoder( 450 features=features_decoder, 451 scale_factors=scale_factors[::-1], 452 conv_block_impl=ConvBlock2d, 453 sampler_impl=_upsampler, 454 ) 455 456 if use_skip_connection: 457 self.deconv1 = Deconv2DBlock( 458 in_channels=self.embed_dim, 459 out_channels=features_decoder[0], 460 use_conv_transpose=use_conv_transpose, 461 ) 462 self.deconv2 = nn.Sequential( 463 Deconv2DBlock( 464 in_channels=self.embed_dim, 465 out_channels=features_decoder[0], 466 use_conv_transpose=use_conv_transpose, 467 ), 468 Deconv2DBlock( 469 in_channels=features_decoder[0], 470 out_channels=features_decoder[1], 471 use_conv_transpose=use_conv_transpose, 472 ) 473 ) 474 self.deconv3 = nn.Sequential( 475 Deconv2DBlock( 476 in_channels=self.embed_dim, 477 out_channels=features_decoder[0], 478 use_conv_transpose=use_conv_transpose, 479 ), 480 Deconv2DBlock( 481 in_channels=features_decoder[0], 482 out_channels=features_decoder[1], 483 use_conv_transpose=use_conv_transpose, 484 ), 485 Deconv2DBlock( 486 in_channels=features_decoder[1], 487 out_channels=features_decoder[2], 488 use_conv_transpose=use_conv_transpose, 489 ) 490 ) 491 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 492 else: 493 self.deconv1 = Deconv2DBlock( 494 in_channels=self.embed_dim, 495 out_channels=features_decoder[0], 496 use_conv_transpose=use_conv_transpose, 497 ) 498 self.deconv2 = Deconv2DBlock( 499 in_channels=features_decoder[0], 500 out_channels=features_decoder[1], 501 use_conv_transpose=use_conv_transpose, 502 ) 503 self.deconv3 = Deconv2DBlock( 504 in_channels=features_decoder[1], 505 out_channels=features_decoder[2], 506 use_conv_transpose=use_conv_transpose, 507 ) 508 self.deconv4 = Deconv2DBlock( 509 in_channels=features_decoder[2], 510 out_channels=features_decoder[3], 511 use_conv_transpose=use_conv_transpose, 512 ) 513 514 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 515 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 516 self.deconv_out = _upsampler( 517 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 518 ) 519 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
521 def forward(self, x: torch.Tensor) -> torch.Tensor: 522 """Apply the UNETR to the input data. 523 524 Args: 525 x: The input tensor. 526 527 Returns: 528 The UNETR output. 529 """ 530 original_shape = x.shape[-2:] 531 532 # Reshape the inputs to the shape expected by the encoder 533 # and normalize the inputs if normalization is part of the model. 534 x, input_shape = self.preprocess(x) 535 536 encoder_outputs = self.encoder(x) 537 538 if isinstance(encoder_outputs[-1], list): 539 # `encoder_outputs` can be arranged in only two forms: 540 # - either we only return the image embeddings 541 # - or, we return the image embeddings and the "list" of global attention layers 542 z12, from_encoder = encoder_outputs 543 else: 544 z12 = encoder_outputs 545 546 if self.use_skip_connection: 547 from_encoder = from_encoder[::-1] 548 z9 = self.deconv1(from_encoder[0]) 549 z6 = self.deconv2(from_encoder[1]) 550 z3 = self.deconv3(from_encoder[2]) 551 z0 = self.deconv4(x) 552 553 else: 554 z9 = self.deconv1(z12) 555 z6 = self.deconv2(z9) 556 z3 = self.deconv3(z6) 557 z0 = self.deconv4(z3) 558 559 updated_from_encoder = [z9, z6, z3] 560 561 x = self.base(z12) 562 x = self.decoder(x, encoder_inputs=updated_from_encoder) 563 x = self.deconv_out(x) 564 565 x = torch.cat([x, z0], dim=1) 566 x = self.decoder_head(x) 567 568 x = self.out_conv(x) 569 if self.final_activation is not None: 570 x = self.final_activation(x) 571 572 x = self.postprocess_masks(x, input_shape, original_shape) 573 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.
576class UNETR2D(UNETR): 577 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 578 """ 579 pass
A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
582class UNETR3D(UNETRBase): 583 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 584 """ 585 def __init__( 586 self, 587 img_size: int = 1024, 588 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 589 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 590 decoder: Optional[nn.Module] = None, 591 out_channels: int = 1, 592 use_sam_stats: bool = False, 593 use_mae_stats: bool = False, 594 use_dino_stats: bool = False, 595 resize_input: bool = True, 596 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 597 final_activation: Optional[Union[str, nn.Module]] = None, 598 use_skip_connection: bool = False, 599 embed_dim: Optional[int] = None, 600 use_conv_transpose: bool = False, 601 use_strip_pooling: bool = True, 602 **kwargs 603 ): 604 if use_skip_connection: 605 raise NotImplementedError("The framework cannot handle skip connections atm.") 606 if use_conv_transpose: 607 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 608 609 # Sort the `embed_dim` out 610 embed_dim = 256 if embed_dim is None else embed_dim 611 612 super().__init__( 613 img_size=img_size, 614 backbone=backbone, 615 encoder=encoder, 616 decoder=decoder, 617 out_channels=out_channels, 618 use_sam_stats=use_sam_stats, 619 use_mae_stats=use_mae_stats, 620 use_dino_stats=use_dino_stats, 621 resize_input=resize_input, 622 encoder_checkpoint=encoder_checkpoint, 623 final_activation=final_activation, 624 use_skip_connection=use_skip_connection, 625 embed_dim=embed_dim, 626 use_conv_transpose=use_conv_transpose, 627 **kwargs, 628 ) 629 630 # The 3d convolutional decoder. 631 # First, get the important parameters for the decoder. 632 depth = 3 633 initial_features = 64 634 gain = 2 635 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 636 scale_factors = [1, 2, 2] 637 self.out_channels = out_channels 638 639 # The mapping blocks. 640 self.deconv1 = Deconv3DBlock( 641 in_channels=embed_dim, 642 out_channels=features_decoder[0], 643 scale_factor=scale_factors, 644 use_strip_pooling=use_strip_pooling, 645 ) 646 self.deconv2 = Deconv3DBlock( 647 in_channels=features_decoder[0], 648 out_channels=features_decoder[1], 649 scale_factor=scale_factors, 650 use_strip_pooling=use_strip_pooling, 651 ) 652 self.deconv3 = Deconv3DBlock( 653 in_channels=features_decoder[1], 654 out_channels=features_decoder[2], 655 scale_factor=scale_factors, 656 use_strip_pooling=use_strip_pooling, 657 ) 658 self.deconv4 = Deconv3DBlock( 659 in_channels=features_decoder[2], 660 out_channels=features_decoder[3], 661 scale_factor=scale_factors, 662 use_strip_pooling=use_strip_pooling, 663 ) 664 665 # The core decoder block. 666 self.decoder = decoder or Decoder( 667 features=features_decoder, 668 scale_factors=[scale_factors] * depth, 669 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 670 sampler_impl=Upsampler3d, 671 ) 672 673 # And the final upsampler to match the expected dimensions. 674 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 675 in_channels=features_decoder[-1], 676 out_channels=features_decoder[-1], 677 scale_factor=scale_factors, 678 use_strip_pooling=use_strip_pooling, 679 ) 680 681 # Additional conjunction blocks. 682 self.base = ConvBlock3dWithStrip( 683 in_channels=embed_dim, 684 out_channels=features_decoder[0], 685 use_strip_pooling=use_strip_pooling, 686 ) 687 688 # And the output layers. 689 self.decoder_head = ConvBlock3dWithStrip( 690 in_channels=2 * features_decoder[-1], 691 out_channels=features_decoder[-1], 692 use_strip_pooling=use_strip_pooling, 693 ) 694 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 695 696 def forward(self, x: torch.Tensor): 697 """Forward pass of the UNETR-3D model. 698 699 Args: 700 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 701 702 Returns: 703 The UNETR output. 704 """ 705 B, C, Z, H, W = x.shape 706 original_shape = (Z, H, W) 707 708 # Preprocessing step 709 x, input_shape = self.preprocess(x) 710 711 # Run the image encoder. 712 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 713 714 # Prepare the counterparts for the decoder. 715 # NOTE: The section below is sequential, there's no skip connections atm. 716 z9 = self.deconv1(curr_features) 717 z6 = self.deconv2(z9) 718 z3 = self.deconv3(z6) 719 z0 = self.deconv4(z3) 720 721 updated_from_encoder = [z9, z6, z3] 722 723 # Align the features through the base block. 724 x = self.base(curr_features) 725 # Run the decoder 726 x = self.decoder(x, encoder_inputs=updated_from_encoder) 727 x = self.deconv_out(x) # NOTE before `end_up` 728 729 # And the final output head. 730 x = torch.cat([x, z0], dim=1) 731 x = self.decoder_head(x) 732 x = self.out_conv(x) 733 if self.final_activation is not None: 734 x = self.final_activation(x) 735 736 # Postprocess the output back to original size. 737 x = self.postprocess_masks(x, input_shape, original_shape) 738 return x
A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
585 def __init__( 586 self, 587 img_size: int = 1024, 588 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 589 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 590 decoder: Optional[nn.Module] = None, 591 out_channels: int = 1, 592 use_sam_stats: bool = False, 593 use_mae_stats: bool = False, 594 use_dino_stats: bool = False, 595 resize_input: bool = True, 596 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 597 final_activation: Optional[Union[str, nn.Module]] = None, 598 use_skip_connection: bool = False, 599 embed_dim: Optional[int] = None, 600 use_conv_transpose: bool = False, 601 use_strip_pooling: bool = True, 602 **kwargs 603 ): 604 if use_skip_connection: 605 raise NotImplementedError("The framework cannot handle skip connections atm.") 606 if use_conv_transpose: 607 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 608 609 # Sort the `embed_dim` out 610 embed_dim = 256 if embed_dim is None else embed_dim 611 612 super().__init__( 613 img_size=img_size, 614 backbone=backbone, 615 encoder=encoder, 616 decoder=decoder, 617 out_channels=out_channels, 618 use_sam_stats=use_sam_stats, 619 use_mae_stats=use_mae_stats, 620 use_dino_stats=use_dino_stats, 621 resize_input=resize_input, 622 encoder_checkpoint=encoder_checkpoint, 623 final_activation=final_activation, 624 use_skip_connection=use_skip_connection, 625 embed_dim=embed_dim, 626 use_conv_transpose=use_conv_transpose, 627 **kwargs, 628 ) 629 630 # The 3d convolutional decoder. 631 # First, get the important parameters for the decoder. 632 depth = 3 633 initial_features = 64 634 gain = 2 635 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 636 scale_factors = [1, 2, 2] 637 self.out_channels = out_channels 638 639 # The mapping blocks. 640 self.deconv1 = Deconv3DBlock( 641 in_channels=embed_dim, 642 out_channels=features_decoder[0], 643 scale_factor=scale_factors, 644 use_strip_pooling=use_strip_pooling, 645 ) 646 self.deconv2 = Deconv3DBlock( 647 in_channels=features_decoder[0], 648 out_channels=features_decoder[1], 649 scale_factor=scale_factors, 650 use_strip_pooling=use_strip_pooling, 651 ) 652 self.deconv3 = Deconv3DBlock( 653 in_channels=features_decoder[1], 654 out_channels=features_decoder[2], 655 scale_factor=scale_factors, 656 use_strip_pooling=use_strip_pooling, 657 ) 658 self.deconv4 = Deconv3DBlock( 659 in_channels=features_decoder[2], 660 out_channels=features_decoder[3], 661 scale_factor=scale_factors, 662 use_strip_pooling=use_strip_pooling, 663 ) 664 665 # The core decoder block. 666 self.decoder = decoder or Decoder( 667 features=features_decoder, 668 scale_factors=[scale_factors] * depth, 669 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 670 sampler_impl=Upsampler3d, 671 ) 672 673 # And the final upsampler to match the expected dimensions. 674 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 675 in_channels=features_decoder[-1], 676 out_channels=features_decoder[-1], 677 scale_factor=scale_factors, 678 use_strip_pooling=use_strip_pooling, 679 ) 680 681 # Additional conjunction blocks. 682 self.base = ConvBlock3dWithStrip( 683 in_channels=embed_dim, 684 out_channels=features_decoder[0], 685 use_strip_pooling=use_strip_pooling, 686 ) 687 688 # And the output layers. 689 self.decoder_head = ConvBlock3dWithStrip( 690 in_channels=2 * features_decoder[-1], 691 out_channels=features_decoder[-1], 692 use_strip_pooling=use_strip_pooling, 693 ) 694 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
696 def forward(self, x: torch.Tensor): 697 """Forward pass of the UNETR-3D model. 698 699 Args: 700 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 701 702 Returns: 703 The UNETR output. 704 """ 705 B, C, Z, H, W = x.shape 706 original_shape = (Z, H, W) 707 708 # Preprocessing step 709 x, input_shape = self.preprocess(x) 710 711 # Run the image encoder. 712 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 713 714 # Prepare the counterparts for the decoder. 715 # NOTE: The section below is sequential, there's no skip connections atm. 716 z9 = self.deconv1(curr_features) 717 z6 = self.deconv2(z9) 718 z3 = self.deconv3(z6) 719 z0 = self.deconv4(z3) 720 721 updated_from_encoder = [z9, z6, z3] 722 723 # Align the features through the base block. 724 x = self.base(curr_features) 725 # Run the decoder 726 x = self.decoder(x, encoder_inputs=updated_from_encoder) 727 x = self.deconv_out(x) # NOTE before `end_up` 728 729 # And the final output head. 730 x = torch.cat([x, z0], dim=1) 731 x = self.decoder_head(x) 732 x = self.out_conv(x) 733 if self.final_activation is not None: 734 x = self.final_activation(x) 735 736 # Postprocess the output back to original size. 737 x = self.postprocess_masks(x, input_shape, original_shape) 738 return x
Forward pass of the UNETR-3D model.
Arguments:
- x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:
The UNETR output.