torch_em.model.vit
1import math 2from functools import partial 3from typing import Tuple, List 4 5import torch 6import torch.nn as nn 7 8# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should 9# only be optional dependencies for torch_em 10try: 11 from segment_anything.modeling import ImageEncoderViT 12 _sam_import_success = True 13except ImportError: 14 ImageEncoderViT = object 15 _sam_import_success = False 16 17try: 18 from timm.models.vision_transformer import VisionTransformer, PatchEmbed 19 _timm_import_success = True 20except ImportError: 21 VisionTransformer = object 22 PatchEmbed = object 23 _timm_import_success = False 24 25try: 26 from sam2.modeling.backbones.hieradet import Hiera 27 from sam2.modeling.position_encoding import PositionEmbeddingSine 28 from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck 29 _sam2_import_success = True 30except ImportError: 31 ImageEncoder = object 32 _sam2_import_success = False 33 34try: 35 from dinov2.models.vision_transformer import DinoVisionTransformer as DinoV2VisionTransformer 36 from dinov2.layers import MemEffAttention, NestedTensorBlock as Block 37 _dinov2_import_success = True 38except ImportError: 39 DinoV2VisionTransformer = object 40 _dinov2_import_success = False 41 42try: 43 from dinov3.models.vision_transformer import DinoVisionTransformer as DinoV3VisionTransformer 44 _dinov3_import_success = True 45except ImportError: 46 DinoV3VisionTransformer = object 47 _dinov3_import_success = False 48 49 50try: 51 from sam3.model.vitdet import ViT as SAM3ViT, get_abs_pos 52 _sam3_import_success = True 53except ImportError: 54 SAM3ViT = object 55 _sam3_import_success = False 56 57 58class ViT_Sam(ImageEncoderViT): 59 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 60 61 Based on: 62 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 63 64 Args: 65 in_chans: The number of input channels. 66 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 67 global_attn_indexes: The global attention indices. 68 kwargs: Keyword arguments for the image encoder base class. 69 """ 70 def __init__( 71 self, 72 in_chans: int = 3, 73 embed_dim: int = 768, 74 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 75 **kwargs, 76 ) -> None: 77 if not _sam_import_success: 78 raise RuntimeError( 79 "The vision transformer backend can only be initialized if segment anything is installed. " 80 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 81 "and then rerun your code." 82 ) 83 84 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 85 self.chunks_for_projection = global_attn_indexes 86 self.in_chans = in_chans 87 self.embed_dim = embed_dim 88 89 def forward(self, x: torch.Tensor) -> torch.Tensor: 90 """Apply the vision transformer to input data. 91 92 Args: 93 x: The input data. 94 95 Returns: 96 The vision transformer output. 97 """ 98 x = self.patch_embed(x) 99 if self.pos_embed is not None: 100 x = x + self.pos_embed 101 102 list_from_encoder = [] 103 for i, blk in enumerate(self.blocks): 104 x = blk(x) 105 if i in self.chunks_for_projection: 106 list_from_encoder.append(x) 107 108 x = x.permute(0, 3, 1, 2) 109 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 110 return x, list_from_encoder[:3] 111 112 113class ViT_CellposeSAM(nn.Module): 114 """Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x). 115 116 This replicates CellposeSAM's actual initialization: instantiate SAM's ``ImageEncoderViT`` via 117 ``sam_model_registry``, then modify the patch embedding, position embeddings, and set global attention. 118 This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading 119 without any interpolation. 120 121 Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py 122 123 NOTE: The pretrained CellposeSAM model uses ``vit_l`` exclusively. 124 125 Args: 126 ps: The patch size (default for CellposeSAM is 8). 127 bsize: The input image size (default for CellposeSAM is 256). 128 """ 129 def __init__(self, ps: int = 8, bsize: int = 256) -> None: 130 super().__init__() 131 132 if not _sam_import_success: 133 raise RuntimeError( 134 "The vision transformer backend can only be initialized if segment anything is installed. " 135 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 136 "and then rerun your code." 137 ) 138 139 from segment_anything import sam_model_registry 140 141 # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer). 142 encoder = sam_model_registry["vit_l"](None).image_encoder 143 144 w = encoder.patch_embed.proj.weight.detach() 145 nchan = w.shape[0] 146 147 # CellPoseSAM changes the patch size from 16 to 'ps'. 148 encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps) 149 encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps] 150 151 # Next, they subsample position embeddings for the new patch size and input resolution. 152 ds = (1024 // 16) // (bsize // ps) 153 encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True) 154 155 # Finally, they set all blocks to global attention. 156 for blk in encoder.blocks: 157 blk.window_size = 0 158 159 # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping). 160 self.patch_embed = encoder.patch_embed 161 self.pos_embed = encoder.pos_embed 162 self.blocks = encoder.blocks 163 self.neck = encoder.neck 164 165 # Additional attributes expected by UNETR. 166 self.embed_dim = nchan 167 self.img_size = bsize 168 self.in_chans = 3 169 170 # Feature extraction at evenly-spaced depths. 171 depth = len(self.blocks) 172 _chunks = depth // 4 173 self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1] 174 175 def forward(self, x: torch.Tensor) -> torch.Tensor: 176 """Apply the vision transformer to input data. 177 178 Args: 179 x: The input data. 180 181 Returns: 182 The vision transformer output. 183 """ 184 x = self.patch_embed(x) 185 if self.pos_embed is not None: 186 x = x + self.pos_embed 187 188 list_from_encoder = [] 189 for i, blk in enumerate(self.blocks): 190 x = blk(x) 191 if i in self.chunks_for_projection: 192 list_from_encoder.append(x) 193 194 x = x.permute(0, 3, 1, 2) 195 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 196 return x, list_from_encoder[:3] 197 198 199class ViT_MAE(VisionTransformer): 200 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 201 202 Based on: 203 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 204 205 Args: 206 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 207 in_chans: The number of input channels. 208 depth: The depth of the vision transformer. 209 kwargs: Additional keyword arguments for the vision transformer base class. 210 """ 211 def __init__( 212 self, 213 img_size: int = 1024, # chosen to match our experiments with segment anything 214 in_chans: int = 3, 215 depth: int = 12, 216 **kwargs 217 ): 218 if not _timm_import_success: 219 raise RuntimeError( 220 "The vision transformer backend can only be initialized if timm is installed. " 221 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 222 "and then rerun your code" 223 ) 224 super().__init__(img_size=img_size, depth=depth, **kwargs) 225 self.img_size = img_size 226 self.in_chans = in_chans 227 self.depth = depth 228 229 def convert_to_expected_dim(self, inputs_): 230 """@private 231 """ 232 inputs_ = inputs_[:, 1:, :] # removing the class tokens 233 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 234 rdim = inputs_.shape[1] 235 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 236 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 237 inputs_ = inputs_.permute(0, 3, 1, 2) 238 return inputs_ 239 240 def forward_features(self, x): 241 """@private 242 """ 243 B = x.shape[0] 244 x = self.patch_embed(x) 245 246 cls_tokens = self.cls_token.expand(B, -1, -1) 247 x = torch.cat((cls_tokens, x), dim=1) 248 249 x = x + self.pos_embed 250 x = self.pos_drop(x) 251 252 # chunks obtained for getting the projections for conjuctions with upsampling blocks 253 _chunks = int(self.depth / 4) 254 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 255 256 list_from_encoder = [] 257 for i, blk in enumerate(self.blocks): 258 x = blk(x) 259 if i in chunks_for_projection: 260 list_from_encoder.append(self.convert_to_expected_dim(x)) 261 262 x = self.convert_to_expected_dim(x) 263 return x, list_from_encoder[:3] 264 265 def forward(self, x: torch.Tensor) -> torch.Tensor: 266 """Apply the vision transformer to input data. 267 268 Args: 269 x: The input data. 270 271 Returns: 272 The vision transformer output. 273 """ 274 x, list_from_encoder = self.forward_features(x) 275 return x, list_from_encoder 276 277 278class ViT_Sam2(ImageEncoder): 279 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 280 281 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 282 283 Args: 284 backbone_channel_list: The channels throughout the entire backbone. 285 embed_dim: The initial embedding dimension. 286 num_heads: The initial number of heads. 287 stages: The number of blocks per stage. 288 global_att_blocks: The parameter to decide which blocks have global attention. 289 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 290 window_spec: The window size per stage, when not using global attention. 291 scalp: The count of lowest resolution features to discard. 292 """ 293 def __init__( 294 self, 295 backbone_channel_list: List[int], 296 img_size: int = 1024, 297 embed_dim: int = 96, 298 num_heads: int = 1, 299 stages: Tuple[int, ...] = (2, 3, 16, 3), 300 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 301 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 302 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 303 scalp: int = 1, 304 **kwargs 305 ): 306 if not _sam2_import_success: 307 raise RuntimeError( 308 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 309 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 310 "and then rerun your code" 311 ) 312 313 trunk = Hiera( 314 embed_dim=embed_dim, 315 num_heads=num_heads, 316 stages=stages, 317 global_att_blocks=global_att_blocks, 318 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 319 window_spec=window_spec, 320 ) 321 neck = FpnNeck( 322 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 323 d_model=256, 324 backbone_channel_list=backbone_channel_list, 325 fpn_top_down_levels=[2, 3], 326 fpn_interp_model="nearest", 327 ) 328 329 super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs) 330 self.scalp = scalp 331 self.embed_dim = embed_dim 332 self.img_size = img_size 333 334 def forward(self, x: torch.Tensor): 335 # The forward pass throught the backbone. 336 features, pos = self.neck(self.trunk(x)) 337 if self.scalp > 0: # This discard the "n" lowest resolution features. 338 features, pos = features[:-self.scalp], pos[:-self.scalp] 339 340 return features[-1], features 341 342 343class ViT_Sam3(SAM3ViT): 344 """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719). 345 346 Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py. 347 348 Args: 349 img_size: The input image size. 350 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 351 kwargs: Keyword arguments for the image encoder base class. 352 """ 353 def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs): 354 if not _sam3_import_success: 355 raise RuntimeError( 356 "The vision transformer backend can only be initialized if segment anything 3 is installed. " 357 "Please install segment anything 3 from https://github.com/facebookresearch/sam3 " 358 "and then rerun your code" 359 ) 360 361 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 362 self.img_size = img_size 363 self.embed_dim = embed_dim 364 365 def forward_features(self, x): 366 """@private 367 """ 368 x = self.patch_embed(x) 369 h, w = x.shape[1], x.shape[2] 370 371 s = 0 372 if self.retain_cls_token: 373 # If the 'cls_token' is retained, we don't maintain the spatial shape. 374 x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1) 375 s = 1 376 377 if self.pos_embed is not None: 378 x = x + get_abs_pos( 379 self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos, 380 ) 381 382 x = self.ln_pre(x) 383 384 list_from_encoder = [] 385 for i, blk in enumerate(self.blocks): 386 if self.use_act_checkpoint and self.training: 387 x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False) 388 else: 389 x = blk(x) 390 391 x = self._convert_to_expected_dim(x, i, s) 392 393 if i in self.full_attn_ids: 394 list_from_encoder.append(x) 395 396 return x, list_from_encoder 397 398 def _convert_to_expected_dim(self, x, i, s): 399 if (i == self.full_attn_ids[-1]) or ( 400 self.return_interm_layers and i in self.full_attn_ids 401 ): 402 if i == self.full_attn_ids[-1]: 403 x = self.ln_post(x) 404 405 feats = x[:, s:] 406 if feats.ndim == 4: 407 feats = feats.permute(0, 3, 1, 2) 408 else: 409 assert feats.ndim == 3 410 h = w = math.sqrt(feats.shape[1]) 411 feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2) 412 return feats 413 414 else: 415 return x 416 417 def forward(self, x: torch.Tensor): 418 """Apply the vision transformer to input data. 419 420 Args: 421 x: The input data. 422 423 Returns: 424 The vision transformer output. 425 """ 426 x, list_from_encoder = self.forward_features(x) 427 return x, list_from_encoder 428 429# 430# Utilities for ScaleMAE's ViT 431# 432 433 434class CustomCompose: 435 def __init__(self, rescale_transform, other_transforms, src_transform): 436 self.rescale_transform = rescale_transform 437 self.other_transforms = other_transforms 438 self.src_transform = src_transform 439 440 def __call__(self, x, valid_masks=None): 441 if valid_masks is not None: 442 nodata = (x * (1 - valid_masks.float())).max() 443 x_aug = self.rescale_transform(x) 444 parms = self.rescale_transform._params 445 446 # sanity check, comment if this is working 447 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 448 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 449 450 if valid_masks is not None: 451 valid_masks = x_aug != nodata 452 _, c, h, w = x_aug.shape 453 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 454 else: 455 zero_ratio = -1 456 457 if self.other_transforms: 458 x_aug = self.other_transforms(x_aug) 459 x_src = self.src_transform(x_aug) 460 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 461 462 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 463 # assert (dx == dy).all() 464 465 h, w = x_aug.shape[-2:] 466 # assert h == w 467 468 return x_aug, x_src, dx / h, zero_ratio, valid_masks 469 470 471def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 472 """ 473 grid_size: int of the grid height and width 474 res: array of size n, representing the resolution of a pixel (say, in meters), 475 return: 476 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 477 """ 478 # res = torch.FloatTensor(res).to(device) 479 res = res.to(device) 480 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 481 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 482 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 483 grid = torch.stack(grid, dim=0) # 2 x h x w 484 485 # grid = grid.reshape([2, 1, grid_size, grid_size]) 486 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 487 _, n, h, w = grid.shape 488 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 489 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 490 if cls_token: 491 pos_embed = torch.cat( 492 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 493 ) 494 495 return pos_embed 496 497 498def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 499 assert embed_dim % 2 == 0 500 501 # use half of dimensions to encode grid_h 502 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 503 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 504 505 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 506 return emb 507 508 509def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 510 """ 511 embed_dim: output dimension for each position 512 pos: a list of positions to be encoded: size (M,) 513 out: (M, D) 514 """ 515 assert embed_dim % 2 == 0 516 # old_shape = pos 517 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 518 omega /= embed_dim / 2.0 519 omega = 1.0 / 10000**omega # (D/2,) 520 521 pos = pos.reshape(-1) # (M,) 522 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 523 524 emb_sin = torch.sin(out) # (M, D/2) 525 emb_cos = torch.cos(out) # (M, D/2) 526 527 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 528 return emb 529 530 531class PatchEmbedUnSafe(PatchEmbed): 532 """Image to Patch Embedding""" 533 534 def forward(self, x): 535 B, C, H, W = x.shape 536 537 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 538 # assert H == self.img_size[0] and W == self.img_size[1], \ 539 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 540 541 x = self.proj(x).flatten(2).transpose(1, 2) 542 return x 543 544 545class ViT_ScaleMAE(VisionTransformer): 546 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 547 548 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 549 the model on a different zoom factor dataset. 550 """ 551 552 def __init__( 553 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 554 ): 555 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 556 self.img_size = img_size 557 self.in_chans = in_chans 558 self.depth = depth 559 self.base_resolution = base_resolution 560 561 self.patch_embed = PatchEmbedUnSafe( 562 img_size=img_size, 563 patch_size=patch_size, 564 in_chans=in_chans, 565 embed_dim=embed_dim, 566 ) 567 568 def transform_inputs(self, x): 569 import kornia.augmentation as K 570 from kornia.constants import Resample 571 572 self._transforms = CustomCompose( 573 rescale_transform=K.RandomResizedCrop( 574 (448, 448), 575 ratio=(1.0, 1.0), 576 scale=(1.0, 1.0), 577 resample=Resample.BICUBIC.name, 578 ), 579 other_transforms=None, 580 src_transform=K.Resize((224, 224)), 581 ) 582 x, _, ratios, _, _ = self._transforms(x) 583 input_res = ratios * self.base_resolution 584 return x, input_res 585 586 def convert_to_expected_dim(self, x): 587 inputs_ = x[:, 1:, :] # removing the class tokens 588 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 589 rdim = inputs_.shape[1] 590 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 591 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 592 inputs_ = inputs_.permute(0, 3, 1, 2) 593 return inputs_ 594 595 def forward_features(self, x): 596 x, input_res = self.transform_inputs(x) 597 598 B, _, h, w = x.shape 599 x = self.patch_embed(x) 600 601 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 602 pos_embed = get_2d_sincos_pos_embed_with_resolution( 603 x.shape[-1], 604 int(num_patches ** 0.5), 605 input_res, 606 cls_token=True, 607 device=x.device, 608 ) 609 610 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 611 x = torch.cat((cls_tokens, x), dim=1) 612 x = x + pos_embed 613 x = self.pos_drop(x) 614 615 # chunks obtained for getting the projections for conjuctions with upsampling blocks 616 _chunks = int(self.depth / 4) 617 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 618 619 list_from_encoder = [] 620 for i, blk in enumerate(self.blocks): 621 x = blk(x) 622 if i in chunks_for_projection: 623 list_from_encoder.append(self.convert_to_expected_dim(x)) 624 625 x = self.convert_to_expected_dim(x) 626 627 return x, list_from_encoder 628 629 def forward(self, x): 630 x, list_from_encoder = self.forward_features(x) 631 return x, list_from_encoder 632 633 634class ViT_DINOv2(DinoV2VisionTransformer): 635 """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193). 636 637 Based on: 638 https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py. 639 640 Args: 641 img_size: The input image size. 642 patch_size: The patch size. 643 depth: The depth of the network. 644 num_register_tokens: The number of registers added (in addition to the class tokens). 645 It's important to know for ViTs trained with registers, to remove them at the end. 646 """ 647 def __init__( 648 self, 649 img_size: int = 224, 650 patch_size: int = 16, 651 depth: int = 12, 652 num_register_tokens: int = 0, 653 **kwargs 654 ): 655 if not _dinov2_import_success: 656 raise RuntimeError( 657 "The vision transformer backend can only be initialized if DINOv2 is installed. " 658 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 659 "and then rerun your code." 660 ) 661 662 super().__init__( 663 img_size=img_size, 664 depth=depth, 665 patch_size=patch_size, 666 num_register_tokens=num_register_tokens, 667 **kwargs 668 ) 669 670 self.img_size = img_size 671 self.num_register_tokens = num_register_tokens 672 self.patch_size = patch_size 673 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 674 675 def forward(self, x, masks=None) -> torch.Tensor: 676 677 B = x.shape[0] 678 679 x = self.prepare_tokens_with_masks(x) 680 681 list_of_encoder = [] 682 for i, blk in enumerate(self.blocks): 683 x = blk(x) 684 if i in self.attn_outs: 685 list_of_encoder.append(x) 686 687 x = self.norm(x) 688 x = x[:, self.num_register_tokens + 1:].reshape( 689 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 690 ).permute(0, 3, 1, 2).contiguous() 691 692 list_of_encoder = [ 693 o[:, self.num_register_tokens + 1:].reshape( 694 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 695 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 696 ] 697 698 return x, list_of_encoder[:3] 699 700 701class ViT_DINOv3(DinoV3VisionTransformer): 702 """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104). 703 704 Based on: 705 https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py. 706 707 Args: 708 img_size: The input image size. 709 patch_size: The patch size. 710 embed_dim: The embedding dimension. 711 depth: The depth of the network. 712 num_heads: The number of heads. 713 ffn_ratio: The FFN rato. 714 n_storage_tokens: The number of storage (class) tokens to remove. 715 kwargs: Keyword arguments for the image encoder base class. 716 """ 717 def __init__( 718 self, 719 in_chans: int = 3, 720 img_size: int = 224, 721 patch_size: int = 16, 722 embed_dim: int = 768, 723 depth: int = 12, 724 num_heads: int = 12, 725 ffn_ratio: float = 4.0, 726 n_storage_tokens: int = 0, 727 **kwargs 728 ): 729 if not _dinov3_import_success: 730 raise RuntimeError( 731 "The vision transformer backend can only be initialized if DINOv3 is installed. " 732 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 733 "and then rerun your code." 734 ) 735 736 super().__init__( 737 in_chans=in_chans, 738 img_size=img_size, 739 patch_size=patch_size, 740 embed_dim=embed_dim, 741 depth=depth, 742 num_heads=num_heads, 743 ffn_ratio=ffn_ratio, 744 n_storage_tokens=n_storage_tokens, 745 **kwargs 746 ) 747 748 self.in_chans = in_chans 749 self.img_size = img_size 750 self.n_storage_tokens = n_storage_tokens 751 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 752 753 def forward(self, x) -> torch.Tensor: 754 755 B = x.shape[0] 756 757 x, hw_tuple = self.prepare_tokens_with_masks(x) 758 759 list_of_encoder = [] 760 for i, blk in enumerate(self.blocks): 761 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 762 x = blk(x, rope_sincos) 763 if i in self.attn_outs: 764 list_of_encoder.append(x) 765 766 x = self.norm(x) 767 x = x[:, self.n_storage_tokens + 1:].reshape( 768 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 769 ).permute(0, 3, 1, 2).contiguous() 770 771 list_of_encoder = [ 772 o[:, self.n_storage_tokens + 1:].reshape( 773 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 774 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 775 ] 776 777 return x, list_of_encoder[:3] 778 779 780def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 781 """Get vision transformer encoder. 782 783 Args: 784 backbone: The name of the vision transformer implementation. 785 One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3". 786 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 787 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 788 kwargs: Additional kwargs which can be expected by the vision transformer, 789 e.g. 'base_resolution' for `ViT_ScaleMAE`. 790 791 Returns: 792 The vision transformer. 793 """ 794 if backbone == "sam": 795 if model == "vit_b": 796 encoder = ViT_Sam( 797 depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4, 798 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 799 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 800 global_attn_indexes=[2, 5, 8, 11], 801 window_size=14, out_chans=256, 802 ) 803 elif model == "vit_l": 804 encoder = ViT_Sam( 805 depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4, 806 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 807 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 808 global_attn_indexes=[5, 11, 17, 23], 809 window_size=14, out_chans=256, 810 ) 811 elif model == "vit_h": 812 encoder = ViT_Sam( 813 depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4, 814 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 815 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 816 global_attn_indexes=[7, 15, 23, 31], 817 window_size=14, out_chans=256, 818 ) 819 else: 820 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 821 822 elif backbone == "cellpose_sam": 823 if model != "vit_l": 824 raise ValueError(f"'{model}' is not supported by CellposeSAM. Only 'vit_l' is supported.") 825 encoder = ViT_CellposeSAM(ps=8, bsize=img_size) 826 827 elif backbone == "sam2": 828 if model == "hvit_t": 829 encoder = ViT_Sam2( 830 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 831 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 832 ) 833 elif model == "hvit_s": 834 encoder = ViT_Sam2( 835 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 836 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 837 ) 838 elif model == "hvit_b": 839 encoder = ViT_Sam2( 840 img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 841 ) 842 elif model == "hvit_l": 843 encoder = ViT_Sam2( 844 img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 845 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 846 ) 847 else: 848 raise ValueError( 849 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 850 ) 851 852 elif backbone == "sam3": 853 if model != "vit_pe": 854 raise ValueError( 855 "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration." 856 ) 857 858 encoder = ViT_Sam3( 859 img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16, 860 mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True, 861 tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True, 862 use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True, 863 ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None, 864 ) 865 866 elif backbone == "mae": 867 if model == "vit_b": 868 encoder = ViT_MAE( 869 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 870 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 871 ) 872 elif model == "vit_l": 873 encoder = ViT_MAE( 874 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 875 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 876 ) 877 elif model == "vit_h": 878 encoder = ViT_MAE( 879 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 880 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 881 ) 882 else: 883 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 884 885 elif backbone == "scalemae": 886 base_resolution = kwargs.get("base_resolution", 2.5) 887 888 if model == "vit_b": 889 encoder = ViT_ScaleMAE( 890 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 891 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 892 base_resolution=base_resolution, 893 ) 894 elif model == "vit_l": 895 encoder = ViT_ScaleMAE( 896 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 897 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 898 base_resolution=base_resolution, 899 ) 900 elif model == "vit_h": 901 encoder = ViT_ScaleMAE( 902 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 903 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 904 base_resolution=base_resolution, 905 ) 906 else: 907 raise ValueError( 908 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 909 ) 910 911 elif backbone == "dinov2": 912 block_fn = partial(Block, attn_class=MemEffAttention) 913 msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>." 914 915 if model.startswith("vit_s"): 916 assert model in ["vit_s", "vit_s_reg4"], msg 917 encoder = ViT_DINOv2( 918 img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 919 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 920 num_register_tokens=4 if model.endswith("_reg4") else 0, 921 ) 922 elif model.startswith("vit_b"): 923 assert model in ["vit_b", "vit_b_reg4"], msg 924 encoder = ViT_DINOv2( 925 img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 926 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 927 num_register_tokens=4 if model.endswith("_reg4") else 0, 928 ) 929 elif model.startswith("vit_l"): 930 assert model in ["vit_l", "vit_l_reg4"], msg 931 encoder = ViT_DINOv2( 932 img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 933 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 934 num_register_tokens=4 if model.endswith("_reg4") else 0, 935 ) 936 elif model.startswith("vit_g"): 937 assert model in ["vit_g", "vit_g_reg4"], msg 938 encoder = ViT_DINOv2( 939 img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, 940 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 941 num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu", 942 ) 943 else: 944 raise ValueError( 945 f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported." 946 ) 947 948 elif backbone == "dinov3": 949 950 if model == "vit_s": 951 encoder = ViT_DINOv3( 952 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 953 num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 954 ) 955 elif model == "vit_s+": 956 encoder = ViT_DINOv3( 957 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 958 num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", 959 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 960 ) 961 962 elif model == "vit_b": 963 encoder = ViT_DINOv3( 964 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", 965 layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 966 ) 967 elif model == "vit_l": 968 encoder = ViT_DINOv3( 969 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 970 depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16", 971 n_storage_tokens=4, mask_k_bias=True, 972 ) 973 elif model == "vit_l+": 974 encoder = ViT_DINOv3( 975 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 976 depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 977 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 978 ) 979 elif model == "vit_h+": 980 encoder = ViT_DINOv3( 981 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280, 982 depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 983 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 984 ) 985 elif model == "vit_7b": 986 encoder = ViT_DINOv3( 987 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096, 988 depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, 989 norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True, 990 untie_global_and_local_cls_norm=True, 991 ) 992 else: 993 raise ValueError( 994 f"'{model}' is not supported by DINOv3. Currently, " 995 " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported." 996 ) 997 998 else: 999 raise ValueError( 1000 "The 'UNETR' supported backbones are 'sam', 'cellpose_sam', 'sam2', 'sam3', " 1001 "'mae', 'scalemae', 'dinov2' or 'dinov3'. Please choose one of them." 1002 ) 1003 1004 return encoder
59class ViT_Sam(ImageEncoderViT): 60 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 61 62 Based on: 63 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 64 65 Args: 66 in_chans: The number of input channels. 67 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 68 global_attn_indexes: The global attention indices. 69 kwargs: Keyword arguments for the image encoder base class. 70 """ 71 def __init__( 72 self, 73 in_chans: int = 3, 74 embed_dim: int = 768, 75 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 76 **kwargs, 77 ) -> None: 78 if not _sam_import_success: 79 raise RuntimeError( 80 "The vision transformer backend can only be initialized if segment anything is installed. " 81 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 82 "and then rerun your code." 83 ) 84 85 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 86 self.chunks_for_projection = global_attn_indexes 87 self.in_chans = in_chans 88 self.embed_dim = embed_dim 89 90 def forward(self, x: torch.Tensor) -> torch.Tensor: 91 """Apply the vision transformer to input data. 92 93 Args: 94 x: The input data. 95 96 Returns: 97 The vision transformer output. 98 """ 99 x = self.patch_embed(x) 100 if self.pos_embed is not None: 101 x = x + self.pos_embed 102 103 list_from_encoder = [] 104 for i, blk in enumerate(self.blocks): 105 x = blk(x) 106 if i in self.chunks_for_projection: 107 list_from_encoder.append(x) 108 109 x = x.permute(0, 3, 1, 2) 110 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 111 return x, list_from_encoder[:3]
Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
Arguments:
- in_chans: The number of input channels.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- global_attn_indexes: The global attention indices.
- kwargs: Keyword arguments for the image encoder base class.
71 def __init__( 72 self, 73 in_chans: int = 3, 74 embed_dim: int = 768, 75 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 76 **kwargs, 77 ) -> None: 78 if not _sam_import_success: 79 raise RuntimeError( 80 "The vision transformer backend can only be initialized if segment anything is installed. " 81 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 82 "and then rerun your code." 83 ) 84 85 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 86 self.chunks_for_projection = global_attn_indexes 87 self.in_chans = in_chans 88 self.embed_dim = embed_dim
90 def forward(self, x: torch.Tensor) -> torch.Tensor: 91 """Apply the vision transformer to input data. 92 93 Args: 94 x: The input data. 95 96 Returns: 97 The vision transformer output. 98 """ 99 x = self.patch_embed(x) 100 if self.pos_embed is not None: 101 x = x + self.pos_embed 102 103 list_from_encoder = [] 104 for i, blk in enumerate(self.blocks): 105 x = blk(x) 106 if i in self.chunks_for_projection: 107 list_from_encoder.append(x) 108 109 x = x.permute(0, 3, 1, 2) 110 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 111 return x, list_from_encoder[:3]
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
114class ViT_CellposeSAM(nn.Module): 115 """Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x). 116 117 This replicates CellposeSAM's actual initialization: instantiate SAM's ``ImageEncoderViT`` via 118 ``sam_model_registry``, then modify the patch embedding, position embeddings, and set global attention. 119 This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading 120 without any interpolation. 121 122 Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py 123 124 NOTE: The pretrained CellposeSAM model uses ``vit_l`` exclusively. 125 126 Args: 127 ps: The patch size (default for CellposeSAM is 8). 128 bsize: The input image size (default for CellposeSAM is 256). 129 """ 130 def __init__(self, ps: int = 8, bsize: int = 256) -> None: 131 super().__init__() 132 133 if not _sam_import_success: 134 raise RuntimeError( 135 "The vision transformer backend can only be initialized if segment anything is installed. " 136 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 137 "and then rerun your code." 138 ) 139 140 from segment_anything import sam_model_registry 141 142 # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer). 143 encoder = sam_model_registry["vit_l"](None).image_encoder 144 145 w = encoder.patch_embed.proj.weight.detach() 146 nchan = w.shape[0] 147 148 # CellPoseSAM changes the patch size from 16 to 'ps'. 149 encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps) 150 encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps] 151 152 # Next, they subsample position embeddings for the new patch size and input resolution. 153 ds = (1024 // 16) // (bsize // ps) 154 encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True) 155 156 # Finally, they set all blocks to global attention. 157 for blk in encoder.blocks: 158 blk.window_size = 0 159 160 # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping). 161 self.patch_embed = encoder.patch_embed 162 self.pos_embed = encoder.pos_embed 163 self.blocks = encoder.blocks 164 self.neck = encoder.neck 165 166 # Additional attributes expected by UNETR. 167 self.embed_dim = nchan 168 self.img_size = bsize 169 self.in_chans = 3 170 171 # Feature extraction at evenly-spaced depths. 172 depth = len(self.blocks) 173 _chunks = depth // 4 174 self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1] 175 176 def forward(self, x: torch.Tensor) -> torch.Tensor: 177 """Apply the vision transformer to input data. 178 179 Args: 180 x: The input data. 181 182 Returns: 183 The vision transformer output. 184 """ 185 x = self.patch_embed(x) 186 if self.pos_embed is not None: 187 x = x + self.pos_embed 188 189 list_from_encoder = [] 190 for i, blk in enumerate(self.blocks): 191 x = blk(x) 192 if i in self.chunks_for_projection: 193 list_from_encoder.append(x) 194 195 x = x.permute(0, 3, 1, 2) 196 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 197 return x, list_from_encoder[:3]
Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x).
This replicates CellposeSAM's actual initialization: instantiate SAM's ImageEncoderViT via
sam_model_registry, then modify the patch embedding, position embeddings, and set global attention.
This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading
without any interpolation.
Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py
NOTE: The pretrained CellposeSAM model uses vit_l exclusively.
Arguments:
- ps: The patch size (default for CellposeSAM is 8).
- bsize: The input image size (default for CellposeSAM is 256).
130 def __init__(self, ps: int = 8, bsize: int = 256) -> None: 131 super().__init__() 132 133 if not _sam_import_success: 134 raise RuntimeError( 135 "The vision transformer backend can only be initialized if segment anything is installed. " 136 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 137 "and then rerun your code." 138 ) 139 140 from segment_anything import sam_model_registry 141 142 # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer). 143 encoder = sam_model_registry["vit_l"](None).image_encoder 144 145 w = encoder.patch_embed.proj.weight.detach() 146 nchan = w.shape[0] 147 148 # CellPoseSAM changes the patch size from 16 to 'ps'. 149 encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps) 150 encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps] 151 152 # Next, they subsample position embeddings for the new patch size and input resolution. 153 ds = (1024 // 16) // (bsize // ps) 154 encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True) 155 156 # Finally, they set all blocks to global attention. 157 for blk in encoder.blocks: 158 blk.window_size = 0 159 160 # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping). 161 self.patch_embed = encoder.patch_embed 162 self.pos_embed = encoder.pos_embed 163 self.blocks = encoder.blocks 164 self.neck = encoder.neck 165 166 # Additional attributes expected by UNETR. 167 self.embed_dim = nchan 168 self.img_size = bsize 169 self.in_chans = 3 170 171 # Feature extraction at evenly-spaced depths. 172 depth = len(self.blocks) 173 _chunks = depth // 4 174 self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
176 def forward(self, x: torch.Tensor) -> torch.Tensor: 177 """Apply the vision transformer to input data. 178 179 Args: 180 x: The input data. 181 182 Returns: 183 The vision transformer output. 184 """ 185 x = self.patch_embed(x) 186 if self.pos_embed is not None: 187 x = x + self.pos_embed 188 189 list_from_encoder = [] 190 for i, blk in enumerate(self.blocks): 191 x = blk(x) 192 if i in self.chunks_for_projection: 193 list_from_encoder.append(x) 194 195 x = x.permute(0, 3, 1, 2) 196 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 197 return x, list_from_encoder[:3]
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
200class ViT_MAE(VisionTransformer): 201 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 202 203 Based on: 204 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 205 206 Args: 207 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 208 in_chans: The number of input channels. 209 depth: The depth of the vision transformer. 210 kwargs: Additional keyword arguments for the vision transformer base class. 211 """ 212 def __init__( 213 self, 214 img_size: int = 1024, # chosen to match our experiments with segment anything 215 in_chans: int = 3, 216 depth: int = 12, 217 **kwargs 218 ): 219 if not _timm_import_success: 220 raise RuntimeError( 221 "The vision transformer backend can only be initialized if timm is installed. " 222 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 223 "and then rerun your code" 224 ) 225 super().__init__(img_size=img_size, depth=depth, **kwargs) 226 self.img_size = img_size 227 self.in_chans = in_chans 228 self.depth = depth 229 230 def convert_to_expected_dim(self, inputs_): 231 """@private 232 """ 233 inputs_ = inputs_[:, 1:, :] # removing the class tokens 234 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 235 rdim = inputs_.shape[1] 236 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 237 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 238 inputs_ = inputs_.permute(0, 3, 1, 2) 239 return inputs_ 240 241 def forward_features(self, x): 242 """@private 243 """ 244 B = x.shape[0] 245 x = self.patch_embed(x) 246 247 cls_tokens = self.cls_token.expand(B, -1, -1) 248 x = torch.cat((cls_tokens, x), dim=1) 249 250 x = x + self.pos_embed 251 x = self.pos_drop(x) 252 253 # chunks obtained for getting the projections for conjuctions with upsampling blocks 254 _chunks = int(self.depth / 4) 255 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 256 257 list_from_encoder = [] 258 for i, blk in enumerate(self.blocks): 259 x = blk(x) 260 if i in chunks_for_projection: 261 list_from_encoder.append(self.convert_to_expected_dim(x)) 262 263 x = self.convert_to_expected_dim(x) 264 return x, list_from_encoder[:3] 265 266 def forward(self, x: torch.Tensor) -> torch.Tensor: 267 """Apply the vision transformer to input data. 268 269 Args: 270 x: The input data. 271 272 Returns: 273 The vision transformer output. 274 """ 275 x, list_from_encoder = self.forward_features(x) 276 return x, list_from_encoder
Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- in_chans: The number of input channels.
- depth: The depth of the vision transformer.
- kwargs: Additional keyword arguments for the vision transformer base class.
212 def __init__( 213 self, 214 img_size: int = 1024, # chosen to match our experiments with segment anything 215 in_chans: int = 3, 216 depth: int = 12, 217 **kwargs 218 ): 219 if not _timm_import_success: 220 raise RuntimeError( 221 "The vision transformer backend can only be initialized if timm is installed. " 222 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 223 "and then rerun your code" 224 ) 225 super().__init__(img_size=img_size, depth=depth, **kwargs) 226 self.img_size = img_size 227 self.in_chans = in_chans 228 self.depth = depth
266 def forward(self, x: torch.Tensor) -> torch.Tensor: 267 """Apply the vision transformer to input data. 268 269 Args: 270 x: The input data. 271 272 Returns: 273 The vision transformer output. 274 """ 275 x, list_from_encoder = self.forward_features(x) 276 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
279class ViT_Sam2(ImageEncoder): 280 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 281 282 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 283 284 Args: 285 backbone_channel_list: The channels throughout the entire backbone. 286 embed_dim: The initial embedding dimension. 287 num_heads: The initial number of heads. 288 stages: The number of blocks per stage. 289 global_att_blocks: The parameter to decide which blocks have global attention. 290 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 291 window_spec: The window size per stage, when not using global attention. 292 scalp: The count of lowest resolution features to discard. 293 """ 294 def __init__( 295 self, 296 backbone_channel_list: List[int], 297 img_size: int = 1024, 298 embed_dim: int = 96, 299 num_heads: int = 1, 300 stages: Tuple[int, ...] = (2, 3, 16, 3), 301 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 302 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 303 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 304 scalp: int = 1, 305 **kwargs 306 ): 307 if not _sam2_import_success: 308 raise RuntimeError( 309 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 310 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 311 "and then rerun your code" 312 ) 313 314 trunk = Hiera( 315 embed_dim=embed_dim, 316 num_heads=num_heads, 317 stages=stages, 318 global_att_blocks=global_att_blocks, 319 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 320 window_spec=window_spec, 321 ) 322 neck = FpnNeck( 323 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 324 d_model=256, 325 backbone_channel_list=backbone_channel_list, 326 fpn_top_down_levels=[2, 3], 327 fpn_interp_model="nearest", 328 ) 329 330 super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs) 331 self.scalp = scalp 332 self.embed_dim = embed_dim 333 self.img_size = img_size 334 335 def forward(self, x: torch.Tensor): 336 # The forward pass throught the backbone. 337 features, pos = self.neck(self.trunk(x)) 338 if self.scalp > 0: # This discard the "n" lowest resolution features. 339 features, pos = features[:-self.scalp], pos[:-self.scalp] 340 341 return features[-1], features
Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
Arguments:
- backbone_channel_list: The channels throughout the entire backbone.
- embed_dim: The initial embedding dimension.
- num_heads: The initial number of heads.
- stages: The number of blocks per stage.
- global_att_blocks: The parameter to decide which blocks have global attention.
- window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
- window_spec: The window size per stage, when not using global attention.
- scalp: The count of lowest resolution features to discard.
294 def __init__( 295 self, 296 backbone_channel_list: List[int], 297 img_size: int = 1024, 298 embed_dim: int = 96, 299 num_heads: int = 1, 300 stages: Tuple[int, ...] = (2, 3, 16, 3), 301 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 302 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 303 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 304 scalp: int = 1, 305 **kwargs 306 ): 307 if not _sam2_import_success: 308 raise RuntimeError( 309 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 310 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 311 "and then rerun your code" 312 ) 313 314 trunk = Hiera( 315 embed_dim=embed_dim, 316 num_heads=num_heads, 317 stages=stages, 318 global_att_blocks=global_att_blocks, 319 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 320 window_spec=window_spec, 321 ) 322 neck = FpnNeck( 323 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 324 d_model=256, 325 backbone_channel_list=backbone_channel_list, 326 fpn_top_down_levels=[2, 3], 327 fpn_interp_model="nearest", 328 ) 329 330 super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs) 331 self.scalp = scalp 332 self.embed_dim = embed_dim 333 self.img_size = img_size
335 def forward(self, x: torch.Tensor): 336 # The forward pass throught the backbone. 337 features, pos = self.neck(self.trunk(x)) 338 if self.scalp > 0: # This discard the "n" lowest resolution features. 339 features, pos = features[:-self.scalp], pos[:-self.scalp] 340 341 return features[-1], features
344class ViT_Sam3(SAM3ViT): 345 """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719). 346 347 Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py. 348 349 Args: 350 img_size: The input image size. 351 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 352 kwargs: Keyword arguments for the image encoder base class. 353 """ 354 def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs): 355 if not _sam3_import_success: 356 raise RuntimeError( 357 "The vision transformer backend can only be initialized if segment anything 3 is installed. " 358 "Please install segment anything 3 from https://github.com/facebookresearch/sam3 " 359 "and then rerun your code" 360 ) 361 362 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 363 self.img_size = img_size 364 self.embed_dim = embed_dim 365 366 def forward_features(self, x): 367 """@private 368 """ 369 x = self.patch_embed(x) 370 h, w = x.shape[1], x.shape[2] 371 372 s = 0 373 if self.retain_cls_token: 374 # If the 'cls_token' is retained, we don't maintain the spatial shape. 375 x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1) 376 s = 1 377 378 if self.pos_embed is not None: 379 x = x + get_abs_pos( 380 self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos, 381 ) 382 383 x = self.ln_pre(x) 384 385 list_from_encoder = [] 386 for i, blk in enumerate(self.blocks): 387 if self.use_act_checkpoint and self.training: 388 x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False) 389 else: 390 x = blk(x) 391 392 x = self._convert_to_expected_dim(x, i, s) 393 394 if i in self.full_attn_ids: 395 list_from_encoder.append(x) 396 397 return x, list_from_encoder 398 399 def _convert_to_expected_dim(self, x, i, s): 400 if (i == self.full_attn_ids[-1]) or ( 401 self.return_interm_layers and i in self.full_attn_ids 402 ): 403 if i == self.full_attn_ids[-1]: 404 x = self.ln_post(x) 405 406 feats = x[:, s:] 407 if feats.ndim == 4: 408 feats = feats.permute(0, 3, 1, 2) 409 else: 410 assert feats.ndim == 3 411 h = w = math.sqrt(feats.shape[1]) 412 feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2) 413 return feats 414 415 else: 416 return x 417 418 def forward(self, x: torch.Tensor): 419 """Apply the vision transformer to input data. 420 421 Args: 422 x: The input data. 423 424 Returns: 425 The vision transformer output. 426 """ 427 x, list_from_encoder = self.forward_features(x) 428 return x, list_from_encoder
Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).
Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.
Arguments:
- img_size: The input image size.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- kwargs: Keyword arguments for the image encoder base class.
354 def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs): 355 if not _sam3_import_success: 356 raise RuntimeError( 357 "The vision transformer backend can only be initialized if segment anything 3 is installed. " 358 "Please install segment anything 3 from https://github.com/facebookresearch/sam3 " 359 "and then rerun your code" 360 ) 361 362 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 363 self.img_size = img_size 364 self.embed_dim = embed_dim
418 def forward(self, x: torch.Tensor): 419 """Apply the vision transformer to input data. 420 421 Args: 422 x: The input data. 423 424 Returns: 425 The vision transformer output. 426 """ 427 x, list_from_encoder = self.forward_features(x) 428 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
435class CustomCompose: 436 def __init__(self, rescale_transform, other_transforms, src_transform): 437 self.rescale_transform = rescale_transform 438 self.other_transforms = other_transforms 439 self.src_transform = src_transform 440 441 def __call__(self, x, valid_masks=None): 442 if valid_masks is not None: 443 nodata = (x * (1 - valid_masks.float())).max() 444 x_aug = self.rescale_transform(x) 445 parms = self.rescale_transform._params 446 447 # sanity check, comment if this is working 448 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 449 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 450 451 if valid_masks is not None: 452 valid_masks = x_aug != nodata 453 _, c, h, w = x_aug.shape 454 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 455 else: 456 zero_ratio = -1 457 458 if self.other_transforms: 459 x_aug = self.other_transforms(x_aug) 460 x_src = self.src_transform(x_aug) 461 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 462 463 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 464 # assert (dx == dy).all() 465 466 h, w = x_aug.shape[-2:] 467 # assert h == w 468 469 return x_aug, x_src, dx / h, zero_ratio, valid_masks
472def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 473 """ 474 grid_size: int of the grid height and width 475 res: array of size n, representing the resolution of a pixel (say, in meters), 476 return: 477 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 478 """ 479 # res = torch.FloatTensor(res).to(device) 480 res = res.to(device) 481 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 482 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 483 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 484 grid = torch.stack(grid, dim=0) # 2 x h x w 485 486 # grid = grid.reshape([2, 1, grid_size, grid_size]) 487 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 488 _, n, h, w = grid.shape 489 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 490 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 491 if cls_token: 492 pos_embed = torch.cat( 493 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 494 ) 495 496 return pos_embed
grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
499def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 500 assert embed_dim % 2 == 0 501 502 # use half of dimensions to encode grid_h 503 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 504 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 505 506 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 507 return emb
510def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 511 """ 512 embed_dim: output dimension for each position 513 pos: a list of positions to be encoded: size (M,) 514 out: (M, D) 515 """ 516 assert embed_dim % 2 == 0 517 # old_shape = pos 518 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 519 omega /= embed_dim / 2.0 520 omega = 1.0 / 10000**omega # (D/2,) 521 522 pos = pos.reshape(-1) # (M,) 523 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 524 525 emb_sin = torch.sin(out) # (M, D/2) 526 emb_cos = torch.cos(out) # (M, D/2) 527 528 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 529 return emb
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
532class PatchEmbedUnSafe(PatchEmbed): 533 """Image to Patch Embedding""" 534 535 def forward(self, x): 536 B, C, H, W = x.shape 537 538 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 539 # assert H == self.img_size[0] and W == self.img_size[1], \ 540 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 541 542 x = self.proj(x).flatten(2).transpose(1, 2) 543 return x
Image to Patch Embedding
535 def forward(self, x): 536 B, C, H, W = x.shape 537 538 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 539 # assert H == self.img_size[0] and W == self.img_size[1], \ 540 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 541 542 x = self.proj(x).flatten(2).transpose(1, 2) 543 return x
546class ViT_ScaleMAE(VisionTransformer): 547 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 548 549 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 550 the model on a different zoom factor dataset. 551 """ 552 553 def __init__( 554 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 555 ): 556 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 557 self.img_size = img_size 558 self.in_chans = in_chans 559 self.depth = depth 560 self.base_resolution = base_resolution 561 562 self.patch_embed = PatchEmbedUnSafe( 563 img_size=img_size, 564 patch_size=patch_size, 565 in_chans=in_chans, 566 embed_dim=embed_dim, 567 ) 568 569 def transform_inputs(self, x): 570 import kornia.augmentation as K 571 from kornia.constants import Resample 572 573 self._transforms = CustomCompose( 574 rescale_transform=K.RandomResizedCrop( 575 (448, 448), 576 ratio=(1.0, 1.0), 577 scale=(1.0, 1.0), 578 resample=Resample.BICUBIC.name, 579 ), 580 other_transforms=None, 581 src_transform=K.Resize((224, 224)), 582 ) 583 x, _, ratios, _, _ = self._transforms(x) 584 input_res = ratios * self.base_resolution 585 return x, input_res 586 587 def convert_to_expected_dim(self, x): 588 inputs_ = x[:, 1:, :] # removing the class tokens 589 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 590 rdim = inputs_.shape[1] 591 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 592 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 593 inputs_ = inputs_.permute(0, 3, 1, 2) 594 return inputs_ 595 596 def forward_features(self, x): 597 x, input_res = self.transform_inputs(x) 598 599 B, _, h, w = x.shape 600 x = self.patch_embed(x) 601 602 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 603 pos_embed = get_2d_sincos_pos_embed_with_resolution( 604 x.shape[-1], 605 int(num_patches ** 0.5), 606 input_res, 607 cls_token=True, 608 device=x.device, 609 ) 610 611 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 612 x = torch.cat((cls_tokens, x), dim=1) 613 x = x + pos_embed 614 x = self.pos_drop(x) 615 616 # chunks obtained for getting the projections for conjuctions with upsampling blocks 617 _chunks = int(self.depth / 4) 618 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 619 620 list_from_encoder = [] 621 for i, blk in enumerate(self.blocks): 622 x = blk(x) 623 if i in chunks_for_projection: 624 list_from_encoder.append(self.convert_to_expected_dim(x)) 625 626 x = self.convert_to_expected_dim(x) 627 628 return x, list_from_encoder 629 630 def forward(self, x): 631 x, list_from_encoder = self.forward_features(x) 632 return x, list_from_encoder
Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.
553 def __init__( 554 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 555 ): 556 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 557 self.img_size = img_size 558 self.in_chans = in_chans 559 self.depth = depth 560 self.base_resolution = base_resolution 561 562 self.patch_embed = PatchEmbedUnSafe( 563 img_size=img_size, 564 patch_size=patch_size, 565 in_chans=in_chans, 566 embed_dim=embed_dim, 567 )
569 def transform_inputs(self, x): 570 import kornia.augmentation as K 571 from kornia.constants import Resample 572 573 self._transforms = CustomCompose( 574 rescale_transform=K.RandomResizedCrop( 575 (448, 448), 576 ratio=(1.0, 1.0), 577 scale=(1.0, 1.0), 578 resample=Resample.BICUBIC.name, 579 ), 580 other_transforms=None, 581 src_transform=K.Resize((224, 224)), 582 ) 583 x, _, ratios, _, _ = self._transforms(x) 584 input_res = ratios * self.base_resolution 585 return x, input_res
587 def convert_to_expected_dim(self, x): 588 inputs_ = x[:, 1:, :] # removing the class tokens 589 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 590 rdim = inputs_.shape[1] 591 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 592 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 593 inputs_ = inputs_.permute(0, 3, 1, 2) 594 return inputs_
596 def forward_features(self, x): 597 x, input_res = self.transform_inputs(x) 598 599 B, _, h, w = x.shape 600 x = self.patch_embed(x) 601 602 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 603 pos_embed = get_2d_sincos_pos_embed_with_resolution( 604 x.shape[-1], 605 int(num_patches ** 0.5), 606 input_res, 607 cls_token=True, 608 device=x.device, 609 ) 610 611 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 612 x = torch.cat((cls_tokens, x), dim=1) 613 x = x + pos_embed 614 x = self.pos_drop(x) 615 616 # chunks obtained for getting the projections for conjuctions with upsampling blocks 617 _chunks = int(self.depth / 4) 618 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 619 620 list_from_encoder = [] 621 for i, blk in enumerate(self.blocks): 622 x = blk(x) 623 if i in chunks_for_projection: 624 list_from_encoder.append(self.convert_to_expected_dim(x)) 625 626 x = self.convert_to_expected_dim(x) 627 628 return x, list_from_encoder
635class ViT_DINOv2(DinoV2VisionTransformer): 636 """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193). 637 638 Based on: 639 https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py. 640 641 Args: 642 img_size: The input image size. 643 patch_size: The patch size. 644 depth: The depth of the network. 645 num_register_tokens: The number of registers added (in addition to the class tokens). 646 It's important to know for ViTs trained with registers, to remove them at the end. 647 """ 648 def __init__( 649 self, 650 img_size: int = 224, 651 patch_size: int = 16, 652 depth: int = 12, 653 num_register_tokens: int = 0, 654 **kwargs 655 ): 656 if not _dinov2_import_success: 657 raise RuntimeError( 658 "The vision transformer backend can only be initialized if DINOv2 is installed. " 659 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 660 "and then rerun your code." 661 ) 662 663 super().__init__( 664 img_size=img_size, 665 depth=depth, 666 patch_size=patch_size, 667 num_register_tokens=num_register_tokens, 668 **kwargs 669 ) 670 671 self.img_size = img_size 672 self.num_register_tokens = num_register_tokens 673 self.patch_size = patch_size 674 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 675 676 def forward(self, x, masks=None) -> torch.Tensor: 677 678 B = x.shape[0] 679 680 x = self.prepare_tokens_with_masks(x) 681 682 list_of_encoder = [] 683 for i, blk in enumerate(self.blocks): 684 x = blk(x) 685 if i in self.attn_outs: 686 list_of_encoder.append(x) 687 688 x = self.norm(x) 689 x = x[:, self.num_register_tokens + 1:].reshape( 690 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 691 ).permute(0, 3, 1, 2).contiguous() 692 693 list_of_encoder = [ 694 o[:, self.num_register_tokens + 1:].reshape( 695 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 696 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 697 ] 698 699 return x, list_of_encoder[:3]
Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
Based on: https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
Arguments:
- img_size: The input image size.
- patch_size: The patch size.
- depth: The depth of the network.
- num_register_tokens: The number of registers added (in addition to the class tokens). It's important to know for ViTs trained with registers, to remove them at the end.
648 def __init__( 649 self, 650 img_size: int = 224, 651 patch_size: int = 16, 652 depth: int = 12, 653 num_register_tokens: int = 0, 654 **kwargs 655 ): 656 if not _dinov2_import_success: 657 raise RuntimeError( 658 "The vision transformer backend can only be initialized if DINOv2 is installed. " 659 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 660 "and then rerun your code." 661 ) 662 663 super().__init__( 664 img_size=img_size, 665 depth=depth, 666 patch_size=patch_size, 667 num_register_tokens=num_register_tokens, 668 **kwargs 669 ) 670 671 self.img_size = img_size 672 self.num_register_tokens = num_register_tokens 673 self.patch_size = patch_size 674 self.attn_outs = [i for i in range(depth) if i % 3 == 2]
676 def forward(self, x, masks=None) -> torch.Tensor: 677 678 B = x.shape[0] 679 680 x = self.prepare_tokens_with_masks(x) 681 682 list_of_encoder = [] 683 for i, blk in enumerate(self.blocks): 684 x = blk(x) 685 if i in self.attn_outs: 686 list_of_encoder.append(x) 687 688 x = self.norm(x) 689 x = x[:, self.num_register_tokens + 1:].reshape( 690 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 691 ).permute(0, 3, 1, 2).contiguous() 692 693 list_of_encoder = [ 694 o[:, self.num_register_tokens + 1:].reshape( 695 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 696 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 697 ] 698 699 return x, list_of_encoder[:3]
702class ViT_DINOv3(DinoV3VisionTransformer): 703 """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104). 704 705 Based on: 706 https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py. 707 708 Args: 709 img_size: The input image size. 710 patch_size: The patch size. 711 embed_dim: The embedding dimension. 712 depth: The depth of the network. 713 num_heads: The number of heads. 714 ffn_ratio: The FFN rato. 715 n_storage_tokens: The number of storage (class) tokens to remove. 716 kwargs: Keyword arguments for the image encoder base class. 717 """ 718 def __init__( 719 self, 720 in_chans: int = 3, 721 img_size: int = 224, 722 patch_size: int = 16, 723 embed_dim: int = 768, 724 depth: int = 12, 725 num_heads: int = 12, 726 ffn_ratio: float = 4.0, 727 n_storage_tokens: int = 0, 728 **kwargs 729 ): 730 if not _dinov3_import_success: 731 raise RuntimeError( 732 "The vision transformer backend can only be initialized if DINOv3 is installed. " 733 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 734 "and then rerun your code." 735 ) 736 737 super().__init__( 738 in_chans=in_chans, 739 img_size=img_size, 740 patch_size=patch_size, 741 embed_dim=embed_dim, 742 depth=depth, 743 num_heads=num_heads, 744 ffn_ratio=ffn_ratio, 745 n_storage_tokens=n_storage_tokens, 746 **kwargs 747 ) 748 749 self.in_chans = in_chans 750 self.img_size = img_size 751 self.n_storage_tokens = n_storage_tokens 752 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 753 754 def forward(self, x) -> torch.Tensor: 755 756 B = x.shape[0] 757 758 x, hw_tuple = self.prepare_tokens_with_masks(x) 759 760 list_of_encoder = [] 761 for i, blk in enumerate(self.blocks): 762 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 763 x = blk(x, rope_sincos) 764 if i in self.attn_outs: 765 list_of_encoder.append(x) 766 767 x = self.norm(x) 768 x = x[:, self.n_storage_tokens + 1:].reshape( 769 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 770 ).permute(0, 3, 1, 2).contiguous() 771 772 list_of_encoder = [ 773 o[:, self.n_storage_tokens + 1:].reshape( 774 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 775 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 776 ] 777 778 return x, list_of_encoder[:3]
Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
Based on: https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
Arguments:
- img_size: The input image size.
- patch_size: The patch size.
- embed_dim: The embedding dimension.
- depth: The depth of the network.
- num_heads: The number of heads.
- ffn_ratio: The FFN rato.
- n_storage_tokens: The number of storage (class) tokens to remove.
- kwargs: Keyword arguments for the image encoder base class.
718 def __init__( 719 self, 720 in_chans: int = 3, 721 img_size: int = 224, 722 patch_size: int = 16, 723 embed_dim: int = 768, 724 depth: int = 12, 725 num_heads: int = 12, 726 ffn_ratio: float = 4.0, 727 n_storage_tokens: int = 0, 728 **kwargs 729 ): 730 if not _dinov3_import_success: 731 raise RuntimeError( 732 "The vision transformer backend can only be initialized if DINOv3 is installed. " 733 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 734 "and then rerun your code." 735 ) 736 737 super().__init__( 738 in_chans=in_chans, 739 img_size=img_size, 740 patch_size=patch_size, 741 embed_dim=embed_dim, 742 depth=depth, 743 num_heads=num_heads, 744 ffn_ratio=ffn_ratio, 745 n_storage_tokens=n_storage_tokens, 746 **kwargs 747 ) 748 749 self.in_chans = in_chans 750 self.img_size = img_size 751 self.n_storage_tokens = n_storage_tokens 752 self.attn_outs = [i for i in range(depth) if i % 3 == 2]
754 def forward(self, x) -> torch.Tensor: 755 756 B = x.shape[0] 757 758 x, hw_tuple = self.prepare_tokens_with_masks(x) 759 760 list_of_encoder = [] 761 for i, blk in enumerate(self.blocks): 762 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 763 x = blk(x, rope_sincos) 764 if i in self.attn_outs: 765 list_of_encoder.append(x) 766 767 x = self.norm(x) 768 x = x[:, self.n_storage_tokens + 1:].reshape( 769 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 770 ).permute(0, 3, 1, 2).contiguous() 771 772 list_of_encoder = [ 773 o[:, self.n_storage_tokens + 1:].reshape( 774 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 775 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 776 ] 777 778 return x, list_of_encoder[:3]
781def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 782 """Get vision transformer encoder. 783 784 Args: 785 backbone: The name of the vision transformer implementation. 786 One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3". 787 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 788 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 789 kwargs: Additional kwargs which can be expected by the vision transformer, 790 e.g. 'base_resolution' for `ViT_ScaleMAE`. 791 792 Returns: 793 The vision transformer. 794 """ 795 if backbone == "sam": 796 if model == "vit_b": 797 encoder = ViT_Sam( 798 depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4, 799 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 800 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 801 global_attn_indexes=[2, 5, 8, 11], 802 window_size=14, out_chans=256, 803 ) 804 elif model == "vit_l": 805 encoder = ViT_Sam( 806 depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4, 807 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 808 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 809 global_attn_indexes=[5, 11, 17, 23], 810 window_size=14, out_chans=256, 811 ) 812 elif model == "vit_h": 813 encoder = ViT_Sam( 814 depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4, 815 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 816 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 817 global_attn_indexes=[7, 15, 23, 31], 818 window_size=14, out_chans=256, 819 ) 820 else: 821 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 822 823 elif backbone == "cellpose_sam": 824 if model != "vit_l": 825 raise ValueError(f"'{model}' is not supported by CellposeSAM. Only 'vit_l' is supported.") 826 encoder = ViT_CellposeSAM(ps=8, bsize=img_size) 827 828 elif backbone == "sam2": 829 if model == "hvit_t": 830 encoder = ViT_Sam2( 831 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 832 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 833 ) 834 elif model == "hvit_s": 835 encoder = ViT_Sam2( 836 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 837 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 838 ) 839 elif model == "hvit_b": 840 encoder = ViT_Sam2( 841 img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 842 ) 843 elif model == "hvit_l": 844 encoder = ViT_Sam2( 845 img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 846 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 847 ) 848 else: 849 raise ValueError( 850 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 851 ) 852 853 elif backbone == "sam3": 854 if model != "vit_pe": 855 raise ValueError( 856 "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration." 857 ) 858 859 encoder = ViT_Sam3( 860 img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16, 861 mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True, 862 tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True, 863 use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True, 864 ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None, 865 ) 866 867 elif backbone == "mae": 868 if model == "vit_b": 869 encoder = ViT_MAE( 870 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 871 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 872 ) 873 elif model == "vit_l": 874 encoder = ViT_MAE( 875 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 876 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 877 ) 878 elif model == "vit_h": 879 encoder = ViT_MAE( 880 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 881 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 882 ) 883 else: 884 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 885 886 elif backbone == "scalemae": 887 base_resolution = kwargs.get("base_resolution", 2.5) 888 889 if model == "vit_b": 890 encoder = ViT_ScaleMAE( 891 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 892 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 893 base_resolution=base_resolution, 894 ) 895 elif model == "vit_l": 896 encoder = ViT_ScaleMAE( 897 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 898 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 899 base_resolution=base_resolution, 900 ) 901 elif model == "vit_h": 902 encoder = ViT_ScaleMAE( 903 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 904 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 905 base_resolution=base_resolution, 906 ) 907 else: 908 raise ValueError( 909 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 910 ) 911 912 elif backbone == "dinov2": 913 block_fn = partial(Block, attn_class=MemEffAttention) 914 msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>." 915 916 if model.startswith("vit_s"): 917 assert model in ["vit_s", "vit_s_reg4"], msg 918 encoder = ViT_DINOv2( 919 img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 920 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 921 num_register_tokens=4 if model.endswith("_reg4") else 0, 922 ) 923 elif model.startswith("vit_b"): 924 assert model in ["vit_b", "vit_b_reg4"], msg 925 encoder = ViT_DINOv2( 926 img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 927 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 928 num_register_tokens=4 if model.endswith("_reg4") else 0, 929 ) 930 elif model.startswith("vit_l"): 931 assert model in ["vit_l", "vit_l_reg4"], msg 932 encoder = ViT_DINOv2( 933 img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 934 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 935 num_register_tokens=4 if model.endswith("_reg4") else 0, 936 ) 937 elif model.startswith("vit_g"): 938 assert model in ["vit_g", "vit_g_reg4"], msg 939 encoder = ViT_DINOv2( 940 img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, 941 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 942 num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu", 943 ) 944 else: 945 raise ValueError( 946 f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported." 947 ) 948 949 elif backbone == "dinov3": 950 951 if model == "vit_s": 952 encoder = ViT_DINOv3( 953 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 954 num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 955 ) 956 elif model == "vit_s+": 957 encoder = ViT_DINOv3( 958 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 959 num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", 960 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 961 ) 962 963 elif model == "vit_b": 964 encoder = ViT_DINOv3( 965 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", 966 layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 967 ) 968 elif model == "vit_l": 969 encoder = ViT_DINOv3( 970 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 971 depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16", 972 n_storage_tokens=4, mask_k_bias=True, 973 ) 974 elif model == "vit_l+": 975 encoder = ViT_DINOv3( 976 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 977 depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 978 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 979 ) 980 elif model == "vit_h+": 981 encoder = ViT_DINOv3( 982 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280, 983 depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 984 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 985 ) 986 elif model == "vit_7b": 987 encoder = ViT_DINOv3( 988 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096, 989 depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, 990 norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True, 991 untie_global_and_local_cls_norm=True, 992 ) 993 else: 994 raise ValueError( 995 f"'{model}' is not supported by DINOv3. Currently, " 996 " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported." 997 ) 998 999 else: 1000 raise ValueError( 1001 "The 'UNETR' supported backbones are 'sam', 'cellpose_sam', 'sam2', 'sam3', " 1002 "'mae', 'scalemae', 'dinov2' or 'dinov3'. Please choose one of them." 1003 ) 1004 1005 return encoder
Get vision transformer encoder.
Arguments:
- backbone: The name of the vision transformer implementation. One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3".
- model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- kwargs: Additional kwargs which can be expected by the vision transformer,
e.g. 'base_resolution' for
ViT_ScaleMAE.
Returns:
The vision transformer.