torch_em.model.vit
1import math 2from functools import partial 3from typing import Tuple, List 4 5import torch 6import torch.nn as nn 7 8# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should 9# only be optional dependencies for torch_em 10try: 11 from segment_anything.modeling import ImageEncoderViT 12 _sam_import_success = True 13except ImportError: 14 ImageEncoderViT = object 15 _sam_import_success = False 16 17try: 18 from timm.models.vision_transformer import VisionTransformer, PatchEmbed 19 _timm_import_success = True 20except ImportError: 21 VisionTransformer = object 22 PatchEmbed = object 23 _timm_import_success = False 24 25try: 26 from sam2.modeling.backbones.hieradet import Hiera 27 from sam2.modeling.position_encoding import PositionEmbeddingSine 28 from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck 29 _sam2_import_success = True 30except ImportError: 31 ImageEncoder = object 32 _sam2_import_success = False 33 34try: 35 from dinov2.models.vision_transformer import DinoVisionTransformer as DinoV2VisionTransformer 36 from dinov2.layers import MemEffAttention, NestedTensorBlock as Block 37 _dinov2_import_success = True 38except ImportError: 39 DinoV2VisionTransformer = object 40 _dinov2_import_success = False 41 42try: 43 from dinov3.models.vision_transformer import DinoVisionTransformer as DinoV3VisionTransformer 44 _dinov3_import_success = True 45except ImportError: 46 DinoV3VisionTransformer = object 47 _dinov3_import_success = False 48 49 50try: 51 from sam3.model.vitdet import ViT as SAM3ViT, get_abs_pos 52 _sam3_import_success = True 53except ImportError: 54 SAM3ViT = object 55 _sam3_import_success = False 56 57 58class ViT_Sam(ImageEncoderViT): 59 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 60 61 Based on: 62 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 63 64 Args: 65 in_chans: The number of input channels. 66 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 67 global_attn_indexes: The global attention indices. 68 apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head. 69 kwargs: Keyword arguments for the image encoder base class. 70 """ 71 def __init__( 72 self, 73 in_chans: int = 3, 74 embed_dim: int = 768, 75 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 76 apply_neck: bool = False, 77 **kwargs, 78 ) -> None: 79 if not _sam_import_success: 80 raise RuntimeError( 81 "The vision transformer backend can only be initialized if segment anything is installed. " 82 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 83 "and then rerun your code." 84 ) 85 86 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 87 self.chunks_for_projection = global_attn_indexes 88 self.in_chans = in_chans 89 self.embed_dim = embed_dim 90 self.apply_neck = apply_neck 91 92 def forward(self, x: torch.Tensor) -> torch.Tensor: 93 """Apply the vision transformer to input data. 94 95 Args: 96 x: The input data. 97 98 Returns: 99 The vision transformer output. 100 """ 101 x = self.patch_embed(x) 102 if self.pos_embed is not None: 103 x = x + self.pos_embed 104 105 list_from_encoder = [] 106 for i, blk in enumerate(self.blocks): 107 x = blk(x) 108 if i in self.chunks_for_projection: 109 list_from_encoder.append(x) 110 111 x = x.permute(0, 3, 1, 2) 112 113 if self.apply_neck: 114 x = self.neck(x) 115 116 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 117 return x, list_from_encoder[:3] 118 119 120class ViT_CellposeSAM(nn.Module): 121 """Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x). 122 123 This replicates CellposeSAM's actual initialization: instantiate SAM's ``ImageEncoderViT`` via 124 ``sam_model_registry``, then modify the patch embedding, position embeddings, and set global attention. 125 This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading 126 without any interpolation. 127 128 Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py 129 130 NOTE: The pretrained CellposeSAM model uses ``vit_l`` exclusively. 131 132 Args: 133 ps: The patch size (default for CellposeSAM is 8). 134 bsize: The input image size (default for CellposeSAM is 256). 135 apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head. 136 """ 137 def __init__(self, ps: int = 8, bsize: int = 256, apply_neck: bool = True) -> None: 138 super().__init__() 139 140 if not _sam_import_success: 141 raise RuntimeError( 142 "The vision transformer backend can only be initialized if segment anything is installed. " 143 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 144 "and then rerun your code." 145 ) 146 147 from segment_anything import sam_model_registry 148 149 # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer). 150 encoder = sam_model_registry["vit_l"](None).image_encoder 151 152 w = encoder.patch_embed.proj.weight.detach() 153 nchan = w.shape[0] 154 155 # CellPoseSAM changes the patch size from 16 to 'ps'. 156 encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps) 157 encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps] 158 159 # Next, they subsample position embeddings for the new patch size and input resolution. 160 ds = (1024 // 16) // (bsize // ps) 161 encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True) 162 163 # Finally, they set all blocks to global attention. 164 for blk in encoder.blocks: 165 blk.window_size = 0 166 167 # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping). 168 self.patch_embed = encoder.patch_embed 169 self.pos_embed = encoder.pos_embed 170 self.blocks = encoder.blocks 171 self.neck = encoder.neck 172 173 # Additional attributes expected by UNETR. 174 self.embed_dim = nchan 175 self.img_size = bsize 176 self.in_chans = 3 177 self.apply_neck = apply_neck 178 179 # Feature extraction at evenly-spaced depths. 180 depth = len(self.blocks) 181 _chunks = depth // 4 182 self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1] 183 184 def forward(self, x: torch.Tensor) -> torch.Tensor: 185 """Apply the vision transformer to input data. 186 187 Args: 188 x: The input data. 189 190 Returns: 191 The vision transformer output. 192 """ 193 x = self.patch_embed(x) 194 if self.pos_embed is not None: 195 x = x + self.pos_embed 196 197 list_from_encoder = [] 198 for i, blk in enumerate(self.blocks): 199 x = blk(x) 200 if i in self.chunks_for_projection: 201 list_from_encoder.append(x) 202 203 x = x.permute(0, 3, 1, 2) 204 205 if self.apply_neck: 206 x = self.neck(x) 207 208 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 209 return x, list_from_encoder[:3] 210 211 212class ViT_MAE(VisionTransformer): 213 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 214 215 Based on: 216 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 217 218 Args: 219 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 220 in_chans: The number of input channels. 221 depth: The depth of the vision transformer. 222 kwargs: Additional keyword arguments for the vision transformer base class. 223 """ 224 def __init__( 225 self, 226 img_size: int = 1024, # chosen to match our experiments with segment anything 227 in_chans: int = 3, 228 depth: int = 12, 229 **kwargs 230 ): 231 if not _timm_import_success: 232 raise RuntimeError( 233 "The vision transformer backend can only be initialized if timm is installed. " 234 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 235 "and then rerun your code" 236 ) 237 super().__init__(img_size=img_size, depth=depth, **kwargs) 238 self.img_size = img_size 239 self.in_chans = in_chans 240 self.depth = depth 241 242 def convert_to_expected_dim(self, inputs_): 243 """@private 244 """ 245 inputs_ = inputs_[:, 1:, :] # removing the class tokens 246 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 247 rdim = inputs_.shape[1] 248 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 249 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 250 inputs_ = inputs_.permute(0, 3, 1, 2) 251 return inputs_ 252 253 def forward_features(self, x): 254 """@private 255 """ 256 B = x.shape[0] 257 x = self.patch_embed(x) 258 259 cls_tokens = self.cls_token.expand(B, -1, -1) 260 x = torch.cat((cls_tokens, x), dim=1) 261 262 x = x + self.pos_embed 263 x = self.pos_drop(x) 264 265 # chunks obtained for getting the projections for conjuctions with upsampling blocks 266 _chunks = int(self.depth / 4) 267 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 268 269 list_from_encoder = [] 270 for i, blk in enumerate(self.blocks): 271 x = blk(x) 272 if i in chunks_for_projection: 273 list_from_encoder.append(self.convert_to_expected_dim(x)) 274 275 x = self.convert_to_expected_dim(x) 276 return x, list_from_encoder[:3] 277 278 def forward(self, x: torch.Tensor) -> torch.Tensor: 279 """Apply the vision transformer to input data. 280 281 Args: 282 x: The input data. 283 284 Returns: 285 The vision transformer output. 286 """ 287 x, list_from_encoder = self.forward_features(x) 288 return x, list_from_encoder 289 290 291class ViT_Sam2(ImageEncoder): 292 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 293 294 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 295 296 Args: 297 backbone_channel_list: The channels throughout the entire backbone. 298 embed_dim: The initial embedding dimension. 299 num_heads: The initial number of heads. 300 stages: The number of blocks per stage. 301 global_att_blocks: The parameter to decide which blocks have global attention. 302 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 303 window_spec: The window size per stage, when not using global attention. 304 scalp: The count of lowest resolution features to discard. 305 """ 306 def __init__( 307 self, 308 backbone_channel_list: List[int], 309 img_size: int = 1024, 310 embed_dim: int = 96, 311 num_heads: int = 1, 312 stages: Tuple[int, ...] = (2, 3, 16, 3), 313 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 314 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 315 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 316 scalp: int = 1, 317 **kwargs 318 ): 319 if not _sam2_import_success: 320 raise RuntimeError( 321 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 322 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 323 "and then rerun your code" 324 ) 325 326 trunk = Hiera( 327 embed_dim=embed_dim, 328 num_heads=num_heads, 329 stages=stages, 330 global_att_blocks=global_att_blocks, 331 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 332 window_spec=window_spec, 333 ) 334 neck = FpnNeck( 335 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 336 d_model=256, 337 backbone_channel_list=backbone_channel_list, 338 fpn_top_down_levels=[2, 3], 339 fpn_interp_model="nearest", 340 ) 341 342 super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs) 343 self.scalp = scalp 344 self.embed_dim = embed_dim 345 self.img_size = img_size 346 347 def forward(self, x: torch.Tensor): 348 # The forward pass throught the backbone. 349 features, pos = self.neck(self.trunk(x)) 350 if self.scalp > 0: # This discard the "n" lowest resolution features. 351 features, pos = features[:-self.scalp], pos[:-self.scalp] 352 353 return features[-1], features 354 355 356class ViT_Sam3(SAM3ViT): 357 """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719). 358 359 Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py. 360 361 Args: 362 img_size: The input image size. 363 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 364 kwargs: Keyword arguments for the image encoder base class. 365 """ 366 def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs): 367 if not _sam3_import_success: 368 raise RuntimeError( 369 "The vision transformer backend can only be initialized if segment anything 3 is installed. " 370 "Please install segment anything 3 from https://github.com/facebookresearch/sam3 " 371 "and then rerun your code" 372 ) 373 374 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 375 self.img_size = img_size 376 self.embed_dim = embed_dim 377 378 def forward_features(self, x): 379 """@private 380 """ 381 x = self.patch_embed(x) 382 h, w = x.shape[1], x.shape[2] 383 384 s = 0 385 if self.retain_cls_token: 386 # If the 'cls_token' is retained, we don't maintain the spatial shape. 387 x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1) 388 s = 1 389 390 if self.pos_embed is not None: 391 x = x + get_abs_pos( 392 self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos, 393 ) 394 395 x = self.ln_pre(x) 396 397 list_from_encoder = [] 398 for i, blk in enumerate(self.blocks): 399 if self.use_act_checkpoint and self.training: 400 x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False) 401 else: 402 x = blk(x) 403 404 x = self._convert_to_expected_dim(x, i, s) 405 406 if i in self.full_attn_ids: 407 list_from_encoder.append(x) 408 409 return x, list_from_encoder 410 411 def _convert_to_expected_dim(self, x, i, s): 412 if (i == self.full_attn_ids[-1]) or ( 413 self.return_interm_layers and i in self.full_attn_ids 414 ): 415 if i == self.full_attn_ids[-1]: 416 x = self.ln_post(x) 417 418 feats = x[:, s:] 419 if feats.ndim == 4: 420 feats = feats.permute(0, 3, 1, 2) 421 else: 422 assert feats.ndim == 3 423 h = w = math.sqrt(feats.shape[1]) 424 feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2) 425 return feats 426 427 else: 428 return x 429 430 def forward(self, x: torch.Tensor): 431 """Apply the vision transformer to input data. 432 433 Args: 434 x: The input data. 435 436 Returns: 437 The vision transformer output. 438 """ 439 x, list_from_encoder = self.forward_features(x) 440 return x, list_from_encoder 441 442# 443# Utilities for ScaleMAE's ViT 444# 445 446 447class CustomCompose: 448 def __init__(self, rescale_transform, other_transforms, src_transform): 449 self.rescale_transform = rescale_transform 450 self.other_transforms = other_transforms 451 self.src_transform = src_transform 452 453 def __call__(self, x, valid_masks=None): 454 if valid_masks is not None: 455 nodata = (x * (1 - valid_masks.float())).max() 456 x_aug = self.rescale_transform(x) 457 parms = self.rescale_transform._params 458 459 # sanity check, comment if this is working 460 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 461 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 462 463 if valid_masks is not None: 464 valid_masks = x_aug != nodata 465 _, c, h, w = x_aug.shape 466 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 467 else: 468 zero_ratio = -1 469 470 if self.other_transforms: 471 x_aug = self.other_transforms(x_aug) 472 x_src = self.src_transform(x_aug) 473 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 474 475 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 476 # assert (dx == dy).all() 477 478 h, w = x_aug.shape[-2:] 479 # assert h == w 480 481 return x_aug, x_src, dx / h, zero_ratio, valid_masks 482 483 484def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 485 """ 486 grid_size: int of the grid height and width 487 res: array of size n, representing the resolution of a pixel (say, in meters), 488 return: 489 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 490 """ 491 # res = torch.FloatTensor(res).to(device) 492 res = res.to(device) 493 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 494 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 495 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 496 grid = torch.stack(grid, dim=0) # 2 x h x w 497 498 # grid = grid.reshape([2, 1, grid_size, grid_size]) 499 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 500 _, n, h, w = grid.shape 501 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 502 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 503 if cls_token: 504 pos_embed = torch.cat( 505 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 506 ) 507 508 return pos_embed 509 510 511def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 512 assert embed_dim % 2 == 0 513 514 # use half of dimensions to encode grid_h 515 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 516 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 517 518 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 519 return emb 520 521 522def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 523 """ 524 embed_dim: output dimension for each position 525 pos: a list of positions to be encoded: size (M,) 526 out: (M, D) 527 """ 528 assert embed_dim % 2 == 0 529 # old_shape = pos 530 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 531 omega /= embed_dim / 2.0 532 omega = 1.0 / 10000**omega # (D/2,) 533 534 pos = pos.reshape(-1) # (M,) 535 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 536 537 emb_sin = torch.sin(out) # (M, D/2) 538 emb_cos = torch.cos(out) # (M, D/2) 539 540 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 541 return emb 542 543 544class PatchEmbedUnSafe(PatchEmbed): 545 """Image to Patch Embedding""" 546 547 def forward(self, x): 548 B, C, H, W = x.shape 549 550 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 551 # assert H == self.img_size[0] and W == self.img_size[1], \ 552 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 553 554 x = self.proj(x).flatten(2).transpose(1, 2) 555 return x 556 557 558class ViT_ScaleMAE(VisionTransformer): 559 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 560 561 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 562 the model on a different zoom factor dataset. 563 """ 564 565 def __init__( 566 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 567 ): 568 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 569 self.img_size = img_size 570 self.in_chans = in_chans 571 self.depth = depth 572 self.base_resolution = base_resolution 573 574 self.patch_embed = PatchEmbedUnSafe( 575 img_size=img_size, 576 patch_size=patch_size, 577 in_chans=in_chans, 578 embed_dim=embed_dim, 579 ) 580 581 def transform_inputs(self, x): 582 import kornia.augmentation as K 583 from kornia.constants import Resample 584 585 self._transforms = CustomCompose( 586 rescale_transform=K.RandomResizedCrop( 587 (448, 448), 588 ratio=(1.0, 1.0), 589 scale=(1.0, 1.0), 590 resample=Resample.BICUBIC.name, 591 ), 592 other_transforms=None, 593 src_transform=K.Resize((224, 224)), 594 ) 595 x, _, ratios, _, _ = self._transforms(x) 596 input_res = ratios * self.base_resolution 597 return x, input_res 598 599 def convert_to_expected_dim(self, x): 600 inputs_ = x[:, 1:, :] # removing the class tokens 601 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 602 rdim = inputs_.shape[1] 603 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 604 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 605 inputs_ = inputs_.permute(0, 3, 1, 2) 606 return inputs_ 607 608 def forward_features(self, x): 609 x, input_res = self.transform_inputs(x) 610 611 B, _, h, w = x.shape 612 x = self.patch_embed(x) 613 614 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 615 pos_embed = get_2d_sincos_pos_embed_with_resolution( 616 x.shape[-1], 617 int(num_patches ** 0.5), 618 input_res, 619 cls_token=True, 620 device=x.device, 621 ) 622 623 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 624 x = torch.cat((cls_tokens, x), dim=1) 625 x = x + pos_embed 626 x = self.pos_drop(x) 627 628 # chunks obtained for getting the projections for conjuctions with upsampling blocks 629 _chunks = int(self.depth / 4) 630 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 631 632 list_from_encoder = [] 633 for i, blk in enumerate(self.blocks): 634 x = blk(x) 635 if i in chunks_for_projection: 636 list_from_encoder.append(self.convert_to_expected_dim(x)) 637 638 x = self.convert_to_expected_dim(x) 639 640 return x, list_from_encoder 641 642 def forward(self, x): 643 x, list_from_encoder = self.forward_features(x) 644 return x, list_from_encoder 645 646 647class ViT_DINOv2(DinoV2VisionTransformer): 648 """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193). 649 650 Based on: 651 https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py. 652 653 Args: 654 img_size: The input image size. 655 patch_size: The patch size. 656 depth: The depth of the network. 657 num_register_tokens: The number of registers added (in addition to the class tokens). 658 It's important to know for ViTs trained with registers, to remove them at the end. 659 """ 660 def __init__( 661 self, 662 img_size: int = 224, 663 patch_size: int = 16, 664 depth: int = 12, 665 num_register_tokens: int = 0, 666 **kwargs 667 ): 668 if not _dinov2_import_success: 669 raise RuntimeError( 670 "The vision transformer backend can only be initialized if DINOv2 is installed. " 671 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 672 "and then rerun your code." 673 ) 674 675 super().__init__( 676 img_size=img_size, 677 depth=depth, 678 patch_size=patch_size, 679 num_register_tokens=num_register_tokens, 680 **kwargs 681 ) 682 683 self.img_size = img_size 684 self.num_register_tokens = num_register_tokens 685 self.patch_size = patch_size 686 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 687 688 def forward(self, x, masks=None) -> torch.Tensor: 689 690 B = x.shape[0] 691 692 x = self.prepare_tokens_with_masks(x) 693 694 list_of_encoder = [] 695 for i, blk in enumerate(self.blocks): 696 x = blk(x) 697 if i in self.attn_outs: 698 list_of_encoder.append(x) 699 700 x = self.norm(x) 701 x = x[:, self.num_register_tokens + 1:].reshape( 702 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 703 ).permute(0, 3, 1, 2).contiguous() 704 705 list_of_encoder = [ 706 o[:, self.num_register_tokens + 1:].reshape( 707 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 708 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 709 ] 710 711 return x, list_of_encoder[:3] 712 713 714class ViT_DINOv3(DinoV3VisionTransformer): 715 """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104). 716 717 Based on: 718 https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py. 719 720 Args: 721 img_size: The input image size. 722 patch_size: The patch size. 723 embed_dim: The embedding dimension. 724 depth: The depth of the network. 725 num_heads: The number of heads. 726 ffn_ratio: The FFN rato. 727 n_storage_tokens: The number of storage (class) tokens to remove. 728 kwargs: Keyword arguments for the image encoder base class. 729 """ 730 def __init__( 731 self, 732 in_chans: int = 3, 733 img_size: int = 224, 734 patch_size: int = 16, 735 embed_dim: int = 768, 736 depth: int = 12, 737 num_heads: int = 12, 738 ffn_ratio: float = 4.0, 739 n_storage_tokens: int = 0, 740 **kwargs 741 ): 742 if not _dinov3_import_success: 743 raise RuntimeError( 744 "The vision transformer backend can only be initialized if DINOv3 is installed. " 745 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 746 "and then rerun your code." 747 ) 748 749 super().__init__( 750 in_chans=in_chans, 751 img_size=img_size, 752 patch_size=patch_size, 753 embed_dim=embed_dim, 754 depth=depth, 755 num_heads=num_heads, 756 ffn_ratio=ffn_ratio, 757 n_storage_tokens=n_storage_tokens, 758 **kwargs 759 ) 760 761 self.in_chans = in_chans 762 self.img_size = img_size 763 self.n_storage_tokens = n_storage_tokens 764 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 765 766 def forward(self, x) -> torch.Tensor: 767 768 B = x.shape[0] 769 770 x, hw_tuple = self.prepare_tokens_with_masks(x) 771 772 list_of_encoder = [] 773 for i, blk in enumerate(self.blocks): 774 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 775 x = blk(x, rope_sincos) 776 if i in self.attn_outs: 777 list_of_encoder.append(x) 778 779 x = self.norm(x) 780 x = x[:, self.n_storage_tokens + 1:].reshape( 781 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 782 ).permute(0, 3, 1, 2).contiguous() 783 784 list_of_encoder = [ 785 o[:, self.n_storage_tokens + 1:].reshape( 786 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 787 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 788 ] 789 790 return x, list_of_encoder[:3] 791 792 793def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 794 """Get vision transformer encoder. 795 796 Args: 797 backbone: The name of the vision transformer implementation. 798 One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3". 799 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 800 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 801 kwargs: Additional kwargs which can be expected by the vision transformer, 802 e.g. 'base_resolution' for `ViT_ScaleMAE`. 803 804 Returns: 805 The vision transformer. 806 """ 807 if backbone == "sam": 808 if model == "vit_b": 809 encoder = ViT_Sam( 810 depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4, 811 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 812 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 813 global_attn_indexes=[2, 5, 8, 11], 814 window_size=14, out_chans=256, 815 ) 816 elif model == "vit_l": 817 encoder = ViT_Sam( 818 depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4, 819 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 820 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 821 global_attn_indexes=[5, 11, 17, 23], 822 window_size=14, out_chans=256, 823 ) 824 elif model == "vit_h": 825 encoder = ViT_Sam( 826 depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4, 827 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 828 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 829 global_attn_indexes=[7, 15, 23, 31], 830 window_size=14, out_chans=256, 831 ) 832 else: 833 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 834 835 elif backbone == "cellpose_sam": 836 if model != "vit_l": 837 raise ValueError(f"'{model}' is not supported by CellposeSAM. Only 'vit_l' is supported.") 838 encoder = ViT_CellposeSAM(ps=8, bsize=img_size) 839 840 elif backbone == "sam2": 841 if model == "hvit_t": 842 encoder = ViT_Sam2( 843 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 844 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 845 ) 846 elif model == "hvit_s": 847 encoder = ViT_Sam2( 848 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 849 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 850 ) 851 elif model == "hvit_b": 852 encoder = ViT_Sam2( 853 img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 854 ) 855 elif model == "hvit_l": 856 encoder = ViT_Sam2( 857 img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 858 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 859 ) 860 else: 861 raise ValueError( 862 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 863 ) 864 865 elif backbone == "sam3": 866 if model != "vit_pe": 867 raise ValueError( 868 "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration." 869 ) 870 871 encoder = ViT_Sam3( 872 img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16, 873 mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True, 874 tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True, 875 use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True, 876 ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None, 877 ) 878 879 elif backbone == "mae": 880 if model == "vit_b": 881 encoder = ViT_MAE( 882 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 883 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 884 ) 885 elif model == "vit_l": 886 encoder = ViT_MAE( 887 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 888 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 889 ) 890 elif model == "vit_h": 891 encoder = ViT_MAE( 892 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 893 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 894 ) 895 else: 896 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 897 898 elif backbone == "scalemae": 899 base_resolution = kwargs.get("base_resolution", 2.5) 900 901 if model == "vit_b": 902 encoder = ViT_ScaleMAE( 903 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 904 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 905 base_resolution=base_resolution, 906 ) 907 elif model == "vit_l": 908 encoder = ViT_ScaleMAE( 909 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 910 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 911 base_resolution=base_resolution, 912 ) 913 elif model == "vit_h": 914 encoder = ViT_ScaleMAE( 915 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 916 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 917 base_resolution=base_resolution, 918 ) 919 else: 920 raise ValueError( 921 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 922 ) 923 924 elif backbone == "dinov2": 925 block_fn = partial(Block, attn_class=MemEffAttention) 926 msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>." 927 928 if model.startswith("vit_s"): 929 assert model in ["vit_s", "vit_s_reg4"], msg 930 encoder = ViT_DINOv2( 931 img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 932 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 933 num_register_tokens=4 if model.endswith("_reg4") else 0, 934 ) 935 elif model.startswith("vit_b"): 936 assert model in ["vit_b", "vit_b_reg4"], msg 937 encoder = ViT_DINOv2( 938 img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 939 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 940 num_register_tokens=4 if model.endswith("_reg4") else 0, 941 ) 942 elif model.startswith("vit_l"): 943 assert model in ["vit_l", "vit_l_reg4"], msg 944 encoder = ViT_DINOv2( 945 img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 946 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 947 num_register_tokens=4 if model.endswith("_reg4") else 0, 948 ) 949 elif model.startswith("vit_g"): 950 assert model in ["vit_g", "vit_g_reg4"], msg 951 encoder = ViT_DINOv2( 952 img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, 953 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 954 num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu", 955 ) 956 else: 957 raise ValueError( 958 f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported." 959 ) 960 961 elif backbone == "dinov3": 962 963 if model == "vit_s": 964 encoder = ViT_DINOv3( 965 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 966 num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 967 ) 968 elif model == "vit_s+": 969 encoder = ViT_DINOv3( 970 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 971 num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", 972 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 973 ) 974 975 elif model == "vit_b": 976 encoder = ViT_DINOv3( 977 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", 978 layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 979 ) 980 elif model == "vit_l": 981 encoder = ViT_DINOv3( 982 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 983 depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16", 984 n_storage_tokens=4, mask_k_bias=True, 985 ) 986 elif model == "vit_l+": 987 encoder = ViT_DINOv3( 988 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 989 depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 990 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 991 ) 992 elif model == "vit_h+": 993 encoder = ViT_DINOv3( 994 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280, 995 depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 996 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 997 ) 998 elif model == "vit_7b": 999 encoder = ViT_DINOv3( 1000 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096, 1001 depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, 1002 norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True, 1003 untie_global_and_local_cls_norm=True, 1004 ) 1005 else: 1006 raise ValueError( 1007 f"'{model}' is not supported by DINOv3. Currently, " 1008 " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported." 1009 ) 1010 1011 else: 1012 raise ValueError( 1013 "The 'UNETR' supported backbones are 'sam', 'cellpose_sam', 'sam2', 'sam3', " 1014 "'mae', 'scalemae', 'dinov2' or 'dinov3'. Please choose one of them." 1015 ) 1016 1017 return encoder
59class ViT_Sam(ImageEncoderViT): 60 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 61 62 Based on: 63 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 64 65 Args: 66 in_chans: The number of input channels. 67 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 68 global_attn_indexes: The global attention indices. 69 apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head. 70 kwargs: Keyword arguments for the image encoder base class. 71 """ 72 def __init__( 73 self, 74 in_chans: int = 3, 75 embed_dim: int = 768, 76 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 77 apply_neck: bool = False, 78 **kwargs, 79 ) -> None: 80 if not _sam_import_success: 81 raise RuntimeError( 82 "The vision transformer backend can only be initialized if segment anything is installed. " 83 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 84 "and then rerun your code." 85 ) 86 87 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 88 self.chunks_for_projection = global_attn_indexes 89 self.in_chans = in_chans 90 self.embed_dim = embed_dim 91 self.apply_neck = apply_neck 92 93 def forward(self, x: torch.Tensor) -> torch.Tensor: 94 """Apply the vision transformer to input data. 95 96 Args: 97 x: The input data. 98 99 Returns: 100 The vision transformer output. 101 """ 102 x = self.patch_embed(x) 103 if self.pos_embed is not None: 104 x = x + self.pos_embed 105 106 list_from_encoder = [] 107 for i, blk in enumerate(self.blocks): 108 x = blk(x) 109 if i in self.chunks_for_projection: 110 list_from_encoder.append(x) 111 112 x = x.permute(0, 3, 1, 2) 113 114 if self.apply_neck: 115 x = self.neck(x) 116 117 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 118 return x, list_from_encoder[:3]
Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
Arguments:
- in_chans: The number of input channels.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- global_attn_indexes: The global attention indices.
- apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head.
- kwargs: Keyword arguments for the image encoder base class.
72 def __init__( 73 self, 74 in_chans: int = 3, 75 embed_dim: int = 768, 76 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 77 apply_neck: bool = False, 78 **kwargs, 79 ) -> None: 80 if not _sam_import_success: 81 raise RuntimeError( 82 "The vision transformer backend can only be initialized if segment anything is installed. " 83 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 84 "and then rerun your code." 85 ) 86 87 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 88 self.chunks_for_projection = global_attn_indexes 89 self.in_chans = in_chans 90 self.embed_dim = embed_dim 91 self.apply_neck = apply_neck
93 def forward(self, x: torch.Tensor) -> torch.Tensor: 94 """Apply the vision transformer to input data. 95 96 Args: 97 x: The input data. 98 99 Returns: 100 The vision transformer output. 101 """ 102 x = self.patch_embed(x) 103 if self.pos_embed is not None: 104 x = x + self.pos_embed 105 106 list_from_encoder = [] 107 for i, blk in enumerate(self.blocks): 108 x = blk(x) 109 if i in self.chunks_for_projection: 110 list_from_encoder.append(x) 111 112 x = x.permute(0, 3, 1, 2) 113 114 if self.apply_neck: 115 x = self.neck(x) 116 117 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 118 return x, list_from_encoder[:3]
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
121class ViT_CellposeSAM(nn.Module): 122 """Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x). 123 124 This replicates CellposeSAM's actual initialization: instantiate SAM's ``ImageEncoderViT`` via 125 ``sam_model_registry``, then modify the patch embedding, position embeddings, and set global attention. 126 This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading 127 without any interpolation. 128 129 Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py 130 131 NOTE: The pretrained CellposeSAM model uses ``vit_l`` exclusively. 132 133 Args: 134 ps: The patch size (default for CellposeSAM is 8). 135 bsize: The input image size (default for CellposeSAM is 256). 136 apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head. 137 """ 138 def __init__(self, ps: int = 8, bsize: int = 256, apply_neck: bool = True) -> None: 139 super().__init__() 140 141 if not _sam_import_success: 142 raise RuntimeError( 143 "The vision transformer backend can only be initialized if segment anything is installed. " 144 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 145 "and then rerun your code." 146 ) 147 148 from segment_anything import sam_model_registry 149 150 # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer). 151 encoder = sam_model_registry["vit_l"](None).image_encoder 152 153 w = encoder.patch_embed.proj.weight.detach() 154 nchan = w.shape[0] 155 156 # CellPoseSAM changes the patch size from 16 to 'ps'. 157 encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps) 158 encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps] 159 160 # Next, they subsample position embeddings for the new patch size and input resolution. 161 ds = (1024 // 16) // (bsize // ps) 162 encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True) 163 164 # Finally, they set all blocks to global attention. 165 for blk in encoder.blocks: 166 blk.window_size = 0 167 168 # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping). 169 self.patch_embed = encoder.patch_embed 170 self.pos_embed = encoder.pos_embed 171 self.blocks = encoder.blocks 172 self.neck = encoder.neck 173 174 # Additional attributes expected by UNETR. 175 self.embed_dim = nchan 176 self.img_size = bsize 177 self.in_chans = 3 178 self.apply_neck = apply_neck 179 180 # Feature extraction at evenly-spaced depths. 181 depth = len(self.blocks) 182 _chunks = depth // 4 183 self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1] 184 185 def forward(self, x: torch.Tensor) -> torch.Tensor: 186 """Apply the vision transformer to input data. 187 188 Args: 189 x: The input data. 190 191 Returns: 192 The vision transformer output. 193 """ 194 x = self.patch_embed(x) 195 if self.pos_embed is not None: 196 x = x + self.pos_embed 197 198 list_from_encoder = [] 199 for i, blk in enumerate(self.blocks): 200 x = blk(x) 201 if i in self.chunks_for_projection: 202 list_from_encoder.append(x) 203 204 x = x.permute(0, 3, 1, 2) 205 206 if self.apply_neck: 207 x = self.neck(x) 208 209 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 210 return x, list_from_encoder[:3]
Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x).
This replicates CellposeSAM's actual initialization: instantiate SAM's ImageEncoderViT via
sam_model_registry, then modify the patch embedding, position embeddings, and set global attention.
This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading
without any interpolation.
Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py
NOTE: The pretrained CellposeSAM model uses vit_l exclusively.
Arguments:
- ps: The patch size (default for CellposeSAM is 8).
- bsize: The input image size (default for CellposeSAM is 256).
- apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head.
138 def __init__(self, ps: int = 8, bsize: int = 256, apply_neck: bool = True) -> None: 139 super().__init__() 140 141 if not _sam_import_success: 142 raise RuntimeError( 143 "The vision transformer backend can only be initialized if segment anything is installed. " 144 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 145 "and then rerun your code." 146 ) 147 148 from segment_anything import sam_model_registry 149 150 # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer). 151 encoder = sam_model_registry["vit_l"](None).image_encoder 152 153 w = encoder.patch_embed.proj.weight.detach() 154 nchan = w.shape[0] 155 156 # CellPoseSAM changes the patch size from 16 to 'ps'. 157 encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps) 158 encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps] 159 160 # Next, they subsample position embeddings for the new patch size and input resolution. 161 ds = (1024 // 16) // (bsize // ps) 162 encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True) 163 164 # Finally, they set all blocks to global attention. 165 for blk in encoder.blocks: 166 blk.window_size = 0 167 168 # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping). 169 self.patch_embed = encoder.patch_embed 170 self.pos_embed = encoder.pos_embed 171 self.blocks = encoder.blocks 172 self.neck = encoder.neck 173 174 # Additional attributes expected by UNETR. 175 self.embed_dim = nchan 176 self.img_size = bsize 177 self.in_chans = 3 178 self.apply_neck = apply_neck 179 180 # Feature extraction at evenly-spaced depths. 181 depth = len(self.blocks) 182 _chunks = depth // 4 183 self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
185 def forward(self, x: torch.Tensor) -> torch.Tensor: 186 """Apply the vision transformer to input data. 187 188 Args: 189 x: The input data. 190 191 Returns: 192 The vision transformer output. 193 """ 194 x = self.patch_embed(x) 195 if self.pos_embed is not None: 196 x = x + self.pos_embed 197 198 list_from_encoder = [] 199 for i, blk in enumerate(self.blocks): 200 x = blk(x) 201 if i in self.chunks_for_projection: 202 list_from_encoder.append(x) 203 204 x = x.permute(0, 3, 1, 2) 205 206 if self.apply_neck: 207 x = self.neck(x) 208 209 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 210 return x, list_from_encoder[:3]
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
213class ViT_MAE(VisionTransformer): 214 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 215 216 Based on: 217 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 218 219 Args: 220 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 221 in_chans: The number of input channels. 222 depth: The depth of the vision transformer. 223 kwargs: Additional keyword arguments for the vision transformer base class. 224 """ 225 def __init__( 226 self, 227 img_size: int = 1024, # chosen to match our experiments with segment anything 228 in_chans: int = 3, 229 depth: int = 12, 230 **kwargs 231 ): 232 if not _timm_import_success: 233 raise RuntimeError( 234 "The vision transformer backend can only be initialized if timm is installed. " 235 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 236 "and then rerun your code" 237 ) 238 super().__init__(img_size=img_size, depth=depth, **kwargs) 239 self.img_size = img_size 240 self.in_chans = in_chans 241 self.depth = depth 242 243 def convert_to_expected_dim(self, inputs_): 244 """@private 245 """ 246 inputs_ = inputs_[:, 1:, :] # removing the class tokens 247 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 248 rdim = inputs_.shape[1] 249 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 250 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 251 inputs_ = inputs_.permute(0, 3, 1, 2) 252 return inputs_ 253 254 def forward_features(self, x): 255 """@private 256 """ 257 B = x.shape[0] 258 x = self.patch_embed(x) 259 260 cls_tokens = self.cls_token.expand(B, -1, -1) 261 x = torch.cat((cls_tokens, x), dim=1) 262 263 x = x + self.pos_embed 264 x = self.pos_drop(x) 265 266 # chunks obtained for getting the projections for conjuctions with upsampling blocks 267 _chunks = int(self.depth / 4) 268 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 269 270 list_from_encoder = [] 271 for i, blk in enumerate(self.blocks): 272 x = blk(x) 273 if i in chunks_for_projection: 274 list_from_encoder.append(self.convert_to_expected_dim(x)) 275 276 x = self.convert_to_expected_dim(x) 277 return x, list_from_encoder[:3] 278 279 def forward(self, x: torch.Tensor) -> torch.Tensor: 280 """Apply the vision transformer to input data. 281 282 Args: 283 x: The input data. 284 285 Returns: 286 The vision transformer output. 287 """ 288 x, list_from_encoder = self.forward_features(x) 289 return x, list_from_encoder
Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- in_chans: The number of input channels.
- depth: The depth of the vision transformer.
- kwargs: Additional keyword arguments for the vision transformer base class.
225 def __init__( 226 self, 227 img_size: int = 1024, # chosen to match our experiments with segment anything 228 in_chans: int = 3, 229 depth: int = 12, 230 **kwargs 231 ): 232 if not _timm_import_success: 233 raise RuntimeError( 234 "The vision transformer backend can only be initialized if timm is installed. " 235 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 236 "and then rerun your code" 237 ) 238 super().__init__(img_size=img_size, depth=depth, **kwargs) 239 self.img_size = img_size 240 self.in_chans = in_chans 241 self.depth = depth
279 def forward(self, x: torch.Tensor) -> torch.Tensor: 280 """Apply the vision transformer to input data. 281 282 Args: 283 x: The input data. 284 285 Returns: 286 The vision transformer output. 287 """ 288 x, list_from_encoder = self.forward_features(x) 289 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
292class ViT_Sam2(ImageEncoder): 293 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 294 295 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 296 297 Args: 298 backbone_channel_list: The channels throughout the entire backbone. 299 embed_dim: The initial embedding dimension. 300 num_heads: The initial number of heads. 301 stages: The number of blocks per stage. 302 global_att_blocks: The parameter to decide which blocks have global attention. 303 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 304 window_spec: The window size per stage, when not using global attention. 305 scalp: The count of lowest resolution features to discard. 306 """ 307 def __init__( 308 self, 309 backbone_channel_list: List[int], 310 img_size: int = 1024, 311 embed_dim: int = 96, 312 num_heads: int = 1, 313 stages: Tuple[int, ...] = (2, 3, 16, 3), 314 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 315 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 316 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 317 scalp: int = 1, 318 **kwargs 319 ): 320 if not _sam2_import_success: 321 raise RuntimeError( 322 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 323 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 324 "and then rerun your code" 325 ) 326 327 trunk = Hiera( 328 embed_dim=embed_dim, 329 num_heads=num_heads, 330 stages=stages, 331 global_att_blocks=global_att_blocks, 332 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 333 window_spec=window_spec, 334 ) 335 neck = FpnNeck( 336 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 337 d_model=256, 338 backbone_channel_list=backbone_channel_list, 339 fpn_top_down_levels=[2, 3], 340 fpn_interp_model="nearest", 341 ) 342 343 super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs) 344 self.scalp = scalp 345 self.embed_dim = embed_dim 346 self.img_size = img_size 347 348 def forward(self, x: torch.Tensor): 349 # The forward pass throught the backbone. 350 features, pos = self.neck(self.trunk(x)) 351 if self.scalp > 0: # This discard the "n" lowest resolution features. 352 features, pos = features[:-self.scalp], pos[:-self.scalp] 353 354 return features[-1], features
Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
Arguments:
- backbone_channel_list: The channels throughout the entire backbone.
- embed_dim: The initial embedding dimension.
- num_heads: The initial number of heads.
- stages: The number of blocks per stage.
- global_att_blocks: The parameter to decide which blocks have global attention.
- window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
- window_spec: The window size per stage, when not using global attention.
- scalp: The count of lowest resolution features to discard.
307 def __init__( 308 self, 309 backbone_channel_list: List[int], 310 img_size: int = 1024, 311 embed_dim: int = 96, 312 num_heads: int = 1, 313 stages: Tuple[int, ...] = (2, 3, 16, 3), 314 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 315 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 316 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 317 scalp: int = 1, 318 **kwargs 319 ): 320 if not _sam2_import_success: 321 raise RuntimeError( 322 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 323 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 324 "and then rerun your code" 325 ) 326 327 trunk = Hiera( 328 embed_dim=embed_dim, 329 num_heads=num_heads, 330 stages=stages, 331 global_att_blocks=global_att_blocks, 332 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 333 window_spec=window_spec, 334 ) 335 neck = FpnNeck( 336 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 337 d_model=256, 338 backbone_channel_list=backbone_channel_list, 339 fpn_top_down_levels=[2, 3], 340 fpn_interp_model="nearest", 341 ) 342 343 super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs) 344 self.scalp = scalp 345 self.embed_dim = embed_dim 346 self.img_size = img_size
348 def forward(self, x: torch.Tensor): 349 # The forward pass throught the backbone. 350 features, pos = self.neck(self.trunk(x)) 351 if self.scalp > 0: # This discard the "n" lowest resolution features. 352 features, pos = features[:-self.scalp], pos[:-self.scalp] 353 354 return features[-1], features
357class ViT_Sam3(SAM3ViT): 358 """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719). 359 360 Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py. 361 362 Args: 363 img_size: The input image size. 364 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 365 kwargs: Keyword arguments for the image encoder base class. 366 """ 367 def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs): 368 if not _sam3_import_success: 369 raise RuntimeError( 370 "The vision transformer backend can only be initialized if segment anything 3 is installed. " 371 "Please install segment anything 3 from https://github.com/facebookresearch/sam3 " 372 "and then rerun your code" 373 ) 374 375 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 376 self.img_size = img_size 377 self.embed_dim = embed_dim 378 379 def forward_features(self, x): 380 """@private 381 """ 382 x = self.patch_embed(x) 383 h, w = x.shape[1], x.shape[2] 384 385 s = 0 386 if self.retain_cls_token: 387 # If the 'cls_token' is retained, we don't maintain the spatial shape. 388 x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1) 389 s = 1 390 391 if self.pos_embed is not None: 392 x = x + get_abs_pos( 393 self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos, 394 ) 395 396 x = self.ln_pre(x) 397 398 list_from_encoder = [] 399 for i, blk in enumerate(self.blocks): 400 if self.use_act_checkpoint and self.training: 401 x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False) 402 else: 403 x = blk(x) 404 405 x = self._convert_to_expected_dim(x, i, s) 406 407 if i in self.full_attn_ids: 408 list_from_encoder.append(x) 409 410 return x, list_from_encoder 411 412 def _convert_to_expected_dim(self, x, i, s): 413 if (i == self.full_attn_ids[-1]) or ( 414 self.return_interm_layers and i in self.full_attn_ids 415 ): 416 if i == self.full_attn_ids[-1]: 417 x = self.ln_post(x) 418 419 feats = x[:, s:] 420 if feats.ndim == 4: 421 feats = feats.permute(0, 3, 1, 2) 422 else: 423 assert feats.ndim == 3 424 h = w = math.sqrt(feats.shape[1]) 425 feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2) 426 return feats 427 428 else: 429 return x 430 431 def forward(self, x: torch.Tensor): 432 """Apply the vision transformer to input data. 433 434 Args: 435 x: The input data. 436 437 Returns: 438 The vision transformer output. 439 """ 440 x, list_from_encoder = self.forward_features(x) 441 return x, list_from_encoder
Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).
Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.
Arguments:
- img_size: The input image size.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- kwargs: Keyword arguments for the image encoder base class.
367 def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs): 368 if not _sam3_import_success: 369 raise RuntimeError( 370 "The vision transformer backend can only be initialized if segment anything 3 is installed. " 371 "Please install segment anything 3 from https://github.com/facebookresearch/sam3 " 372 "and then rerun your code" 373 ) 374 375 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 376 self.img_size = img_size 377 self.embed_dim = embed_dim
431 def forward(self, x: torch.Tensor): 432 """Apply the vision transformer to input data. 433 434 Args: 435 x: The input data. 436 437 Returns: 438 The vision transformer output. 439 """ 440 x, list_from_encoder = self.forward_features(x) 441 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
448class CustomCompose: 449 def __init__(self, rescale_transform, other_transforms, src_transform): 450 self.rescale_transform = rescale_transform 451 self.other_transforms = other_transforms 452 self.src_transform = src_transform 453 454 def __call__(self, x, valid_masks=None): 455 if valid_masks is not None: 456 nodata = (x * (1 - valid_masks.float())).max() 457 x_aug = self.rescale_transform(x) 458 parms = self.rescale_transform._params 459 460 # sanity check, comment if this is working 461 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 462 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 463 464 if valid_masks is not None: 465 valid_masks = x_aug != nodata 466 _, c, h, w = x_aug.shape 467 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 468 else: 469 zero_ratio = -1 470 471 if self.other_transforms: 472 x_aug = self.other_transforms(x_aug) 473 x_src = self.src_transform(x_aug) 474 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 475 476 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 477 # assert (dx == dy).all() 478 479 h, w = x_aug.shape[-2:] 480 # assert h == w 481 482 return x_aug, x_src, dx / h, zero_ratio, valid_masks
485def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 486 """ 487 grid_size: int of the grid height and width 488 res: array of size n, representing the resolution of a pixel (say, in meters), 489 return: 490 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 491 """ 492 # res = torch.FloatTensor(res).to(device) 493 res = res.to(device) 494 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 495 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 496 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 497 grid = torch.stack(grid, dim=0) # 2 x h x w 498 499 # grid = grid.reshape([2, 1, grid_size, grid_size]) 500 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 501 _, n, h, w = grid.shape 502 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 503 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 504 if cls_token: 505 pos_embed = torch.cat( 506 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 507 ) 508 509 return pos_embed
grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_sizegrid_size, embed_dim] or [n,1+grid_sizegrid_size, embed_dim] (w/ or w/o cls_token)
512def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 513 assert embed_dim % 2 == 0 514 515 # use half of dimensions to encode grid_h 516 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 517 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 518 519 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 520 return emb
523def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 524 """ 525 embed_dim: output dimension for each position 526 pos: a list of positions to be encoded: size (M,) 527 out: (M, D) 528 """ 529 assert embed_dim % 2 == 0 530 # old_shape = pos 531 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 532 omega /= embed_dim / 2.0 533 omega = 1.0 / 10000**omega # (D/2,) 534 535 pos = pos.reshape(-1) # (M,) 536 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 537 538 emb_sin = torch.sin(out) # (M, D/2) 539 emb_cos = torch.cos(out) # (M, D/2) 540 541 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 542 return emb
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
545class PatchEmbedUnSafe(PatchEmbed): 546 """Image to Patch Embedding""" 547 548 def forward(self, x): 549 B, C, H, W = x.shape 550 551 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 552 # assert H == self.img_size[0] and W == self.img_size[1], \ 553 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 554 555 x = self.proj(x).flatten(2).transpose(1, 2) 556 return x
Image to Patch Embedding
548 def forward(self, x): 549 B, C, H, W = x.shape 550 551 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 552 # assert H == self.img_size[0] and W == self.img_size[1], \ 553 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 554 555 x = self.proj(x).flatten(2).transpose(1, 2) 556 return x
559class ViT_ScaleMAE(VisionTransformer): 560 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 561 562 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 563 the model on a different zoom factor dataset. 564 """ 565 566 def __init__( 567 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 568 ): 569 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 570 self.img_size = img_size 571 self.in_chans = in_chans 572 self.depth = depth 573 self.base_resolution = base_resolution 574 575 self.patch_embed = PatchEmbedUnSafe( 576 img_size=img_size, 577 patch_size=patch_size, 578 in_chans=in_chans, 579 embed_dim=embed_dim, 580 ) 581 582 def transform_inputs(self, x): 583 import kornia.augmentation as K 584 from kornia.constants import Resample 585 586 self._transforms = CustomCompose( 587 rescale_transform=K.RandomResizedCrop( 588 (448, 448), 589 ratio=(1.0, 1.0), 590 scale=(1.0, 1.0), 591 resample=Resample.BICUBIC.name, 592 ), 593 other_transforms=None, 594 src_transform=K.Resize((224, 224)), 595 ) 596 x, _, ratios, _, _ = self._transforms(x) 597 input_res = ratios * self.base_resolution 598 return x, input_res 599 600 def convert_to_expected_dim(self, x): 601 inputs_ = x[:, 1:, :] # removing the class tokens 602 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 603 rdim = inputs_.shape[1] 604 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 605 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 606 inputs_ = inputs_.permute(0, 3, 1, 2) 607 return inputs_ 608 609 def forward_features(self, x): 610 x, input_res = self.transform_inputs(x) 611 612 B, _, h, w = x.shape 613 x = self.patch_embed(x) 614 615 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 616 pos_embed = get_2d_sincos_pos_embed_with_resolution( 617 x.shape[-1], 618 int(num_patches ** 0.5), 619 input_res, 620 cls_token=True, 621 device=x.device, 622 ) 623 624 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 625 x = torch.cat((cls_tokens, x), dim=1) 626 x = x + pos_embed 627 x = self.pos_drop(x) 628 629 # chunks obtained for getting the projections for conjuctions with upsampling blocks 630 _chunks = int(self.depth / 4) 631 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 632 633 list_from_encoder = [] 634 for i, blk in enumerate(self.blocks): 635 x = blk(x) 636 if i in chunks_for_projection: 637 list_from_encoder.append(self.convert_to_expected_dim(x)) 638 639 x = self.convert_to_expected_dim(x) 640 641 return x, list_from_encoder 642 643 def forward(self, x): 644 x, list_from_encoder = self.forward_features(x) 645 return x, list_from_encoder
Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.
566 def __init__( 567 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 568 ): 569 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 570 self.img_size = img_size 571 self.in_chans = in_chans 572 self.depth = depth 573 self.base_resolution = base_resolution 574 575 self.patch_embed = PatchEmbedUnSafe( 576 img_size=img_size, 577 patch_size=patch_size, 578 in_chans=in_chans, 579 embed_dim=embed_dim, 580 )
582 def transform_inputs(self, x): 583 import kornia.augmentation as K 584 from kornia.constants import Resample 585 586 self._transforms = CustomCompose( 587 rescale_transform=K.RandomResizedCrop( 588 (448, 448), 589 ratio=(1.0, 1.0), 590 scale=(1.0, 1.0), 591 resample=Resample.BICUBIC.name, 592 ), 593 other_transforms=None, 594 src_transform=K.Resize((224, 224)), 595 ) 596 x, _, ratios, _, _ = self._transforms(x) 597 input_res = ratios * self.base_resolution 598 return x, input_res
600 def convert_to_expected_dim(self, x): 601 inputs_ = x[:, 1:, :] # removing the class tokens 602 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 603 rdim = inputs_.shape[1] 604 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 605 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 606 inputs_ = inputs_.permute(0, 3, 1, 2) 607 return inputs_
609 def forward_features(self, x): 610 x, input_res = self.transform_inputs(x) 611 612 B, _, h, w = x.shape 613 x = self.patch_embed(x) 614 615 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 616 pos_embed = get_2d_sincos_pos_embed_with_resolution( 617 x.shape[-1], 618 int(num_patches ** 0.5), 619 input_res, 620 cls_token=True, 621 device=x.device, 622 ) 623 624 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 625 x = torch.cat((cls_tokens, x), dim=1) 626 x = x + pos_embed 627 x = self.pos_drop(x) 628 629 # chunks obtained for getting the projections for conjuctions with upsampling blocks 630 _chunks = int(self.depth / 4) 631 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 632 633 list_from_encoder = [] 634 for i, blk in enumerate(self.blocks): 635 x = blk(x) 636 if i in chunks_for_projection: 637 list_from_encoder.append(self.convert_to_expected_dim(x)) 638 639 x = self.convert_to_expected_dim(x) 640 641 return x, list_from_encoder
648class ViT_DINOv2(DinoV2VisionTransformer): 649 """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193). 650 651 Based on: 652 https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py. 653 654 Args: 655 img_size: The input image size. 656 patch_size: The patch size. 657 depth: The depth of the network. 658 num_register_tokens: The number of registers added (in addition to the class tokens). 659 It's important to know for ViTs trained with registers, to remove them at the end. 660 """ 661 def __init__( 662 self, 663 img_size: int = 224, 664 patch_size: int = 16, 665 depth: int = 12, 666 num_register_tokens: int = 0, 667 **kwargs 668 ): 669 if not _dinov2_import_success: 670 raise RuntimeError( 671 "The vision transformer backend can only be initialized if DINOv2 is installed. " 672 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 673 "and then rerun your code." 674 ) 675 676 super().__init__( 677 img_size=img_size, 678 depth=depth, 679 patch_size=patch_size, 680 num_register_tokens=num_register_tokens, 681 **kwargs 682 ) 683 684 self.img_size = img_size 685 self.num_register_tokens = num_register_tokens 686 self.patch_size = patch_size 687 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 688 689 def forward(self, x, masks=None) -> torch.Tensor: 690 691 B = x.shape[0] 692 693 x = self.prepare_tokens_with_masks(x) 694 695 list_of_encoder = [] 696 for i, blk in enumerate(self.blocks): 697 x = blk(x) 698 if i in self.attn_outs: 699 list_of_encoder.append(x) 700 701 x = self.norm(x) 702 x = x[:, self.num_register_tokens + 1:].reshape( 703 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 704 ).permute(0, 3, 1, 2).contiguous() 705 706 list_of_encoder = [ 707 o[:, self.num_register_tokens + 1:].reshape( 708 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 709 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 710 ] 711 712 return x, list_of_encoder[:3]
Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
Based on: https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
Arguments:
- img_size: The input image size.
- patch_size: The patch size.
- depth: The depth of the network.
- num_register_tokens: The number of registers added (in addition to the class tokens). It's important to know for ViTs trained with registers, to remove them at the end.
661 def __init__( 662 self, 663 img_size: int = 224, 664 patch_size: int = 16, 665 depth: int = 12, 666 num_register_tokens: int = 0, 667 **kwargs 668 ): 669 if not _dinov2_import_success: 670 raise RuntimeError( 671 "The vision transformer backend can only be initialized if DINOv2 is installed. " 672 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 673 "and then rerun your code." 674 ) 675 676 super().__init__( 677 img_size=img_size, 678 depth=depth, 679 patch_size=patch_size, 680 num_register_tokens=num_register_tokens, 681 **kwargs 682 ) 683 684 self.img_size = img_size 685 self.num_register_tokens = num_register_tokens 686 self.patch_size = patch_size 687 self.attn_outs = [i for i in range(depth) if i % 3 == 2]
689 def forward(self, x, masks=None) -> torch.Tensor: 690 691 B = x.shape[0] 692 693 x = self.prepare_tokens_with_masks(x) 694 695 list_of_encoder = [] 696 for i, blk in enumerate(self.blocks): 697 x = blk(x) 698 if i in self.attn_outs: 699 list_of_encoder.append(x) 700 701 x = self.norm(x) 702 x = x[:, self.num_register_tokens + 1:].reshape( 703 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 704 ).permute(0, 3, 1, 2).contiguous() 705 706 list_of_encoder = [ 707 o[:, self.num_register_tokens + 1:].reshape( 708 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 709 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 710 ] 711 712 return x, list_of_encoder[:3]
715class ViT_DINOv3(DinoV3VisionTransformer): 716 """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104). 717 718 Based on: 719 https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py. 720 721 Args: 722 img_size: The input image size. 723 patch_size: The patch size. 724 embed_dim: The embedding dimension. 725 depth: The depth of the network. 726 num_heads: The number of heads. 727 ffn_ratio: The FFN rato. 728 n_storage_tokens: The number of storage (class) tokens to remove. 729 kwargs: Keyword arguments for the image encoder base class. 730 """ 731 def __init__( 732 self, 733 in_chans: int = 3, 734 img_size: int = 224, 735 patch_size: int = 16, 736 embed_dim: int = 768, 737 depth: int = 12, 738 num_heads: int = 12, 739 ffn_ratio: float = 4.0, 740 n_storage_tokens: int = 0, 741 **kwargs 742 ): 743 if not _dinov3_import_success: 744 raise RuntimeError( 745 "The vision transformer backend can only be initialized if DINOv3 is installed. " 746 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 747 "and then rerun your code." 748 ) 749 750 super().__init__( 751 in_chans=in_chans, 752 img_size=img_size, 753 patch_size=patch_size, 754 embed_dim=embed_dim, 755 depth=depth, 756 num_heads=num_heads, 757 ffn_ratio=ffn_ratio, 758 n_storage_tokens=n_storage_tokens, 759 **kwargs 760 ) 761 762 self.in_chans = in_chans 763 self.img_size = img_size 764 self.n_storage_tokens = n_storage_tokens 765 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 766 767 def forward(self, x) -> torch.Tensor: 768 769 B = x.shape[0] 770 771 x, hw_tuple = self.prepare_tokens_with_masks(x) 772 773 list_of_encoder = [] 774 for i, blk in enumerate(self.blocks): 775 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 776 x = blk(x, rope_sincos) 777 if i in self.attn_outs: 778 list_of_encoder.append(x) 779 780 x = self.norm(x) 781 x = x[:, self.n_storage_tokens + 1:].reshape( 782 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 783 ).permute(0, 3, 1, 2).contiguous() 784 785 list_of_encoder = [ 786 o[:, self.n_storage_tokens + 1:].reshape( 787 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 788 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 789 ] 790 791 return x, list_of_encoder[:3]
Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
Based on: https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
Arguments:
- img_size: The input image size.
- patch_size: The patch size.
- embed_dim: The embedding dimension.
- depth: The depth of the network.
- num_heads: The number of heads.
- ffn_ratio: The FFN rato.
- n_storage_tokens: The number of storage (class) tokens to remove.
- kwargs: Keyword arguments for the image encoder base class.
731 def __init__( 732 self, 733 in_chans: int = 3, 734 img_size: int = 224, 735 patch_size: int = 16, 736 embed_dim: int = 768, 737 depth: int = 12, 738 num_heads: int = 12, 739 ffn_ratio: float = 4.0, 740 n_storage_tokens: int = 0, 741 **kwargs 742 ): 743 if not _dinov3_import_success: 744 raise RuntimeError( 745 "The vision transformer backend can only be initialized if DINOv3 is installed. " 746 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 747 "and then rerun your code." 748 ) 749 750 super().__init__( 751 in_chans=in_chans, 752 img_size=img_size, 753 patch_size=patch_size, 754 embed_dim=embed_dim, 755 depth=depth, 756 num_heads=num_heads, 757 ffn_ratio=ffn_ratio, 758 n_storage_tokens=n_storage_tokens, 759 **kwargs 760 ) 761 762 self.in_chans = in_chans 763 self.img_size = img_size 764 self.n_storage_tokens = n_storage_tokens 765 self.attn_outs = [i for i in range(depth) if i % 3 == 2]
767 def forward(self, x) -> torch.Tensor: 768 769 B = x.shape[0] 770 771 x, hw_tuple = self.prepare_tokens_with_masks(x) 772 773 list_of_encoder = [] 774 for i, blk in enumerate(self.blocks): 775 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 776 x = blk(x, rope_sincos) 777 if i in self.attn_outs: 778 list_of_encoder.append(x) 779 780 x = self.norm(x) 781 x = x[:, self.n_storage_tokens + 1:].reshape( 782 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 783 ).permute(0, 3, 1, 2).contiguous() 784 785 list_of_encoder = [ 786 o[:, self.n_storage_tokens + 1:].reshape( 787 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 788 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 789 ] 790 791 return x, list_of_encoder[:3]
794def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 795 """Get vision transformer encoder. 796 797 Args: 798 backbone: The name of the vision transformer implementation. 799 One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3". 800 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 801 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 802 kwargs: Additional kwargs which can be expected by the vision transformer, 803 e.g. 'base_resolution' for `ViT_ScaleMAE`. 804 805 Returns: 806 The vision transformer. 807 """ 808 if backbone == "sam": 809 if model == "vit_b": 810 encoder = ViT_Sam( 811 depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4, 812 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 813 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 814 global_attn_indexes=[2, 5, 8, 11], 815 window_size=14, out_chans=256, 816 ) 817 elif model == "vit_l": 818 encoder = ViT_Sam( 819 depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4, 820 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 821 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 822 global_attn_indexes=[5, 11, 17, 23], 823 window_size=14, out_chans=256, 824 ) 825 elif model == "vit_h": 826 encoder = ViT_Sam( 827 depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4, 828 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 829 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 830 global_attn_indexes=[7, 15, 23, 31], 831 window_size=14, out_chans=256, 832 ) 833 else: 834 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 835 836 elif backbone == "cellpose_sam": 837 if model != "vit_l": 838 raise ValueError(f"'{model}' is not supported by CellposeSAM. Only 'vit_l' is supported.") 839 encoder = ViT_CellposeSAM(ps=8, bsize=img_size) 840 841 elif backbone == "sam2": 842 if model == "hvit_t": 843 encoder = ViT_Sam2( 844 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 845 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 846 ) 847 elif model == "hvit_s": 848 encoder = ViT_Sam2( 849 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 850 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 851 ) 852 elif model == "hvit_b": 853 encoder = ViT_Sam2( 854 img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 855 ) 856 elif model == "hvit_l": 857 encoder = ViT_Sam2( 858 img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 859 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 860 ) 861 else: 862 raise ValueError( 863 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 864 ) 865 866 elif backbone == "sam3": 867 if model != "vit_pe": 868 raise ValueError( 869 "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration." 870 ) 871 872 encoder = ViT_Sam3( 873 img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16, 874 mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True, 875 tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True, 876 use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True, 877 ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None, 878 ) 879 880 elif backbone == "mae": 881 if model == "vit_b": 882 encoder = ViT_MAE( 883 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 884 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 885 ) 886 elif model == "vit_l": 887 encoder = ViT_MAE( 888 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 889 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 890 ) 891 elif model == "vit_h": 892 encoder = ViT_MAE( 893 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 894 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 895 ) 896 else: 897 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 898 899 elif backbone == "scalemae": 900 base_resolution = kwargs.get("base_resolution", 2.5) 901 902 if model == "vit_b": 903 encoder = ViT_ScaleMAE( 904 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 905 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 906 base_resolution=base_resolution, 907 ) 908 elif model == "vit_l": 909 encoder = ViT_ScaleMAE( 910 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 911 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 912 base_resolution=base_resolution, 913 ) 914 elif model == "vit_h": 915 encoder = ViT_ScaleMAE( 916 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 917 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 918 base_resolution=base_resolution, 919 ) 920 else: 921 raise ValueError( 922 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 923 ) 924 925 elif backbone == "dinov2": 926 block_fn = partial(Block, attn_class=MemEffAttention) 927 msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>." 928 929 if model.startswith("vit_s"): 930 assert model in ["vit_s", "vit_s_reg4"], msg 931 encoder = ViT_DINOv2( 932 img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 933 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 934 num_register_tokens=4 if model.endswith("_reg4") else 0, 935 ) 936 elif model.startswith("vit_b"): 937 assert model in ["vit_b", "vit_b_reg4"], msg 938 encoder = ViT_DINOv2( 939 img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 940 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 941 num_register_tokens=4 if model.endswith("_reg4") else 0, 942 ) 943 elif model.startswith("vit_l"): 944 assert model in ["vit_l", "vit_l_reg4"], msg 945 encoder = ViT_DINOv2( 946 img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 947 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 948 num_register_tokens=4 if model.endswith("_reg4") else 0, 949 ) 950 elif model.startswith("vit_g"): 951 assert model in ["vit_g", "vit_g_reg4"], msg 952 encoder = ViT_DINOv2( 953 img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, 954 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 955 num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu", 956 ) 957 else: 958 raise ValueError( 959 f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported." 960 ) 961 962 elif backbone == "dinov3": 963 964 if model == "vit_s": 965 encoder = ViT_DINOv3( 966 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 967 num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 968 ) 969 elif model == "vit_s+": 970 encoder = ViT_DINOv3( 971 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 972 num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", 973 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 974 ) 975 976 elif model == "vit_b": 977 encoder = ViT_DINOv3( 978 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", 979 layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 980 ) 981 elif model == "vit_l": 982 encoder = ViT_DINOv3( 983 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 984 depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16", 985 n_storage_tokens=4, mask_k_bias=True, 986 ) 987 elif model == "vit_l+": 988 encoder = ViT_DINOv3( 989 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 990 depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 991 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 992 ) 993 elif model == "vit_h+": 994 encoder = ViT_DINOv3( 995 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280, 996 depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 997 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 998 ) 999 elif model == "vit_7b": 1000 encoder = ViT_DINOv3( 1001 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096, 1002 depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, 1003 norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True, 1004 untie_global_and_local_cls_norm=True, 1005 ) 1006 else: 1007 raise ValueError( 1008 f"'{model}' is not supported by DINOv3. Currently, " 1009 " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported." 1010 ) 1011 1012 else: 1013 raise ValueError( 1014 "The 'UNETR' supported backbones are 'sam', 'cellpose_sam', 'sam2', 'sam3', " 1015 "'mae', 'scalemae', 'dinov2' or 'dinov3'. Please choose one of them." 1016 ) 1017 1018 return encoder
Get vision transformer encoder.
Arguments:
- backbone: The name of the vision transformer implementation. One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3".
- model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- kwargs: Additional kwargs which can be expected by the vision transformer,
e.g. 'base_resolution' for
ViT_ScaleMAE.
Returns:
The vision transformer.