torch_em.model.unetr
1from functools import partial 2from collections import OrderedDict 3from typing import Optional, Tuple, Union, Literal 4 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8 9from .vit import get_vision_transformer 10from .unet import Decoder, ConvBlock2d, ConvBlock3d, Upsampler2d, Upsampler3d, _update_conv_kwargs 11 12try: 13 from micro_sam.util import get_sam_model 14except ImportError: 15 get_sam_model = None 16 17try: 18 from micro_sam.v2.util import get_sam2_model 19except ImportError: 20 get_sam2_model = None 21 22try: 23 from micro_sam3.util import get_sam3_model 24except ImportError: 25 get_sam3_model = None 26 27 28# 29# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / CellposeSAM / SAM2 / SAM3 / DINOv2 / DINOv3 / MAE / ScaleMAE) + UNet Decoder from `torch_em`] # noqa 30# 31 32 33def _check_input_normalization_range( 34 x: torch.Tensor, 35 expected_range: Optional[Tuple[float, float]], 36 unit_scale_max: Optional[float] = None, 37) -> None: 38 """Check whether raw inputs match the value range expected by the model normalizer. 39 40 Args: 41 x: The input tensor to validate. 42 expected_range: The (min, max) value range the input must lie within. If None, all checks are skipped. 43 unit_scale_max: If set, raises an error when the input's maximum value is at or below this threshold, 44 catching inputs that are likely in the wrong scale (e.g. [0, 1] instead of [0, 255] for SAM1). 45 """ 46 if expected_range is None: 47 return 48 49 if not torch.all(torch.isfinite(x)): 50 raise ValueError("The input contains NaN or infinite values before normalization.") 51 52 min_value, max_value = expected_range 53 if torch.any((x < min_value) | (x > max_value)): 54 actual_min, actual_max = torch.aminmax(x.detach()) 55 raise ValueError( 56 "The input is outside the expected scale before normalization: " 57 f"expected values in [{min_value}, {max_value}], got [{actual_min.item()}, {actual_max.item()}]. " 58 "Please check whether the raw inputs should be scaled to [0, 1] or kept in [0, 255] " 59 "before applying the pretrained normalization statistics." 60 ) 61 62 if unit_scale_max is not None: 63 actual_max = x.detach().max().item() 64 if actual_max <= unit_scale_max: 65 raise ValueError( 66 f"The input maximum value ({actual_max:.4f}) suggests the input is in the wrong scale: " 67 f"expected inputs with values in [{min_value}, {max_value}], " 68 f"but the maximum is only {actual_max:.4f}. " 69 "Please check whether the raw inputs should be scaled to [0, 255] instead of [0, 1]." 70 ) 71 72 73def _as_stats(mean, std, device, dtype, is_3d: bool): 74 view_shape = (1, -1, 1, 1, 1) if is_3d else (1, -1, 1, 1) 75 pixel_mean = torch.tensor(mean, device=device, dtype=dtype).view(*view_shape) 76 pixel_std = torch.tensor(std, device=device, dtype=dtype).view(*view_shape) 77 return pixel_mean, pixel_std 78 79 80class UNETRBase(nn.Module): 81 """Base class for implementing a UNETR. 82 83 Args: 84 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 85 backbone: The name of the vision transformer implementation. 86 One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" 87 (see all combinations below) 88 encoder: The vision transformer. Can either be a name, such as "vit_b" 89 (see all combinations for this below) or a torch module. 90 decoder: The convolutional decoder. 91 out_channels: The number of output channels of the UNETR. 92 use_sam_stats: Whether to normalize the input data with the statistics of the 93 pretrained SAM / SAM2 / SAM3 model. 94 use_dino_stats: Whether to normalize the input data with the statistics of the 95 pretrained DINOv2 / DINOv3 model. 96 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 97 resize_input: Whether to resize the input images to match `img_size`. 98 By default, it resizes the inputs to match the `img_size`. 99 encoder_checkpoint: Checkpoint for initializing the vision transformer. 100 Can either be a filepath or an already loaded checkpoint. 101 final_activation: The activation to apply to the UNETR output. 102 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 103 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 104 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 105 By default, it uses resampling for upsampling. 106 perform_range_checks: Whether to validate the input value range before normalization on each forward pass. 107 You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct. 108 109 NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following: 110 111 SAM_family_models: 112 - 'sam' x 'vit_b' 113 - 'sam' x 'vit_l' 114 - 'sam' x 'vit_h' 115 - 'sam2' x 'hvit_t' 116 - 'sam2' x 'hvit_s' 117 - 'sam2' x 'hvit_b' 118 - 'sam2' x 'hvit_l' 119 - 'sam3' x 'vit_pe' 120 - 'cellpose_sam' x 'vit_l' 121 122 DINO_family_models: 123 - 'dinov2' x 'vit_s' 124 - 'dinov2' x 'vit_b' 125 - 'dinov2' x 'vit_l' 126 - 'dinov2' x 'vit_g' 127 - 'dinov2' x 'vit_s_reg4' 128 - 'dinov2' x 'vit_b_reg4' 129 - 'dinov2' x 'vit_l_reg4' 130 - 'dinov2' x 'vit_g_reg4' 131 - 'dinov3' x 'vit_s' 132 - 'dinov3' x 'vit_s+' 133 - 'dinov3' x 'vit_b' 134 - 'dinov3' x 'vit_l' 135 - 'dinov3' x 'vit_l+' 136 - 'dinov3' x 'vit_h+' 137 - 'dinov3' x 'vit_7b' 138 139 MAE_family_models: 140 - 'mae' x 'vit_b' 141 - 'mae' x 'vit_l' 142 - 'mae' x 'vit_h' 143 - 'scalemae' x 'vit_b' 144 - 'scalemae' x 'vit_l' 145 - 'scalemae' x 'vit_h' 146 """ 147 def __init__( 148 self, 149 img_size: int = 1024, 150 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 151 encoder: Optional[Union[nn.Module, str]] = "vit_b", 152 decoder: Optional[nn.Module] = None, 153 out_channels: int = 1, 154 use_sam_stats: bool = False, 155 use_mae_stats: bool = False, 156 use_dino_stats: bool = False, 157 resize_input: bool = True, 158 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 159 final_activation: Optional[Union[str, nn.Module]] = None, 160 use_skip_connection: bool = True, 161 embed_dim: Optional[int] = None, 162 use_conv_transpose: bool = False, 163 perform_range_checks: bool = True, 164 **kwargs 165 ) -> None: 166 super().__init__() 167 168 self.img_size = img_size 169 self.use_sam_stats = use_sam_stats 170 self.use_mae_stats = use_mae_stats 171 self.use_dino_stats = use_dino_stats 172 self.use_skip_connection = use_skip_connection 173 self.resize_input = resize_input 174 self.perform_range_checks = perform_range_checks 175 self.use_conv_transpose = use_conv_transpose 176 self.backbone = backbone 177 178 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 179 print(f"Using {encoder} from {backbone.upper()}") 180 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 181 182 if encoder_checkpoint is not None: 183 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 184 185 if embed_dim is None: 186 embed_dim = self.encoder.embed_dim 187 188 # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change. 189 if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck: 190 embed_dim = self.encoder.neck[2].out_channels # the value is 256 191 192 else: # `nn.Module` ViT backbone 193 self.encoder = encoder 194 195 have_neck = False 196 for name, _ in self.encoder.named_parameters(): 197 if name.startswith("neck"): 198 have_neck = True 199 200 if embed_dim is None: 201 if have_neck: 202 embed_dim = self.encoder.neck[2].out_channels # the value is 256 203 else: 204 embed_dim = self.encoder.patch_embed.proj.out_channels 205 206 self.embed_dim = embed_dim 207 self.final_activation = self._get_activation(final_activation) 208 209 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 210 """Function to load pretrained weights to the image encoder. 211 """ 212 if isinstance(checkpoint, str): 213 if backbone == "sam" and isinstance(encoder, str): 214 # If we have a SAM encoder, then we first try to load the full SAM Model 215 # (using micro_sam) and otherwise fall back on directly loading the encoder state 216 # from the checkpoint 217 try: 218 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 219 encoder_state = model.image_encoder.state_dict() 220 except Exception: 221 # Try loading the encoder state directly from a checkpoint. 222 encoder_state = torch.load(checkpoint, weights_only=False) 223 224 elif backbone == "cellpose_sam" and isinstance(encoder, str): 225 # The architecture matches CellposeSAM exactly (same rel_pos sizes), 226 # so weights load directly without any interpolation. 227 encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False) 228 # Handle DataParallel/DistributedDataParallel prefix. 229 if any(k.startswith("module.") for k in encoder_state.keys()): 230 encoder_state = OrderedDict( 231 {k[len("module."):]: v for k, v in encoder_state.items()} 232 ) 233 # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix). 234 if any(k.startswith("encoder.") for k in encoder_state.keys()): 235 encoder_state = OrderedDict( 236 {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")} 237 ) 238 239 elif backbone == "sam2" and isinstance(encoder, str): 240 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 241 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 242 # from the checkpoint 243 try: 244 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 245 encoder_state = model.image_encoder.state_dict() 246 except Exception: 247 # Try loading the encoder state directly from a checkpoint. 248 encoder_state = torch.load(checkpoint, weights_only=False) 249 250 elif backbone == "sam3" and isinstance(encoder, str): 251 # If we have a SAM3 encoder, then we first try to load the full SAM3 Model. 252 # (using micro_sam3) and otherwise fall back on directly loading the encoder state 253 # from the checkpoint 254 try: 255 model = get_sam3_model(checkpoint_path=checkpoint) 256 encoder_state = model.backbone.vision_backbone.state_dict() 257 # Let's align loading the encoder weights with expected parameter names 258 encoder_state = { 259 k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items() 260 } 261 # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks. 262 encoder_state = { 263 k: v for k, v in encoder_state.items() 264 if not (k.startswith("convs.") or k.startswith("sam2_convs.")) 265 } 266 except Exception: 267 # Try loading the encoder state directly from a checkpoint. 268 encoder_state = torch.load(checkpoint, weights_only=False) 269 270 elif backbone == "mae": 271 # vit initialization hints from: 272 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 273 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 274 encoder_state = OrderedDict({ 275 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 276 }) 277 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 278 current_encoder_state = self.encoder.state_dict() 279 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 280 del self.encoder.head 281 282 elif backbone == "scalemae": 283 # Load the encoder state directly from a checkpoint. 284 encoder_state = torch.load(checkpoint)["model"] 285 encoder_state = OrderedDict({ 286 k: v for k, v in encoder_state.items() 287 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 288 }) 289 290 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 291 current_encoder_state = self.encoder.state_dict() 292 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 293 del self.encoder.head 294 295 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 296 del self.encoder.pos_embed 297 298 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 299 encoder_state = torch.load(checkpoint) 300 301 else: 302 raise ValueError( 303 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 304 ) 305 306 else: 307 encoder_state = checkpoint 308 309 self.encoder.load_state_dict(encoder_state) 310 311 def _get_activation(self, activation): 312 return_activation = None 313 if activation is None: 314 return None 315 if isinstance(activation, nn.Module): 316 return activation 317 if isinstance(activation, str): 318 return_activation = getattr(nn, activation, None) 319 if return_activation is None: 320 raise ValueError(f"Invalid activation: {activation}") 321 322 return return_activation() 323 324 @staticmethod 325 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 326 """Compute the output size given input size and target long side length. 327 328 Args: 329 oldh: The input image height. 330 oldw: The input image width. 331 long_side_length: The longest side length for resizing. 332 333 Returns: 334 The new image height. 335 The new image width. 336 """ 337 scale = long_side_length * 1.0 / max(oldh, oldw) 338 newh, neww = oldh * scale, oldw * scale 339 neww = int(neww + 0.5) 340 newh = int(newh + 0.5) 341 return (newh, neww) 342 343 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 344 """Resize the image so that the longest side has the correct length. 345 346 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 347 348 Args: 349 image: The input image. 350 351 Returns: 352 The resized image. 353 """ 354 if image.ndim == 4: # i.e. 2d image 355 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 356 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 357 elif image.ndim == 5: # i.e. 3d volume 358 B, C, Z, H, W = image.shape 359 target_size = self.get_preprocess_shape(H, W, self.img_size) 360 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 361 else: 362 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 363 364 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 365 """@private 366 """ 367 return _as_stats(mean, std, device, dtype, is_3d) 368 369 def _check_input_normalization_range(self, x: torch.Tensor, expected_range: Optional[Tuple[float, float]]) -> None: 370 """@private 371 """ 372 _check_input_normalization_range(x, expected_range) 373 374 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 375 """@private 376 """ 377 return preprocess_vit_inputs( 378 x, 379 use_sam_stats=self.use_sam_stats, 380 backbone=self.backbone, 381 use_mae_stats=self.use_mae_stats, 382 use_dino_stats=self.use_dino_stats, 383 resize_input=self.resize_input, 384 img_size=self.img_size, 385 encoder_img_size=self.encoder.img_size, 386 perform_range_checks=self.perform_range_checks, 387 ) 388 389 def postprocess_masks( 390 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 391 ) -> torch.Tensor: 392 """@private 393 """ 394 if masks.ndim == 4: # i.e. 2d labels 395 masks = F.interpolate( 396 masks, 397 (self.encoder.img_size, self.encoder.img_size), 398 mode="bilinear", 399 align_corners=False, 400 ) 401 masks = masks[..., : input_size[0], : input_size[1]] 402 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 403 404 elif masks.ndim == 5: # i.e. 3d volumetric labels 405 masks = F.interpolate( 406 masks, 407 (input_size[0], self.img_size, self.img_size), 408 mode="trilinear", 409 align_corners=False, 410 ) 411 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 412 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 413 414 else: 415 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 416 417 return masks 418 419 420def preprocess_vit_inputs( 421 x: torch.Tensor, 422 use_sam_stats: bool = False, 423 backbone: str = "sam", 424 use_mae_stats: bool = False, 425 use_dino_stats: bool = False, 426 resize_input: bool = True, 427 img_size: int = 1024, 428 encoder_img_size: int = 1024, 429 perform_range_checks: bool = True, 430) -> Tuple[torch.Tensor, Tuple]: 431 """Preprocess inputs for ViT-backbones in UNETR models. 432 433 Handles normalization stat selection, input range validation, optional resizing to the longest side, 434 and padding to `encoder_img_size`. Can be used as a standalone function without a model instance. 435 436 Args: 437 x: Input tensor of shape (B, C, H, W) for 2D or (B, C, Z, H, W) for 3D. 438 use_sam_stats: Whether to normalize with SAM/SAM2/SAM3 backbone statistics. 439 backbone: The backbone name - controls which SAM stats are used when `use_sam_stats=True`. 440 use_mae_stats: Whether to normalize with MAE statistics. 441 use_dino_stats: Whether to normalize with DINOv2/DINOv3 statistics. 442 resize_input: Whether to resize the input to the longest side before padding. 443 img_size: The model image size, used for 3D resize. 444 encoder_img_size: The encoder image size, used for 2D resize and padding. 445 perform_range_checks: Whether to validate the expected input value range before normalization. 446 You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct. 447 448 Returns: 449 The preprocessed tensor and the spatial shape after resizing (before padding). 450 """ 451 is_3d = (x.ndim == 5) 452 device, dtype = x.device, x.dtype 453 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 454 expected_range = None 455 unit_scale_max = None 456 457 if use_sam_stats: 458 if backbone == "sam2": 459 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 460 expected_range = (0.0, 1.0) 461 elif backbone == "sam3": 462 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 463 expected_range = (0.0, 1.0) 464 else: # sam1 / default 465 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 466 expected_range = (0.0, 255.0) 467 unit_scale_max = 1.0 468 elif use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 469 raise NotImplementedError 470 elif use_dino_stats: 471 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 472 expected_range = (0.0, 1.0) 473 else: 474 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 475 expected_range = None 476 477 if perform_range_checks: 478 _check_input_normalization_range(x, expected_range, unit_scale_max) 479 pixel_mean, pixel_std = _as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 480 481 if resize_input: 482 if x.ndim == 4: 483 target_size = UNETRBase.get_preprocess_shape(x.shape[2], x.shape[3], encoder_img_size) 484 x = F.interpolate(x, target_size, mode="bilinear", align_corners=False, antialias=True) 485 elif x.ndim == 5: 486 B, C, Z, H, W = x.shape 487 target_size = UNETRBase.get_preprocess_shape(H, W, img_size) 488 x = F.interpolate(x, (Z, *target_size), mode="trilinear", align_corners=False) 489 490 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 491 492 x = (x - pixel_mean) / pixel_std 493 h, w = x.shape[-2:] 494 padh = encoder_img_size - h 495 padw = encoder_img_size - w 496 497 if is_3d: 498 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 499 else: 500 x = F.pad(x, (0, padw, 0, padh)) 501 502 return x, input_shape 503 504 505class UNETR(UNETRBase): 506 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 507 """ 508 def __init__( 509 self, 510 img_size: int = 1024, 511 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 512 encoder: Optional[Union[nn.Module, str]] = "vit_b", 513 decoder: Optional[nn.Module] = None, 514 out_channels: int = 1, 515 use_sam_stats: bool = False, 516 use_mae_stats: bool = False, 517 use_dino_stats: bool = False, 518 resize_input: bool = True, 519 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 520 final_activation: Optional[Union[str, nn.Module]] = None, 521 use_skip_connection: bool = True, 522 embed_dim: Optional[int] = None, 523 use_conv_transpose: bool = False, 524 perform_range_checks: bool = True, 525 **kwargs 526 ) -> None: 527 528 super().__init__( 529 img_size=img_size, 530 backbone=backbone, 531 encoder=encoder, 532 decoder=decoder, 533 out_channels=out_channels, 534 use_sam_stats=use_sam_stats, 535 use_mae_stats=use_mae_stats, 536 use_dino_stats=use_dino_stats, 537 resize_input=resize_input, 538 encoder_checkpoint=encoder_checkpoint, 539 final_activation=final_activation, 540 use_skip_connection=use_skip_connection, 541 embed_dim=embed_dim, 542 use_conv_transpose=use_conv_transpose, 543 perform_range_checks=perform_range_checks, 544 **kwargs, 545 ) 546 547 encoder = self.encoder 548 549 if backbone == "sam2" and hasattr(encoder, "trunk"): 550 in_chans = encoder.trunk.patch_embed.proj.in_channels 551 elif hasattr(encoder, "in_chans"): 552 in_chans = encoder.in_chans 553 else: # `nn.Module` ViT backbone. 554 try: 555 in_chans = encoder.patch_embed.proj.in_channels 556 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 557 in_chans = encoder.patch_embed.seq[0].c.in_channels 558 559 # parameters for the decoder network 560 depth = 3 561 initial_features = 64 562 gain = 2 563 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 564 scale_factors = depth * [2] 565 self.out_channels = out_channels 566 567 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 568 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 569 570 self.decoder = decoder or Decoder( 571 features=features_decoder, 572 scale_factors=scale_factors[::-1], 573 conv_block_impl=ConvBlock2d, 574 sampler_impl=_upsampler, 575 ) 576 577 if use_skip_connection: 578 self.deconv1 = Deconv2DBlock( 579 in_channels=self.embed_dim, 580 out_channels=features_decoder[0], 581 use_conv_transpose=use_conv_transpose, 582 ) 583 self.deconv2 = nn.Sequential( 584 Deconv2DBlock( 585 in_channels=self.embed_dim, 586 out_channels=features_decoder[0], 587 use_conv_transpose=use_conv_transpose, 588 ), 589 Deconv2DBlock( 590 in_channels=features_decoder[0], 591 out_channels=features_decoder[1], 592 use_conv_transpose=use_conv_transpose, 593 ) 594 ) 595 self.deconv3 = nn.Sequential( 596 Deconv2DBlock( 597 in_channels=self.embed_dim, 598 out_channels=features_decoder[0], 599 use_conv_transpose=use_conv_transpose, 600 ), 601 Deconv2DBlock( 602 in_channels=features_decoder[0], 603 out_channels=features_decoder[1], 604 use_conv_transpose=use_conv_transpose, 605 ), 606 Deconv2DBlock( 607 in_channels=features_decoder[1], 608 out_channels=features_decoder[2], 609 use_conv_transpose=use_conv_transpose, 610 ) 611 ) 612 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 613 else: 614 self.deconv1 = Deconv2DBlock( 615 in_channels=self.embed_dim, 616 out_channels=features_decoder[0], 617 use_conv_transpose=use_conv_transpose, 618 ) 619 self.deconv2 = Deconv2DBlock( 620 in_channels=features_decoder[0], 621 out_channels=features_decoder[1], 622 use_conv_transpose=use_conv_transpose, 623 ) 624 self.deconv3 = Deconv2DBlock( 625 in_channels=features_decoder[1], 626 out_channels=features_decoder[2], 627 use_conv_transpose=use_conv_transpose, 628 ) 629 self.deconv4 = Deconv2DBlock( 630 in_channels=features_decoder[2], 631 out_channels=features_decoder[3], 632 use_conv_transpose=use_conv_transpose, 633 ) 634 635 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 636 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 637 self.deconv_out = _upsampler( 638 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 639 ) 640 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 641 642 def forward(self, x: torch.Tensor) -> torch.Tensor: 643 """Apply the UNETR to the input data. 644 645 Args: 646 x: The input tensor. 647 648 Returns: 649 The UNETR output. 650 """ 651 original_shape = x.shape[-2:] 652 653 # Reshape the inputs to the shape expected by the encoder 654 # and normalize the inputs if normalization is part of the model. 655 x, input_shape = self.preprocess(x) 656 657 encoder_outputs = self.encoder(x) 658 659 if isinstance(encoder_outputs[-1], list): 660 # `encoder_outputs` can be arranged in only two forms: 661 # - either we only return the image embeddings 662 # - or, we return the image embeddings and the "list" of global attention layers 663 z12, from_encoder = encoder_outputs 664 else: 665 z12 = encoder_outputs 666 667 if self.use_skip_connection: 668 from_encoder = from_encoder[::-1] 669 z9 = self.deconv1(from_encoder[0]) 670 z6 = self.deconv2(from_encoder[1]) 671 z3 = self.deconv3(from_encoder[2]) 672 z0 = self.deconv4(x) 673 674 else: 675 z9 = self.deconv1(z12) 676 z6 = self.deconv2(z9) 677 z3 = self.deconv3(z6) 678 z0 = self.deconv4(z3) 679 680 updated_from_encoder = [z9, z6, z3] 681 682 x = self.base(z12) 683 x = self.decoder(x, encoder_inputs=updated_from_encoder) 684 x = self.deconv_out(x) 685 686 x = torch.cat([x, z0], dim=1) 687 x = self.decoder_head(x) 688 689 x = self.out_conv(x) 690 if self.final_activation is not None: 691 x = self.final_activation(x) 692 693 x = self.postprocess_masks(x, input_shape, original_shape) 694 return x 695 696 697class UNETR2D(UNETR): 698 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 699 """ 700 pass 701 702 703class UNETR3D(UNETRBase): 704 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 705 """ 706 def __init__( 707 self, 708 img_size: int = 1024, 709 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 710 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 711 decoder: Optional[nn.Module] = None, 712 out_channels: int = 1, 713 use_sam_stats: bool = False, 714 use_mae_stats: bool = False, 715 use_dino_stats: bool = False, 716 resize_input: bool = True, 717 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 718 final_activation: Optional[Union[str, nn.Module]] = None, 719 use_skip_connection: bool = False, 720 embed_dim: Optional[int] = None, 721 use_conv_transpose: bool = False, 722 use_strip_pooling: bool = True, 723 perform_range_checks: bool = True, 724 **kwargs 725 ): 726 if use_skip_connection: 727 raise NotImplementedError("The framework cannot handle skip connections atm.") 728 if use_conv_transpose: 729 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 730 731 # Sort the `embed_dim` out 732 embed_dim = 256 if embed_dim is None else embed_dim 733 734 super().__init__( 735 img_size=img_size, 736 backbone=backbone, 737 encoder=encoder, 738 decoder=decoder, 739 out_channels=out_channels, 740 use_sam_stats=use_sam_stats, 741 use_mae_stats=use_mae_stats, 742 use_dino_stats=use_dino_stats, 743 resize_input=resize_input, 744 encoder_checkpoint=encoder_checkpoint, 745 final_activation=final_activation, 746 use_skip_connection=use_skip_connection, 747 embed_dim=embed_dim, 748 use_conv_transpose=use_conv_transpose, 749 perform_range_checks=perform_range_checks, 750 **kwargs, 751 ) 752 753 # The 3d convolutional decoder. 754 # First, get the important parameters for the decoder. 755 depth = 3 756 initial_features = 64 757 gain = 2 758 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 759 scale_factors = [1, 2, 2] 760 self.out_channels = out_channels 761 762 # The mapping blocks. 763 self.deconv1 = Deconv3DBlock( 764 in_channels=embed_dim, 765 out_channels=features_decoder[0], 766 scale_factor=scale_factors, 767 use_strip_pooling=use_strip_pooling, 768 ) 769 self.deconv2 = Deconv3DBlock( 770 in_channels=features_decoder[0], 771 out_channels=features_decoder[1], 772 scale_factor=scale_factors, 773 use_strip_pooling=use_strip_pooling, 774 ) 775 self.deconv3 = Deconv3DBlock( 776 in_channels=features_decoder[1], 777 out_channels=features_decoder[2], 778 scale_factor=scale_factors, 779 use_strip_pooling=use_strip_pooling, 780 ) 781 self.deconv4 = Deconv3DBlock( 782 in_channels=features_decoder[2], 783 out_channels=features_decoder[3], 784 scale_factor=scale_factors, 785 use_strip_pooling=use_strip_pooling, 786 ) 787 788 # The core decoder block. 789 self.decoder = decoder or Decoder( 790 features=features_decoder, 791 scale_factors=[scale_factors] * depth, 792 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 793 sampler_impl=Upsampler3d, 794 ) 795 796 # And the final upsampler to match the expected dimensions. 797 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 798 in_channels=features_decoder[-1], 799 out_channels=features_decoder[-1], 800 scale_factor=scale_factors, 801 use_strip_pooling=use_strip_pooling, 802 ) 803 804 # Additional conjunction blocks. 805 self.base = ConvBlock3dWithStrip( 806 in_channels=embed_dim, 807 out_channels=features_decoder[0], 808 use_strip_pooling=use_strip_pooling, 809 ) 810 811 # And the output layers. 812 self.decoder_head = ConvBlock3dWithStrip( 813 in_channels=2 * features_decoder[-1], 814 out_channels=features_decoder[-1], 815 use_strip_pooling=use_strip_pooling, 816 ) 817 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 818 819 def forward(self, x: torch.Tensor): 820 """Forward pass of the UNETR-3D model. 821 822 Args: 823 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 824 825 Returns: 826 The UNETR output. 827 """ 828 B, C, Z, H, W = x.shape 829 original_shape = (Z, H, W) 830 831 # Preprocessing step 832 x, input_shape = self.preprocess(x) 833 834 # Run the image encoder. 835 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 836 837 # Prepare the counterparts for the decoder. 838 # NOTE: The section below is sequential, there's no skip connections atm. 839 z9 = self.deconv1(curr_features) 840 z6 = self.deconv2(z9) 841 z3 = self.deconv3(z6) 842 z0 = self.deconv4(z3) 843 844 updated_from_encoder = [z9, z6, z3] 845 846 # Align the features through the base block. 847 x = self.base(curr_features) 848 # Run the decoder 849 x = self.decoder(x, encoder_inputs=updated_from_encoder) 850 x = self.deconv_out(x) # NOTE before `end_up` 851 852 # And the final output head. 853 x = torch.cat([x, z0], dim=1) 854 x = self.decoder_head(x) 855 x = self.out_conv(x) 856 if self.final_activation is not None: 857 x = self.final_activation(x) 858 859 # Postprocess the output back to original size. 860 x = self.postprocess_masks(x, input_shape, original_shape) 861 return x 862 863# 864# ADDITIONAL FUNCTIONALITIES 865# 866 867 868def _strip_pooling_layers(enabled, channels) -> nn.Module: 869 return DepthStripPooling(channels) if enabled else nn.Identity() 870 871 872class DepthStripPooling(nn.Module): 873 """@private 874 """ 875 def __init__(self, channels: int, reduction: int = 4): 876 """Block for strip pooling along the depth dimension (only). 877 878 eg. for 3D (Z > 1) - it aggregates global context across depth by adaptive avg pooling 879 to Z=1, and then passes through a small 1x1x1 MLP, then broadcasts it back to Z to 880 modulate the original features (using a gated residual). 881 882 For 2D (Z == 1): returns input unchanged (no-op). 883 884 Args: 885 channels: The output channels. 886 reduction: The reduction of the hidden layers. 887 """ 888 super().__init__() 889 hidden = max(1, channels // reduction) 890 self.conv1 = nn.Conv3d(channels, hidden, kernel_size=1) 891 self.bn1 = nn.BatchNorm3d(hidden) 892 self.relu = nn.ReLU(inplace=True) 893 self.conv2 = nn.Conv3d(hidden, channels, kernel_size=1) 894 895 def forward(self, x: torch.Tensor) -> torch.Tensor: 896 if x.dim() != 5: 897 raise ValueError(f"DepthStripPooling expects 5D tensors as input, got '{x.shape}'.") 898 899 B, C, Z, H, W = x.shape 900 if Z == 1: # i.e. always the case of all 2d. 901 return x # We simply do nothing there. 902 903 # We pool only along the depth dimension: i.e. target shape (B, C, 1, H, W) 904 feat = F.adaptive_avg_pool3d(x, output_size=(1, H, W)) 905 feat = self.conv1(feat) 906 feat = self.bn1(feat) 907 feat = self.relu(feat) 908 feat = self.conv2(feat) 909 gate = torch.sigmoid(feat).expand(B, C, Z, H, W) # Broadcast the collapsed depth context back to all slices 910 911 # Gated residual fusion 912 return x * gate + x 913 914 915class Deconv3DBlock(nn.Module): 916 """@private 917 """ 918 def __init__( 919 self, 920 scale_factor, 921 in_channels, 922 out_channels, 923 kernel_size=3, 924 anisotropic_kernel=True, 925 use_strip_pooling=True, 926 ): 927 super().__init__() 928 conv_block_kwargs = { 929 "in_channels": out_channels, 930 "out_channels": out_channels, 931 "kernel_size": kernel_size, 932 "padding": ((kernel_size - 1) // 2), 933 } 934 if anisotropic_kernel: 935 conv_block_kwargs = _update_conv_kwargs(conv_block_kwargs, scale_factor) 936 937 self.block = nn.Sequential( 938 Upsampler3d(scale_factor, in_channels, out_channels), 939 nn.Conv3d(**conv_block_kwargs), 940 nn.BatchNorm3d(out_channels), 941 nn.ReLU(True), 942 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 943 ) 944 945 def forward(self, x): 946 return self.block(x) 947 948 949class ConvBlock3dWithStrip(nn.Module): 950 """@private 951 """ 952 def __init__( 953 self, in_channels: int, out_channels: int, use_strip_pooling: bool = True, **kwargs 954 ): 955 super().__init__() 956 self.block = nn.Sequential( 957 ConvBlock3d(in_channels, out_channels, **kwargs), 958 _strip_pooling_layers(enabled=use_strip_pooling, channels=out_channels), 959 ) 960 961 def forward(self, x): 962 return self.block(x) 963 964 965class SingleDeconv2DBlock(nn.Module): 966 """@private 967 """ 968 def __init__(self, scale_factor, in_channels, out_channels): 969 super().__init__() 970 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 971 972 def forward(self, x): 973 return self.block(x) 974 975 976class SingleConv2DBlock(nn.Module): 977 """@private 978 """ 979 def __init__(self, in_channels, out_channels, kernel_size): 980 super().__init__() 981 self.block = nn.Conv2d( 982 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 983 ) 984 985 def forward(self, x): 986 return self.block(x) 987 988 989class Conv2DBlock(nn.Module): 990 """@private 991 """ 992 def __init__(self, in_channels, out_channels, kernel_size=3): 993 super().__init__() 994 self.block = nn.Sequential( 995 SingleConv2DBlock(in_channels, out_channels, kernel_size), 996 nn.BatchNorm2d(out_channels), 997 nn.ReLU(True) 998 ) 999 1000 def forward(self, x): 1001 return self.block(x) 1002 1003 1004class Deconv2DBlock(nn.Module): 1005 """@private 1006 """ 1007 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 1008 super().__init__() 1009 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 1010 self.block = nn.Sequential( 1011 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 1012 SingleConv2DBlock(out_channels, out_channels, kernel_size), 1013 nn.BatchNorm2d(out_channels), 1014 nn.ReLU(True) 1015 ) 1016 1017 def forward(self, x): 1018 return self.block(x)
81class UNETRBase(nn.Module): 82 """Base class for implementing a UNETR. 83 84 Args: 85 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 86 backbone: The name of the vision transformer implementation. 87 One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" 88 (see all combinations below) 89 encoder: The vision transformer. Can either be a name, such as "vit_b" 90 (see all combinations for this below) or a torch module. 91 decoder: The convolutional decoder. 92 out_channels: The number of output channels of the UNETR. 93 use_sam_stats: Whether to normalize the input data with the statistics of the 94 pretrained SAM / SAM2 / SAM3 model. 95 use_dino_stats: Whether to normalize the input data with the statistics of the 96 pretrained DINOv2 / DINOv3 model. 97 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 98 resize_input: Whether to resize the input images to match `img_size`. 99 By default, it resizes the inputs to match the `img_size`. 100 encoder_checkpoint: Checkpoint for initializing the vision transformer. 101 Can either be a filepath or an already loaded checkpoint. 102 final_activation: The activation to apply to the UNETR output. 103 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 104 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 105 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 106 By default, it uses resampling for upsampling. 107 perform_range_checks: Whether to validate the input value range before normalization on each forward pass. 108 You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct. 109 110 NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following: 111 112 SAM_family_models: 113 - 'sam' x 'vit_b' 114 - 'sam' x 'vit_l' 115 - 'sam' x 'vit_h' 116 - 'sam2' x 'hvit_t' 117 - 'sam2' x 'hvit_s' 118 - 'sam2' x 'hvit_b' 119 - 'sam2' x 'hvit_l' 120 - 'sam3' x 'vit_pe' 121 - 'cellpose_sam' x 'vit_l' 122 123 DINO_family_models: 124 - 'dinov2' x 'vit_s' 125 - 'dinov2' x 'vit_b' 126 - 'dinov2' x 'vit_l' 127 - 'dinov2' x 'vit_g' 128 - 'dinov2' x 'vit_s_reg4' 129 - 'dinov2' x 'vit_b_reg4' 130 - 'dinov2' x 'vit_l_reg4' 131 - 'dinov2' x 'vit_g_reg4' 132 - 'dinov3' x 'vit_s' 133 - 'dinov3' x 'vit_s+' 134 - 'dinov3' x 'vit_b' 135 - 'dinov3' x 'vit_l' 136 - 'dinov3' x 'vit_l+' 137 - 'dinov3' x 'vit_h+' 138 - 'dinov3' x 'vit_7b' 139 140 MAE_family_models: 141 - 'mae' x 'vit_b' 142 - 'mae' x 'vit_l' 143 - 'mae' x 'vit_h' 144 - 'scalemae' x 'vit_b' 145 - 'scalemae' x 'vit_l' 146 - 'scalemae' x 'vit_h' 147 """ 148 def __init__( 149 self, 150 img_size: int = 1024, 151 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 152 encoder: Optional[Union[nn.Module, str]] = "vit_b", 153 decoder: Optional[nn.Module] = None, 154 out_channels: int = 1, 155 use_sam_stats: bool = False, 156 use_mae_stats: bool = False, 157 use_dino_stats: bool = False, 158 resize_input: bool = True, 159 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 160 final_activation: Optional[Union[str, nn.Module]] = None, 161 use_skip_connection: bool = True, 162 embed_dim: Optional[int] = None, 163 use_conv_transpose: bool = False, 164 perform_range_checks: bool = True, 165 **kwargs 166 ) -> None: 167 super().__init__() 168 169 self.img_size = img_size 170 self.use_sam_stats = use_sam_stats 171 self.use_mae_stats = use_mae_stats 172 self.use_dino_stats = use_dino_stats 173 self.use_skip_connection = use_skip_connection 174 self.resize_input = resize_input 175 self.perform_range_checks = perform_range_checks 176 self.use_conv_transpose = use_conv_transpose 177 self.backbone = backbone 178 179 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 180 print(f"Using {encoder} from {backbone.upper()}") 181 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 182 183 if encoder_checkpoint is not None: 184 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 185 186 if embed_dim is None: 187 embed_dim = self.encoder.embed_dim 188 189 # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change. 190 if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck: 191 embed_dim = self.encoder.neck[2].out_channels # the value is 256 192 193 else: # `nn.Module` ViT backbone 194 self.encoder = encoder 195 196 have_neck = False 197 for name, _ in self.encoder.named_parameters(): 198 if name.startswith("neck"): 199 have_neck = True 200 201 if embed_dim is None: 202 if have_neck: 203 embed_dim = self.encoder.neck[2].out_channels # the value is 256 204 else: 205 embed_dim = self.encoder.patch_embed.proj.out_channels 206 207 self.embed_dim = embed_dim 208 self.final_activation = self._get_activation(final_activation) 209 210 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 211 """Function to load pretrained weights to the image encoder. 212 """ 213 if isinstance(checkpoint, str): 214 if backbone == "sam" and isinstance(encoder, str): 215 # If we have a SAM encoder, then we first try to load the full SAM Model 216 # (using micro_sam) and otherwise fall back on directly loading the encoder state 217 # from the checkpoint 218 try: 219 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 220 encoder_state = model.image_encoder.state_dict() 221 except Exception: 222 # Try loading the encoder state directly from a checkpoint. 223 encoder_state = torch.load(checkpoint, weights_only=False) 224 225 elif backbone == "cellpose_sam" and isinstance(encoder, str): 226 # The architecture matches CellposeSAM exactly (same rel_pos sizes), 227 # so weights load directly without any interpolation. 228 encoder_state = torch.load(checkpoint, map_location="cpu", weights_only=False) 229 # Handle DataParallel/DistributedDataParallel prefix. 230 if any(k.startswith("module.") for k in encoder_state.keys()): 231 encoder_state = OrderedDict( 232 {k[len("module."):]: v for k, v in encoder_state.items()} 233 ) 234 # Extract encoder weights from CellposeSAM checkpoint format (strip 'encoder.' prefix). 235 if any(k.startswith("encoder.") for k in encoder_state.keys()): 236 encoder_state = OrderedDict( 237 {k[len("encoder."):]: v for k, v in encoder_state.items() if k.startswith("encoder.")} 238 ) 239 240 elif backbone == "sam2" and isinstance(encoder, str): 241 # If we have a SAM2 encoder, then we first try to load the full SAM2 Model. 242 # (using micro_sam2) and otherwise fall back on directly loading the encoder state 243 # from the checkpoint 244 try: 245 model = get_sam2_model(model_type=encoder, checkpoint_path=checkpoint) 246 encoder_state = model.image_encoder.state_dict() 247 except Exception: 248 # Try loading the encoder state directly from a checkpoint. 249 encoder_state = torch.load(checkpoint, weights_only=False) 250 251 elif backbone == "sam3" and isinstance(encoder, str): 252 # If we have a SAM3 encoder, then we first try to load the full SAM3 Model. 253 # (using micro_sam3) and otherwise fall back on directly loading the encoder state 254 # from the checkpoint 255 try: 256 model = get_sam3_model(checkpoint_path=checkpoint) 257 encoder_state = model.backbone.vision_backbone.state_dict() 258 # Let's align loading the encoder weights with expected parameter names 259 encoder_state = { 260 k[len("trunk."):] if k.startswith("trunk.") else k: v for k, v in encoder_state.items() 261 } 262 # And drop the 'convs' and 'sam2_convs' - these seem like some upsampling blocks. 263 encoder_state = { 264 k: v for k, v in encoder_state.items() 265 if not (k.startswith("convs.") or k.startswith("sam2_convs.")) 266 } 267 except Exception: 268 # Try loading the encoder state directly from a checkpoint. 269 encoder_state = torch.load(checkpoint, weights_only=False) 270 271 elif backbone == "mae": 272 # vit initialization hints from: 273 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 274 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 275 encoder_state = OrderedDict({ 276 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 277 }) 278 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 279 current_encoder_state = self.encoder.state_dict() 280 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 281 del self.encoder.head 282 283 elif backbone == "scalemae": 284 # Load the encoder state directly from a checkpoint. 285 encoder_state = torch.load(checkpoint)["model"] 286 encoder_state = OrderedDict({ 287 k: v for k, v in encoder_state.items() 288 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 289 }) 290 291 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 292 current_encoder_state = self.encoder.state_dict() 293 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 294 del self.encoder.head 295 296 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 297 del self.encoder.pos_embed 298 299 elif backbone in ["dinov2", "dinov3"]: # Load the encoder state directly from a checkpoint. 300 encoder_state = torch.load(checkpoint) 301 302 else: 303 raise ValueError( 304 f"We don't support either the '{backbone}' backbone or the '{encoder}' model combination (or both)." 305 ) 306 307 else: 308 encoder_state = checkpoint 309 310 self.encoder.load_state_dict(encoder_state) 311 312 def _get_activation(self, activation): 313 return_activation = None 314 if activation is None: 315 return None 316 if isinstance(activation, nn.Module): 317 return activation 318 if isinstance(activation, str): 319 return_activation = getattr(nn, activation, None) 320 if return_activation is None: 321 raise ValueError(f"Invalid activation: {activation}") 322 323 return return_activation() 324 325 @staticmethod 326 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 327 """Compute the output size given input size and target long side length. 328 329 Args: 330 oldh: The input image height. 331 oldw: The input image width. 332 long_side_length: The longest side length for resizing. 333 334 Returns: 335 The new image height. 336 The new image width. 337 """ 338 scale = long_side_length * 1.0 / max(oldh, oldw) 339 newh, neww = oldh * scale, oldw * scale 340 neww = int(neww + 0.5) 341 newh = int(newh + 0.5) 342 return (newh, neww) 343 344 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 345 """Resize the image so that the longest side has the correct length. 346 347 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 348 349 Args: 350 image: The input image. 351 352 Returns: 353 The resized image. 354 """ 355 if image.ndim == 4: # i.e. 2d image 356 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 357 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 358 elif image.ndim == 5: # i.e. 3d volume 359 B, C, Z, H, W = image.shape 360 target_size = self.get_preprocess_shape(H, W, self.img_size) 361 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 362 else: 363 raise ValueError("Expected 4d or 5d inputs, got", image.shape) 364 365 def _as_stats(self, mean, std, device, dtype, is_3d: bool): 366 """@private 367 """ 368 return _as_stats(mean, std, device, dtype, is_3d) 369 370 def _check_input_normalization_range(self, x: torch.Tensor, expected_range: Optional[Tuple[float, float]]) -> None: 371 """@private 372 """ 373 _check_input_normalization_range(x, expected_range) 374 375 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 376 """@private 377 """ 378 return preprocess_vit_inputs( 379 x, 380 use_sam_stats=self.use_sam_stats, 381 backbone=self.backbone, 382 use_mae_stats=self.use_mae_stats, 383 use_dino_stats=self.use_dino_stats, 384 resize_input=self.resize_input, 385 img_size=self.img_size, 386 encoder_img_size=self.encoder.img_size, 387 perform_range_checks=self.perform_range_checks, 388 ) 389 390 def postprocess_masks( 391 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 392 ) -> torch.Tensor: 393 """@private 394 """ 395 if masks.ndim == 4: # i.e. 2d labels 396 masks = F.interpolate( 397 masks, 398 (self.encoder.img_size, self.encoder.img_size), 399 mode="bilinear", 400 align_corners=False, 401 ) 402 masks = masks[..., : input_size[0], : input_size[1]] 403 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 404 405 elif masks.ndim == 5: # i.e. 3d volumetric labels 406 masks = F.interpolate( 407 masks, 408 (input_size[0], self.img_size, self.img_size), 409 mode="trilinear", 410 align_corners=False, 411 ) 412 masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] 413 masks = F.interpolate(masks, original_size, mode="trilinear", align_corners=False) 414 415 else: 416 raise ValueError("Expected 4d or 5d labels, got", masks.shape) 417 418 return masks
Base class for implementing a UNETR.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3" (see all combinations below)
- encoder: The vision transformer. Can either be a name, such as "vit_b" (see all combinations for this below) or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM / SAM2 / SAM3 model.
- use_dino_stats: Whether to normalize the input data with the statistics of the pretrained DINOv2 / DINOv3 model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size. By default, it resizes the inputs to match theimg_size. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
- perform_range_checks: Whether to validate the input value range before normalization on each forward pass. You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct.
- NOTE: The currently supported combinations of 'backbone' x 'encoder' are the following:
- SAM_family_models: - 'sam' x 'vit_b'
- 'sam' x 'vit_l'
- 'sam' x 'vit_h'
- 'sam2' x 'hvit_t'
- 'sam2' x 'hvit_s'
- 'sam2' x 'hvit_b'
- 'sam2' x 'hvit_l'
- 'sam3' x 'vit_pe'
- 'cellpose_sam' x 'vit_l'
- DINO_family_models: - 'dinov2' x 'vit_s'
- 'dinov2' x 'vit_b'
- 'dinov2' x 'vit_l'
- 'dinov2' x 'vit_g'
- 'dinov2' x 'vit_s_reg4'
- 'dinov2' x 'vit_b_reg4'
- 'dinov2' x 'vit_l_reg4'
- 'dinov2' x 'vit_g_reg4'
- 'dinov3' x 'vit_s'
- 'dinov3' x 'vit_s+'
- 'dinov3' x 'vit_b'
- 'dinov3' x 'vit_l'
- 'dinov3' x 'vit_l+'
- 'dinov3' x 'vit_h+'
- 'dinov3' x 'vit_7b'
- MAE_family_models: - 'mae' x 'vit_b'
- 'mae' x 'vit_l'
- 'mae' x 'vit_h'
- 'scalemae' x 'vit_b'
- 'scalemae' x 'vit_l'
- 'scalemae' x 'vit_h'
148 def __init__( 149 self, 150 img_size: int = 1024, 151 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 152 encoder: Optional[Union[nn.Module, str]] = "vit_b", 153 decoder: Optional[nn.Module] = None, 154 out_channels: int = 1, 155 use_sam_stats: bool = False, 156 use_mae_stats: bool = False, 157 use_dino_stats: bool = False, 158 resize_input: bool = True, 159 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 160 final_activation: Optional[Union[str, nn.Module]] = None, 161 use_skip_connection: bool = True, 162 embed_dim: Optional[int] = None, 163 use_conv_transpose: bool = False, 164 perform_range_checks: bool = True, 165 **kwargs 166 ) -> None: 167 super().__init__() 168 169 self.img_size = img_size 170 self.use_sam_stats = use_sam_stats 171 self.use_mae_stats = use_mae_stats 172 self.use_dino_stats = use_dino_stats 173 self.use_skip_connection = use_skip_connection 174 self.resize_input = resize_input 175 self.perform_range_checks = perform_range_checks 176 self.use_conv_transpose = use_conv_transpose 177 self.backbone = backbone 178 179 if isinstance(encoder, str): # e.g. "vit_b" / "hvit_b" / "vit_pe" 180 print(f"Using {encoder} from {backbone.upper()}") 181 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 182 183 if encoder_checkpoint is not None: 184 self._load_encoder_from_checkpoint(backbone=backbone, encoder=encoder, checkpoint=encoder_checkpoint) 185 186 if embed_dim is None: 187 embed_dim = self.encoder.embed_dim 188 189 # For SAM1 encoder, if 'apply_neck' is applied, the embedding dimension must change. 190 if hasattr(self.encoder, "apply_neck") and self.encoder.apply_neck: 191 embed_dim = self.encoder.neck[2].out_channels # the value is 256 192 193 else: # `nn.Module` ViT backbone 194 self.encoder = encoder 195 196 have_neck = False 197 for name, _ in self.encoder.named_parameters(): 198 if name.startswith("neck"): 199 have_neck = True 200 201 if embed_dim is None: 202 if have_neck: 203 embed_dim = self.encoder.neck[2].out_channels # the value is 256 204 else: 205 embed_dim = self.encoder.patch_embed.proj.out_channels 206 207 self.embed_dim = embed_dim 208 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
325 @staticmethod 326 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 327 """Compute the output size given input size and target long side length. 328 329 Args: 330 oldh: The input image height. 331 oldw: The input image width. 332 long_side_length: The longest side length for resizing. 333 334 Returns: 335 The new image height. 336 The new image width. 337 """ 338 scale = long_side_length * 1.0 / max(oldh, oldw) 339 newh, neww = oldh * scale, oldw * scale 340 neww = int(neww + 0.5) 341 newh = int(newh + 0.5) 342 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
344 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 345 """Resize the image so that the longest side has the correct length. 346 347 Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format. 348 349 Args: 350 image: The input image. 351 352 Returns: 353 The resized image. 354 """ 355 if image.ndim == 4: # i.e. 2d image 356 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 357 return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 358 elif image.ndim == 5: # i.e. 3d volume 359 B, C, Z, H, W = image.shape 360 target_size = self.get_preprocess_shape(H, W, self.img_size) 361 return F.interpolate(image, (Z, *target_size), mode="trilinear", align_corners=False) 362 else: 363 raise ValueError("Expected 4d or 5d inputs, got", image.shape)
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW OR BxCxDxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
421def preprocess_vit_inputs( 422 x: torch.Tensor, 423 use_sam_stats: bool = False, 424 backbone: str = "sam", 425 use_mae_stats: bool = False, 426 use_dino_stats: bool = False, 427 resize_input: bool = True, 428 img_size: int = 1024, 429 encoder_img_size: int = 1024, 430 perform_range_checks: bool = True, 431) -> Tuple[torch.Tensor, Tuple]: 432 """Preprocess inputs for ViT-backbones in UNETR models. 433 434 Handles normalization stat selection, input range validation, optional resizing to the longest side, 435 and padding to `encoder_img_size`. Can be used as a standalone function without a model instance. 436 437 Args: 438 x: Input tensor of shape (B, C, H, W) for 2D or (B, C, Z, H, W) for 3D. 439 use_sam_stats: Whether to normalize with SAM/SAM2/SAM3 backbone statistics. 440 backbone: The backbone name - controls which SAM stats are used when `use_sam_stats=True`. 441 use_mae_stats: Whether to normalize with MAE statistics. 442 use_dino_stats: Whether to normalize with DINOv2/DINOv3 statistics. 443 resize_input: Whether to resize the input to the longest side before padding. 444 img_size: The model image size, used for 3D resize. 445 encoder_img_size: The encoder image size, used for 2D resize and padding. 446 perform_range_checks: Whether to validate the expected input value range before normalization. 447 You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct. 448 449 Returns: 450 The preprocessed tensor and the spatial shape after resizing (before padding). 451 """ 452 is_3d = (x.ndim == 5) 453 device, dtype = x.device, x.dtype 454 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 455 expected_range = None 456 unit_scale_max = None 457 458 if use_sam_stats: 459 if backbone == "sam2": 460 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 461 expected_range = (0.0, 1.0) 462 elif backbone == "sam3": 463 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 464 expected_range = (0.0, 1.0) 465 else: # sam1 / default 466 mean, std = (123.675, 116.28, 103.53), (58.395, 57.12, 57.375) 467 expected_range = (0.0, 255.0) 468 unit_scale_max = 1.0 469 elif use_mae_stats: # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 470 raise NotImplementedError 471 elif use_dino_stats: 472 mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 473 expected_range = (0.0, 1.0) 474 else: 475 mean, std = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0) 476 expected_range = None 477 478 if perform_range_checks: 479 _check_input_normalization_range(x, expected_range, unit_scale_max) 480 pixel_mean, pixel_std = _as_stats(mean, std, device=device, dtype=dtype, is_3d=is_3d) 481 482 if resize_input: 483 if x.ndim == 4: 484 target_size = UNETRBase.get_preprocess_shape(x.shape[2], x.shape[3], encoder_img_size) 485 x = F.interpolate(x, target_size, mode="bilinear", align_corners=False, antialias=True) 486 elif x.ndim == 5: 487 B, C, Z, H, W = x.shape 488 target_size = UNETRBase.get_preprocess_shape(H, W, img_size) 489 x = F.interpolate(x, (Z, *target_size), mode="trilinear", align_corners=False) 490 491 input_shape = x.shape[-3:] if is_3d else x.shape[-2:] 492 493 x = (x - pixel_mean) / pixel_std 494 h, w = x.shape[-2:] 495 padh = encoder_img_size - h 496 padw = encoder_img_size - w 497 498 if is_3d: 499 x = F.pad(x, (0, padw, 0, padh, 0, 0)) 500 else: 501 x = F.pad(x, (0, padw, 0, padh)) 502 503 return x, input_shape
Preprocess inputs for ViT-backbones in UNETR models.
Handles normalization stat selection, input range validation, optional resizing to the longest side,
and padding to encoder_img_size. Can be used as a standalone function without a model instance.
Arguments:
- x: Input tensor of shape (B, C, H, W) for 2D or (B, C, Z, H, W) for 3D.
- use_sam_stats: Whether to normalize with SAM/SAM2/SAM3 backbone statistics.
- backbone: The backbone name - controls which SAM stats are used when
use_sam_stats=True. - use_mae_stats: Whether to normalize with MAE statistics.
- use_dino_stats: Whether to normalize with DINOv2/DINOv3 statistics.
- resize_input: Whether to resize the input to the longest side before padding.
- img_size: The model image size, used for 3D resize.
- encoder_img_size: The encoder image size, used for 2D resize and padding.
- perform_range_checks: Whether to validate the expected input value range before normalization. You can disable the checks to avoid GPU sync overhead during training when inputs are known to be correct.
Returns:
The preprocessed tensor and the spatial shape after resizing (before padding).
506class UNETR(UNETRBase): 507 """A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder. 508 """ 509 def __init__( 510 self, 511 img_size: int = 1024, 512 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 513 encoder: Optional[Union[nn.Module, str]] = "vit_b", 514 decoder: Optional[nn.Module] = None, 515 out_channels: int = 1, 516 use_sam_stats: bool = False, 517 use_mae_stats: bool = False, 518 use_dino_stats: bool = False, 519 resize_input: bool = True, 520 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 521 final_activation: Optional[Union[str, nn.Module]] = None, 522 use_skip_connection: bool = True, 523 embed_dim: Optional[int] = None, 524 use_conv_transpose: bool = False, 525 perform_range_checks: bool = True, 526 **kwargs 527 ) -> None: 528 529 super().__init__( 530 img_size=img_size, 531 backbone=backbone, 532 encoder=encoder, 533 decoder=decoder, 534 out_channels=out_channels, 535 use_sam_stats=use_sam_stats, 536 use_mae_stats=use_mae_stats, 537 use_dino_stats=use_dino_stats, 538 resize_input=resize_input, 539 encoder_checkpoint=encoder_checkpoint, 540 final_activation=final_activation, 541 use_skip_connection=use_skip_connection, 542 embed_dim=embed_dim, 543 use_conv_transpose=use_conv_transpose, 544 perform_range_checks=perform_range_checks, 545 **kwargs, 546 ) 547 548 encoder = self.encoder 549 550 if backbone == "sam2" and hasattr(encoder, "trunk"): 551 in_chans = encoder.trunk.patch_embed.proj.in_channels 552 elif hasattr(encoder, "in_chans"): 553 in_chans = encoder.in_chans 554 else: # `nn.Module` ViT backbone. 555 try: 556 in_chans = encoder.patch_embed.proj.in_channels 557 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 558 in_chans = encoder.patch_embed.seq[0].c.in_channels 559 560 # parameters for the decoder network 561 depth = 3 562 initial_features = 64 563 gain = 2 564 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 565 scale_factors = depth * [2] 566 self.out_channels = out_channels 567 568 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 569 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 570 571 self.decoder = decoder or Decoder( 572 features=features_decoder, 573 scale_factors=scale_factors[::-1], 574 conv_block_impl=ConvBlock2d, 575 sampler_impl=_upsampler, 576 ) 577 578 if use_skip_connection: 579 self.deconv1 = Deconv2DBlock( 580 in_channels=self.embed_dim, 581 out_channels=features_decoder[0], 582 use_conv_transpose=use_conv_transpose, 583 ) 584 self.deconv2 = nn.Sequential( 585 Deconv2DBlock( 586 in_channels=self.embed_dim, 587 out_channels=features_decoder[0], 588 use_conv_transpose=use_conv_transpose, 589 ), 590 Deconv2DBlock( 591 in_channels=features_decoder[0], 592 out_channels=features_decoder[1], 593 use_conv_transpose=use_conv_transpose, 594 ) 595 ) 596 self.deconv3 = nn.Sequential( 597 Deconv2DBlock( 598 in_channels=self.embed_dim, 599 out_channels=features_decoder[0], 600 use_conv_transpose=use_conv_transpose, 601 ), 602 Deconv2DBlock( 603 in_channels=features_decoder[0], 604 out_channels=features_decoder[1], 605 use_conv_transpose=use_conv_transpose, 606 ), 607 Deconv2DBlock( 608 in_channels=features_decoder[1], 609 out_channels=features_decoder[2], 610 use_conv_transpose=use_conv_transpose, 611 ) 612 ) 613 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 614 else: 615 self.deconv1 = Deconv2DBlock( 616 in_channels=self.embed_dim, 617 out_channels=features_decoder[0], 618 use_conv_transpose=use_conv_transpose, 619 ) 620 self.deconv2 = Deconv2DBlock( 621 in_channels=features_decoder[0], 622 out_channels=features_decoder[1], 623 use_conv_transpose=use_conv_transpose, 624 ) 625 self.deconv3 = Deconv2DBlock( 626 in_channels=features_decoder[1], 627 out_channels=features_decoder[2], 628 use_conv_transpose=use_conv_transpose, 629 ) 630 self.deconv4 = Deconv2DBlock( 631 in_channels=features_decoder[2], 632 out_channels=features_decoder[3], 633 use_conv_transpose=use_conv_transpose, 634 ) 635 636 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 637 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 638 self.deconv_out = _upsampler( 639 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 640 ) 641 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 642 643 def forward(self, x: torch.Tensor) -> torch.Tensor: 644 """Apply the UNETR to the input data. 645 646 Args: 647 x: The input tensor. 648 649 Returns: 650 The UNETR output. 651 """ 652 original_shape = x.shape[-2:] 653 654 # Reshape the inputs to the shape expected by the encoder 655 # and normalize the inputs if normalization is part of the model. 656 x, input_shape = self.preprocess(x) 657 658 encoder_outputs = self.encoder(x) 659 660 if isinstance(encoder_outputs[-1], list): 661 # `encoder_outputs` can be arranged in only two forms: 662 # - either we only return the image embeddings 663 # - or, we return the image embeddings and the "list" of global attention layers 664 z12, from_encoder = encoder_outputs 665 else: 666 z12 = encoder_outputs 667 668 if self.use_skip_connection: 669 from_encoder = from_encoder[::-1] 670 z9 = self.deconv1(from_encoder[0]) 671 z6 = self.deconv2(from_encoder[1]) 672 z3 = self.deconv3(from_encoder[2]) 673 z0 = self.deconv4(x) 674 675 else: 676 z9 = self.deconv1(z12) 677 z6 = self.deconv2(z9) 678 z3 = self.deconv3(z6) 679 z0 = self.deconv4(z3) 680 681 updated_from_encoder = [z9, z6, z3] 682 683 x = self.base(z12) 684 x = self.decoder(x, encoder_inputs=updated_from_encoder) 685 x = self.deconv_out(x) 686 687 x = torch.cat([x, z0], dim=1) 688 x = self.decoder_head(x) 689 690 x = self.out_conv(x) 691 if self.final_activation is not None: 692 x = self.final_activation(x) 693 694 x = self.postprocess_masks(x, input_shape, original_shape) 695 return x
A (2d-only) UNet Transformer using a vision transformer as encoder and a convolutional decoder.
509 def __init__( 510 self, 511 img_size: int = 1024, 512 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 513 encoder: Optional[Union[nn.Module, str]] = "vit_b", 514 decoder: Optional[nn.Module] = None, 515 out_channels: int = 1, 516 use_sam_stats: bool = False, 517 use_mae_stats: bool = False, 518 use_dino_stats: bool = False, 519 resize_input: bool = True, 520 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 521 final_activation: Optional[Union[str, nn.Module]] = None, 522 use_skip_connection: bool = True, 523 embed_dim: Optional[int] = None, 524 use_conv_transpose: bool = False, 525 perform_range_checks: bool = True, 526 **kwargs 527 ) -> None: 528 529 super().__init__( 530 img_size=img_size, 531 backbone=backbone, 532 encoder=encoder, 533 decoder=decoder, 534 out_channels=out_channels, 535 use_sam_stats=use_sam_stats, 536 use_mae_stats=use_mae_stats, 537 use_dino_stats=use_dino_stats, 538 resize_input=resize_input, 539 encoder_checkpoint=encoder_checkpoint, 540 final_activation=final_activation, 541 use_skip_connection=use_skip_connection, 542 embed_dim=embed_dim, 543 use_conv_transpose=use_conv_transpose, 544 perform_range_checks=perform_range_checks, 545 **kwargs, 546 ) 547 548 encoder = self.encoder 549 550 if backbone == "sam2" and hasattr(encoder, "trunk"): 551 in_chans = encoder.trunk.patch_embed.proj.in_channels 552 elif hasattr(encoder, "in_chans"): 553 in_chans = encoder.in_chans 554 else: # `nn.Module` ViT backbone. 555 try: 556 in_chans = encoder.patch_embed.proj.in_channels 557 except AttributeError: # for getting the input channels while using 'vit_t' from MobileSam 558 in_chans = encoder.patch_embed.seq[0].c.in_channels 559 560 # parameters for the decoder network 561 depth = 3 562 initial_features = 64 563 gain = 2 564 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 565 scale_factors = depth * [2] 566 self.out_channels = out_channels 567 568 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 569 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 570 571 self.decoder = decoder or Decoder( 572 features=features_decoder, 573 scale_factors=scale_factors[::-1], 574 conv_block_impl=ConvBlock2d, 575 sampler_impl=_upsampler, 576 ) 577 578 if use_skip_connection: 579 self.deconv1 = Deconv2DBlock( 580 in_channels=self.embed_dim, 581 out_channels=features_decoder[0], 582 use_conv_transpose=use_conv_transpose, 583 ) 584 self.deconv2 = nn.Sequential( 585 Deconv2DBlock( 586 in_channels=self.embed_dim, 587 out_channels=features_decoder[0], 588 use_conv_transpose=use_conv_transpose, 589 ), 590 Deconv2DBlock( 591 in_channels=features_decoder[0], 592 out_channels=features_decoder[1], 593 use_conv_transpose=use_conv_transpose, 594 ) 595 ) 596 self.deconv3 = nn.Sequential( 597 Deconv2DBlock( 598 in_channels=self.embed_dim, 599 out_channels=features_decoder[0], 600 use_conv_transpose=use_conv_transpose, 601 ), 602 Deconv2DBlock( 603 in_channels=features_decoder[0], 604 out_channels=features_decoder[1], 605 use_conv_transpose=use_conv_transpose, 606 ), 607 Deconv2DBlock( 608 in_channels=features_decoder[1], 609 out_channels=features_decoder[2], 610 use_conv_transpose=use_conv_transpose, 611 ) 612 ) 613 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 614 else: 615 self.deconv1 = Deconv2DBlock( 616 in_channels=self.embed_dim, 617 out_channels=features_decoder[0], 618 use_conv_transpose=use_conv_transpose, 619 ) 620 self.deconv2 = Deconv2DBlock( 621 in_channels=features_decoder[0], 622 out_channels=features_decoder[1], 623 use_conv_transpose=use_conv_transpose, 624 ) 625 self.deconv3 = Deconv2DBlock( 626 in_channels=features_decoder[1], 627 out_channels=features_decoder[2], 628 use_conv_transpose=use_conv_transpose, 629 ) 630 self.deconv4 = Deconv2DBlock( 631 in_channels=features_decoder[2], 632 out_channels=features_decoder[3], 633 use_conv_transpose=use_conv_transpose, 634 ) 635 636 self.base = ConvBlock2d(self.embed_dim, features_decoder[0]) 637 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 638 self.deconv_out = _upsampler( 639 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 640 ) 641 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
643 def forward(self, x: torch.Tensor) -> torch.Tensor: 644 """Apply the UNETR to the input data. 645 646 Args: 647 x: The input tensor. 648 649 Returns: 650 The UNETR output. 651 """ 652 original_shape = x.shape[-2:] 653 654 # Reshape the inputs to the shape expected by the encoder 655 # and normalize the inputs if normalization is part of the model. 656 x, input_shape = self.preprocess(x) 657 658 encoder_outputs = self.encoder(x) 659 660 if isinstance(encoder_outputs[-1], list): 661 # `encoder_outputs` can be arranged in only two forms: 662 # - either we only return the image embeddings 663 # - or, we return the image embeddings and the "list" of global attention layers 664 z12, from_encoder = encoder_outputs 665 else: 666 z12 = encoder_outputs 667 668 if self.use_skip_connection: 669 from_encoder = from_encoder[::-1] 670 z9 = self.deconv1(from_encoder[0]) 671 z6 = self.deconv2(from_encoder[1]) 672 z3 = self.deconv3(from_encoder[2]) 673 z0 = self.deconv4(x) 674 675 else: 676 z9 = self.deconv1(z12) 677 z6 = self.deconv2(z9) 678 z3 = self.deconv3(z6) 679 z0 = self.deconv4(z3) 680 681 updated_from_encoder = [z9, z6, z3] 682 683 x = self.base(z12) 684 x = self.decoder(x, encoder_inputs=updated_from_encoder) 685 x = self.deconv_out(x) 686 687 x = torch.cat([x, z0], dim=1) 688 x = self.decoder_head(x) 689 690 x = self.out_conv(x) 691 if self.final_activation is not None: 692 x = self.final_activation(x) 693 694 x = self.postprocess_masks(x, input_shape, original_shape) 695 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.
698class UNETR2D(UNETR): 699 """A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 700 """ 701 pass
A two-dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
Inherited Members
704class UNETR3D(UNETRBase): 705 """A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder. 706 """ 707 def __init__( 708 self, 709 img_size: int = 1024, 710 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 711 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 712 decoder: Optional[nn.Module] = None, 713 out_channels: int = 1, 714 use_sam_stats: bool = False, 715 use_mae_stats: bool = False, 716 use_dino_stats: bool = False, 717 resize_input: bool = True, 718 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 719 final_activation: Optional[Union[str, nn.Module]] = None, 720 use_skip_connection: bool = False, 721 embed_dim: Optional[int] = None, 722 use_conv_transpose: bool = False, 723 use_strip_pooling: bool = True, 724 perform_range_checks: bool = True, 725 **kwargs 726 ): 727 if use_skip_connection: 728 raise NotImplementedError("The framework cannot handle skip connections atm.") 729 if use_conv_transpose: 730 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 731 732 # Sort the `embed_dim` out 733 embed_dim = 256 if embed_dim is None else embed_dim 734 735 super().__init__( 736 img_size=img_size, 737 backbone=backbone, 738 encoder=encoder, 739 decoder=decoder, 740 out_channels=out_channels, 741 use_sam_stats=use_sam_stats, 742 use_mae_stats=use_mae_stats, 743 use_dino_stats=use_dino_stats, 744 resize_input=resize_input, 745 encoder_checkpoint=encoder_checkpoint, 746 final_activation=final_activation, 747 use_skip_connection=use_skip_connection, 748 embed_dim=embed_dim, 749 use_conv_transpose=use_conv_transpose, 750 perform_range_checks=perform_range_checks, 751 **kwargs, 752 ) 753 754 # The 3d convolutional decoder. 755 # First, get the important parameters for the decoder. 756 depth = 3 757 initial_features = 64 758 gain = 2 759 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 760 scale_factors = [1, 2, 2] 761 self.out_channels = out_channels 762 763 # The mapping blocks. 764 self.deconv1 = Deconv3DBlock( 765 in_channels=embed_dim, 766 out_channels=features_decoder[0], 767 scale_factor=scale_factors, 768 use_strip_pooling=use_strip_pooling, 769 ) 770 self.deconv2 = Deconv3DBlock( 771 in_channels=features_decoder[0], 772 out_channels=features_decoder[1], 773 scale_factor=scale_factors, 774 use_strip_pooling=use_strip_pooling, 775 ) 776 self.deconv3 = Deconv3DBlock( 777 in_channels=features_decoder[1], 778 out_channels=features_decoder[2], 779 scale_factor=scale_factors, 780 use_strip_pooling=use_strip_pooling, 781 ) 782 self.deconv4 = Deconv3DBlock( 783 in_channels=features_decoder[2], 784 out_channels=features_decoder[3], 785 scale_factor=scale_factors, 786 use_strip_pooling=use_strip_pooling, 787 ) 788 789 # The core decoder block. 790 self.decoder = decoder or Decoder( 791 features=features_decoder, 792 scale_factors=[scale_factors] * depth, 793 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 794 sampler_impl=Upsampler3d, 795 ) 796 797 # And the final upsampler to match the expected dimensions. 798 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 799 in_channels=features_decoder[-1], 800 out_channels=features_decoder[-1], 801 scale_factor=scale_factors, 802 use_strip_pooling=use_strip_pooling, 803 ) 804 805 # Additional conjunction blocks. 806 self.base = ConvBlock3dWithStrip( 807 in_channels=embed_dim, 808 out_channels=features_decoder[0], 809 use_strip_pooling=use_strip_pooling, 810 ) 811 812 # And the output layers. 813 self.decoder_head = ConvBlock3dWithStrip( 814 in_channels=2 * features_decoder[-1], 815 out_channels=features_decoder[-1], 816 use_strip_pooling=use_strip_pooling, 817 ) 818 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1) 819 820 def forward(self, x: torch.Tensor): 821 """Forward pass of the UNETR-3D model. 822 823 Args: 824 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 825 826 Returns: 827 The UNETR output. 828 """ 829 B, C, Z, H, W = x.shape 830 original_shape = (Z, H, W) 831 832 # Preprocessing step 833 x, input_shape = self.preprocess(x) 834 835 # Run the image encoder. 836 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 837 838 # Prepare the counterparts for the decoder. 839 # NOTE: The section below is sequential, there's no skip connections atm. 840 z9 = self.deconv1(curr_features) 841 z6 = self.deconv2(z9) 842 z3 = self.deconv3(z6) 843 z0 = self.deconv4(z3) 844 845 updated_from_encoder = [z9, z6, z3] 846 847 # Align the features through the base block. 848 x = self.base(curr_features) 849 # Run the decoder 850 x = self.decoder(x, encoder_inputs=updated_from_encoder) 851 x = self.deconv_out(x) # NOTE before `end_up` 852 853 # And the final output head. 854 x = torch.cat([x, z0], dim=1) 855 x = self.decoder_head(x) 856 x = self.out_conv(x) 857 if self.final_activation is not None: 858 x = self.final_activation(x) 859 860 # Postprocess the output back to original size. 861 x = self.postprocess_masks(x, input_shape, original_shape) 862 return x
A three dimensional UNet Transformer using a vision transformer as encoder and a convolutional decoder.
707 def __init__( 708 self, 709 img_size: int = 1024, 710 backbone: Literal["sam", "sam2", "sam3", "cellpose_sam", "mae", "scalemae", "dinov2", "dinov3"] = "sam", 711 encoder: Optional[Union[nn.Module, str]] = "hvit_b", 712 decoder: Optional[nn.Module] = None, 713 out_channels: int = 1, 714 use_sam_stats: bool = False, 715 use_mae_stats: bool = False, 716 use_dino_stats: bool = False, 717 resize_input: bool = True, 718 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 719 final_activation: Optional[Union[str, nn.Module]] = None, 720 use_skip_connection: bool = False, 721 embed_dim: Optional[int] = None, 722 use_conv_transpose: bool = False, 723 use_strip_pooling: bool = True, 724 perform_range_checks: bool = True, 725 **kwargs 726 ): 727 if use_skip_connection: 728 raise NotImplementedError("The framework cannot handle skip connections atm.") 729 if use_conv_transpose: 730 raise NotImplementedError("It's not enabled to switch between interpolation and transposed convolutions.") 731 732 # Sort the `embed_dim` out 733 embed_dim = 256 if embed_dim is None else embed_dim 734 735 super().__init__( 736 img_size=img_size, 737 backbone=backbone, 738 encoder=encoder, 739 decoder=decoder, 740 out_channels=out_channels, 741 use_sam_stats=use_sam_stats, 742 use_mae_stats=use_mae_stats, 743 use_dino_stats=use_dino_stats, 744 resize_input=resize_input, 745 encoder_checkpoint=encoder_checkpoint, 746 final_activation=final_activation, 747 use_skip_connection=use_skip_connection, 748 embed_dim=embed_dim, 749 use_conv_transpose=use_conv_transpose, 750 perform_range_checks=perform_range_checks, 751 **kwargs, 752 ) 753 754 # The 3d convolutional decoder. 755 # First, get the important parameters for the decoder. 756 depth = 3 757 initial_features = 64 758 gain = 2 759 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 760 scale_factors = [1, 2, 2] 761 self.out_channels = out_channels 762 763 # The mapping blocks. 764 self.deconv1 = Deconv3DBlock( 765 in_channels=embed_dim, 766 out_channels=features_decoder[0], 767 scale_factor=scale_factors, 768 use_strip_pooling=use_strip_pooling, 769 ) 770 self.deconv2 = Deconv3DBlock( 771 in_channels=features_decoder[0], 772 out_channels=features_decoder[1], 773 scale_factor=scale_factors, 774 use_strip_pooling=use_strip_pooling, 775 ) 776 self.deconv3 = Deconv3DBlock( 777 in_channels=features_decoder[1], 778 out_channels=features_decoder[2], 779 scale_factor=scale_factors, 780 use_strip_pooling=use_strip_pooling, 781 ) 782 self.deconv4 = Deconv3DBlock( 783 in_channels=features_decoder[2], 784 out_channels=features_decoder[3], 785 scale_factor=scale_factors, 786 use_strip_pooling=use_strip_pooling, 787 ) 788 789 # The core decoder block. 790 self.decoder = decoder or Decoder( 791 features=features_decoder, 792 scale_factors=[scale_factors] * depth, 793 conv_block_impl=partial(ConvBlock3dWithStrip, use_strip_pooling=use_strip_pooling), 794 sampler_impl=Upsampler3d, 795 ) 796 797 # And the final upsampler to match the expected dimensions. 798 self.deconv_out = Deconv3DBlock( # NOTE: changed `end_up` to `deconv_out` 799 in_channels=features_decoder[-1], 800 out_channels=features_decoder[-1], 801 scale_factor=scale_factors, 802 use_strip_pooling=use_strip_pooling, 803 ) 804 805 # Additional conjunction blocks. 806 self.base = ConvBlock3dWithStrip( 807 in_channels=embed_dim, 808 out_channels=features_decoder[0], 809 use_strip_pooling=use_strip_pooling, 810 ) 811 812 # And the output layers. 813 self.decoder_head = ConvBlock3dWithStrip( 814 in_channels=2 * features_decoder[-1], 815 out_channels=features_decoder[-1], 816 use_strip_pooling=use_strip_pooling, 817 ) 818 self.out_conv = nn.Conv3d(features_decoder[-1], out_channels, 1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
820 def forward(self, x: torch.Tensor): 821 """Forward pass of the UNETR-3D model. 822 823 Args: 824 x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs. 825 826 Returns: 827 The UNETR output. 828 """ 829 B, C, Z, H, W = x.shape 830 original_shape = (Z, H, W) 831 832 # Preprocessing step 833 x, input_shape = self.preprocess(x) 834 835 # Run the image encoder. 836 curr_features = torch.stack([self.encoder(x[:, :, i])[0] for i in range(Z)], dim=2) 837 838 # Prepare the counterparts for the decoder. 839 # NOTE: The section below is sequential, there's no skip connections atm. 840 z9 = self.deconv1(curr_features) 841 z6 = self.deconv2(z9) 842 z3 = self.deconv3(z6) 843 z0 = self.deconv4(z3) 844 845 updated_from_encoder = [z9, z6, z3] 846 847 # Align the features through the base block. 848 x = self.base(curr_features) 849 # Run the decoder 850 x = self.decoder(x, encoder_inputs=updated_from_encoder) 851 x = self.deconv_out(x) # NOTE before `end_up` 852 853 # And the final output head. 854 x = torch.cat([x, z0], dim=1) 855 x = self.decoder_head(x) 856 x = self.out_conv(x) 857 if self.final_activation is not None: 858 x = self.final_activation(x) 859 860 # Postprocess the output back to original size. 861 x = self.postprocess_masks(x, input_shape, original_shape) 862 return x
Forward pass of the UNETR-3D model.
Arguments:
- x: Inputs of expected shape (B, C, Z, Y, X), where Z considers flexible inputs.
Returns:
The UNETR output.