torch_em.model.vit
1import math 2from functools import partial 3from typing import Tuple, List 4 5import torch 6import torch.nn as nn 7 8# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should 9# only be optional dependencies for torch_em 10try: 11 from segment_anything.modeling import ImageEncoderViT 12 _sam_import_success = True 13except ImportError: 14 ImageEncoderViT = object 15 _sam_import_success = False 16 17try: 18 from timm.models.vision_transformer import VisionTransformer, PatchEmbed 19 _timm_import_success = True 20except ImportError: 21 VisionTransformer = object 22 PatchEmbed = object 23 _timm_import_success = False 24 25try: 26 from sam2.modeling.backbones.hieradet import Hiera 27 from sam2.modeling.position_encoding import PositionEmbeddingSine 28 from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck 29 _sam2_import_success = True 30except ImportError: 31 ImageEncoder = object 32 _sam2_import_success = False 33 34try: 35 from dinov2.models.vision_transformer import DinoVisionTransformer as DinoV2VisionTransformer 36 from dinov2.layers import MemEffAttention, NestedTensorBlock as Block 37 _dinov2_import_success = True 38except ImportError: 39 DinoV2VisionTransformer = object 40 _dinov2_import_success = False 41 42try: 43 from dinov3.models.vision_transformer import DinoVisionTransformer as DinoV3VisionTransformer 44 _dinov3_import_success = True 45except ImportError: 46 DinoV3VisionTransformer = object 47 _dinov3_import_success = False 48 49 50try: 51 from sam3.model.vitdet import ViT as SAM3ViT, get_abs_pos 52 _sam3_import_success = True 53except ImportError: 54 SAM3ViT = object 55 _sam3_import_success = False 56 57 58class ViT_Sam(ImageEncoderViT): 59 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 60 61 Based on: 62 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 63 64 Args: 65 in_chans: The number of input channels. 66 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 67 global_attn_indexes: The global attention indices. 68 kwargs: Keyword arguments for the image encoder base class. 69 """ 70 def __init__( 71 self, 72 in_chans: int = 3, 73 embed_dim: int = 768, 74 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 75 **kwargs, 76 ) -> None: 77 if not _sam_import_success: 78 raise RuntimeError( 79 "The vision transformer backend can only be initialized if segment anything is installed. " 80 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 81 "and then rerun your code." 82 ) 83 84 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 85 self.chunks_for_projection = global_attn_indexes 86 self.in_chans = in_chans 87 self.embed_dim = embed_dim 88 89 def forward(self, x: torch.Tensor) -> torch.Tensor: 90 """Apply the vision transformer to input data. 91 92 Args: 93 x: The input data. 94 95 Returns: 96 The vision transformer output. 97 """ 98 x = self.patch_embed(x) 99 if self.pos_embed is not None: 100 x = x + self.pos_embed 101 102 list_from_encoder = [] 103 for i, blk in enumerate(self.blocks): 104 x = blk(x) 105 if i in self.chunks_for_projection: 106 list_from_encoder.append(x) 107 108 x = x.permute(0, 3, 1, 2) 109 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 110 return x, list_from_encoder[:3] 111 112 113class ViT_MAE(VisionTransformer): 114 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 115 116 Based on: 117 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 118 119 Args: 120 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 121 in_chans: The number of input channels. 122 depth: The depth of the vision transformer. 123 kwargs: Additional keyword arguments for the vision transformer base class. 124 """ 125 def __init__( 126 self, 127 img_size: int = 1024, # chosen to match our experiments with segment anything 128 in_chans: int = 3, 129 depth: int = 12, 130 **kwargs 131 ): 132 if not _timm_import_success: 133 raise RuntimeError( 134 "The vision transformer backend can only be initialized if timm is installed. " 135 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 136 "and then rerun your code" 137 ) 138 super().__init__(img_size=img_size, depth=depth, **kwargs) 139 self.img_size = img_size 140 self.in_chans = in_chans 141 self.depth = depth 142 143 def convert_to_expected_dim(self, inputs_): 144 """@private 145 """ 146 inputs_ = inputs_[:, 1:, :] # removing the class tokens 147 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 148 rdim = inputs_.shape[1] 149 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 150 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 151 inputs_ = inputs_.permute(0, 3, 1, 2) 152 return inputs_ 153 154 def forward_features(self, x): 155 """@private 156 """ 157 B = x.shape[0] 158 x = self.patch_embed(x) 159 160 cls_tokens = self.cls_token.expand(B, -1, -1) 161 x = torch.cat((cls_tokens, x), dim=1) 162 163 x = x + self.pos_embed 164 x = self.pos_drop(x) 165 166 # chunks obtained for getting the projections for conjuctions with upsampling blocks 167 _chunks = int(self.depth / 4) 168 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 169 170 list_from_encoder = [] 171 for i, blk in enumerate(self.blocks): 172 x = blk(x) 173 if i in chunks_for_projection: 174 list_from_encoder.append(self.convert_to_expected_dim(x)) 175 176 x = self.convert_to_expected_dim(x) 177 return x, list_from_encoder[:3] 178 179 def forward(self, x: torch.Tensor) -> torch.Tensor: 180 """Apply the vision transformer to input data. 181 182 Args: 183 x: The input data. 184 185 Returns: 186 The vision transformer output. 187 """ 188 x, list_from_encoder = self.forward_features(x) 189 return x, list_from_encoder 190 191 192class ViT_Sam2(ImageEncoder): 193 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 194 195 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 196 197 Args: 198 backbone_channel_list: The channels throughout the entire backbone. 199 embed_dim: The initial embedding dimension. 200 num_heads: The initial number of heads. 201 stages: The number of blocks per stage. 202 global_att_blocks: The parameter to decide which blocks have global attention. 203 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 204 window_spec: The window size per stage, when not using global attention. 205 scalp: The count of lowest resolution features to discard. 206 """ 207 def __init__( 208 self, 209 backbone_channel_list: List[int], 210 img_size: int = 1024, 211 embed_dim: int = 96, 212 num_heads: int = 1, 213 stages: Tuple[int, ...] = (2, 3, 16, 3), 214 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 215 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 216 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 217 scalp: int = 1, 218 **kwargs 219 ): 220 if not _sam2_import_success: 221 raise RuntimeError( 222 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 223 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 224 "and then rerun your code" 225 ) 226 227 trunk = Hiera( 228 embed_dim=embed_dim, 229 num_heads=num_heads, 230 stages=stages, 231 global_att_blocks=global_att_blocks, 232 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 233 window_spec=window_spec, 234 ) 235 neck = FpnNeck( 236 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 237 d_model=256, 238 backbone_channel_list=backbone_channel_list, 239 fpn_top_down_levels=[2, 3], 240 fpn_interp_model="nearest", 241 ) 242 243 super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs) 244 self.scalp = scalp 245 self.embed_dim = embed_dim 246 self.img_size = img_size 247 248 def forward(self, x: torch.Tensor): 249 # The forward pass throught the backbone. 250 features, pos = self.neck(self.trunk(x)) 251 if self.scalp > 0: # This discard the "n" lowest resolution features. 252 features, pos = features[:-self.scalp], pos[:-self.scalp] 253 254 return features[-1], features 255 256 257class ViT_Sam3(SAM3ViT): 258 """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719). 259 260 Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py. 261 262 Args: 263 img_size: The input image size. 264 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 265 kwargs: Keyword arguments for the image encoder base class. 266 """ 267 def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs): 268 if not _sam3_import_success: 269 raise RuntimeError( 270 "The vision transformer backend can only be initialized if segment anything 3 is installed. " 271 "Please install segment anything 3 from https://github.com/facebookresearch/sam3 " 272 "and then rerun your code" 273 ) 274 275 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 276 self.img_size = img_size 277 self.embed_dim = embed_dim 278 279 def forward_features(self, x): 280 """@private 281 """ 282 x = self.patch_embed(x) 283 h, w = x.shape[1], x.shape[2] 284 285 s = 0 286 if self.retain_cls_token: 287 # If the 'cls_token' is retained, we don't maintain the spatial shape. 288 x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1) 289 s = 1 290 291 if self.pos_embed is not None: 292 x = x + get_abs_pos( 293 self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos, 294 ) 295 296 x = self.ln_pre(x) 297 298 list_from_encoder = [] 299 for i, blk in enumerate(self.blocks): 300 if self.use_act_checkpoint and self.training: 301 x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False) 302 else: 303 x = blk(x) 304 305 x = self._convert_to_expected_dim(x, i, s) 306 307 if i in self.full_attn_ids: 308 list_from_encoder.append(x) 309 310 return x, list_from_encoder 311 312 def _convert_to_expected_dim(self, x, i, s): 313 if (i == self.full_attn_ids[-1]) or ( 314 self.return_interm_layers and i in self.full_attn_ids 315 ): 316 if i == self.full_attn_ids[-1]: 317 x = self.ln_post(x) 318 319 feats = x[:, s:] 320 if feats.ndim == 4: 321 feats = feats.permute(0, 3, 1, 2) 322 else: 323 assert feats.ndim == 3 324 h = w = math.sqrt(feats.shape[1]) 325 feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2) 326 return feats 327 328 else: 329 return x 330 331 def forward(self, x: torch.Tensor): 332 """Apply the vision transformer to input data. 333 334 Args: 335 x: The input data. 336 337 Returns: 338 The vision transformer output. 339 """ 340 x, list_from_encoder = self.forward_features(x) 341 return x, list_from_encoder 342 343# 344# Utilities for ScaleMAE's ViT 345# 346 347 348class CustomCompose: 349 def __init__(self, rescale_transform, other_transforms, src_transform): 350 self.rescale_transform = rescale_transform 351 self.other_transforms = other_transforms 352 self.src_transform = src_transform 353 354 def __call__(self, x, valid_masks=None): 355 if valid_masks is not None: 356 nodata = (x * (1 - valid_masks.float())).max() 357 x_aug = self.rescale_transform(x) 358 parms = self.rescale_transform._params 359 360 # sanity check, comment if this is working 361 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 362 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 363 364 if valid_masks is not None: 365 valid_masks = x_aug != nodata 366 _, c, h, w = x_aug.shape 367 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 368 else: 369 zero_ratio = -1 370 371 if self.other_transforms: 372 x_aug = self.other_transforms(x_aug) 373 x_src = self.src_transform(x_aug) 374 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 375 376 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 377 # assert (dx == dy).all() 378 379 h, w = x_aug.shape[-2:] 380 # assert h == w 381 382 return x_aug, x_src, dx / h, zero_ratio, valid_masks 383 384 385def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 386 """ 387 grid_size: int of the grid height and width 388 res: array of size n, representing the resolution of a pixel (say, in meters), 389 return: 390 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 391 """ 392 # res = torch.FloatTensor(res).to(device) 393 res = res.to(device) 394 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 395 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 396 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 397 grid = torch.stack(grid, dim=0) # 2 x h x w 398 399 # grid = grid.reshape([2, 1, grid_size, grid_size]) 400 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 401 _, n, h, w = grid.shape 402 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 403 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 404 if cls_token: 405 pos_embed = torch.cat( 406 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 407 ) 408 409 return pos_embed 410 411 412def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 413 assert embed_dim % 2 == 0 414 415 # use half of dimensions to encode grid_h 416 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 417 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 418 419 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 420 return emb 421 422 423def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 424 """ 425 embed_dim: output dimension for each position 426 pos: a list of positions to be encoded: size (M,) 427 out: (M, D) 428 """ 429 assert embed_dim % 2 == 0 430 # old_shape = pos 431 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 432 omega /= embed_dim / 2.0 433 omega = 1.0 / 10000**omega # (D/2,) 434 435 pos = pos.reshape(-1) # (M,) 436 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 437 438 emb_sin = torch.sin(out) # (M, D/2) 439 emb_cos = torch.cos(out) # (M, D/2) 440 441 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 442 return emb 443 444 445class PatchEmbedUnSafe(PatchEmbed): 446 """Image to Patch Embedding""" 447 448 def forward(self, x): 449 B, C, H, W = x.shape 450 451 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 452 # assert H == self.img_size[0] and W == self.img_size[1], \ 453 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 454 455 x = self.proj(x).flatten(2).transpose(1, 2) 456 return x 457 458 459class ViT_ScaleMAE(VisionTransformer): 460 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 461 462 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 463 the model on a different zoom factor dataset. 464 """ 465 466 def __init__( 467 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 468 ): 469 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 470 self.img_size = img_size 471 self.in_chans = in_chans 472 self.depth = depth 473 self.base_resolution = base_resolution 474 475 self.patch_embed = PatchEmbedUnSafe( 476 img_size=img_size, 477 patch_size=patch_size, 478 in_chans=in_chans, 479 embed_dim=embed_dim, 480 ) 481 482 def transform_inputs(self, x): 483 import kornia.augmentation as K 484 from kornia.constants import Resample 485 486 self._transforms = CustomCompose( 487 rescale_transform=K.RandomResizedCrop( 488 (448, 448), 489 ratio=(1.0, 1.0), 490 scale=(1.0, 1.0), 491 resample=Resample.BICUBIC.name, 492 ), 493 other_transforms=None, 494 src_transform=K.Resize((224, 224)), 495 ) 496 x, _, ratios, _, _ = self._transforms(x) 497 input_res = ratios * self.base_resolution 498 return x, input_res 499 500 def convert_to_expected_dim(self, x): 501 inputs_ = x[:, 1:, :] # removing the class tokens 502 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 503 rdim = inputs_.shape[1] 504 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 505 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 506 inputs_ = inputs_.permute(0, 3, 1, 2) 507 return inputs_ 508 509 def forward_features(self, x): 510 x, input_res = self.transform_inputs(x) 511 512 B, _, h, w = x.shape 513 x = self.patch_embed(x) 514 515 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 516 pos_embed = get_2d_sincos_pos_embed_with_resolution( 517 x.shape[-1], 518 int(num_patches ** 0.5), 519 input_res, 520 cls_token=True, 521 device=x.device, 522 ) 523 524 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 525 x = torch.cat((cls_tokens, x), dim=1) 526 x = x + pos_embed 527 x = self.pos_drop(x) 528 529 # chunks obtained for getting the projections for conjuctions with upsampling blocks 530 _chunks = int(self.depth / 4) 531 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 532 533 list_from_encoder = [] 534 for i, blk in enumerate(self.blocks): 535 x = blk(x) 536 if i in chunks_for_projection: 537 list_from_encoder.append(self.convert_to_expected_dim(x)) 538 539 x = self.convert_to_expected_dim(x) 540 541 return x, list_from_encoder 542 543 def forward(self, x): 544 x, list_from_encoder = self.forward_features(x) 545 return x, list_from_encoder 546 547 548class ViT_DINOv2(DinoV2VisionTransformer): 549 """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193). 550 551 Based on: 552 https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py. 553 554 Args: 555 img_size: The input image size. 556 patch_size: The patch size. 557 depth: The depth of the network. 558 num_register_tokens: The number of registers added (in addition to the class tokens). 559 It's important to know for ViTs trained with registers, to remove them at the end. 560 """ 561 def __init__( 562 self, 563 img_size: int = 224, 564 patch_size: int = 16, 565 depth: int = 12, 566 num_register_tokens: int = 0, 567 **kwargs 568 ): 569 if not _dinov2_import_success: 570 raise RuntimeError( 571 "The vision transformer backend can only be initialized if DINOv2 is installed. " 572 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 573 "and then rerun your code." 574 ) 575 576 super().__init__( 577 img_size=img_size, 578 depth=depth, 579 patch_size=patch_size, 580 num_register_tokens=num_register_tokens, 581 **kwargs 582 ) 583 584 self.img_size = img_size 585 self.num_register_tokens = num_register_tokens 586 self.patch_size = patch_size 587 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 588 589 def forward(self, x, masks=None) -> torch.Tensor: 590 591 B = x.shape[0] 592 593 x = self.prepare_tokens_with_masks(x) 594 595 list_of_encoder = [] 596 for i, blk in enumerate(self.blocks): 597 x = blk(x) 598 if i in self.attn_outs: 599 list_of_encoder.append(x) 600 601 x = self.norm(x) 602 x = x[:, self.num_register_tokens + 1:].reshape( 603 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 604 ).permute(0, 3, 1, 2).contiguous() 605 606 list_of_encoder = [ 607 o[:, self.num_register_tokens + 1:].reshape( 608 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 609 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 610 ] 611 612 return x, list_of_encoder[:3] 613 614 615class ViT_DINOv3(DinoV3VisionTransformer): 616 """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104). 617 618 Based on: 619 https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py. 620 621 Args: 622 img_size: The input image size. 623 patch_size: The patch size. 624 embed_dim: The embedding dimension. 625 depth: The depth of the network. 626 num_heads: The number of heads. 627 ffn_ratio: The FFN rato. 628 n_storage_tokens: The number of storage (class) tokens to remove. 629 kwargs: Keyword arguments for the image encoder base class. 630 """ 631 def __init__( 632 self, 633 in_chans: int = 3, 634 img_size: int = 224, 635 patch_size: int = 16, 636 embed_dim: int = 768, 637 depth: int = 12, 638 num_heads: int = 12, 639 ffn_ratio: float = 4.0, 640 n_storage_tokens: int = 0, 641 **kwargs 642 ): 643 if not _dinov3_import_success: 644 raise RuntimeError( 645 "The vision transformer backend can only be initialized if DINOv3 is installed. " 646 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 647 "and then rerun your code." 648 ) 649 650 super().__init__( 651 in_chans=in_chans, 652 img_size=img_size, 653 patch_size=patch_size, 654 embed_dim=embed_dim, 655 depth=depth, 656 num_heads=num_heads, 657 ffn_ratio=ffn_ratio, 658 n_storage_tokens=n_storage_tokens, 659 **kwargs 660 ) 661 662 self.in_chans = in_chans 663 self.img_size = img_size 664 self.n_storage_tokens = n_storage_tokens 665 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 666 667 def forward(self, x) -> torch.Tensor: 668 669 B = x.shape[0] 670 671 x, hw_tuple = self.prepare_tokens_with_masks(x) 672 673 list_of_encoder = [] 674 for i, blk in enumerate(self.blocks): 675 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 676 x = blk(x, rope_sincos) 677 if i in self.attn_outs: 678 list_of_encoder.append(x) 679 680 x = self.norm(x) 681 x = x[:, self.n_storage_tokens + 1:].reshape( 682 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 683 ).permute(0, 3, 1, 2).contiguous() 684 685 list_of_encoder = [ 686 o[:, self.n_storage_tokens + 1:].reshape( 687 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 688 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 689 ] 690 691 return x, list_of_encoder[:3] 692 693 694def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 695 """Get vision transformer encoder. 696 697 Args: 698 backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae". 699 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 700 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 701 kwargs: Additional kwargs which can be expected by the vision transformer, 702 e.g. 'base_resolution' for `ViT_ScaleMAE`. 703 704 Returns: 705 The vision transformer. 706 """ 707 if backbone == "sam": 708 if model == "vit_b": 709 encoder = ViT_Sam( 710 depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4, 711 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 712 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 713 global_attn_indexes=[2, 5, 8, 11], 714 window_size=14, out_chans=256, 715 ) 716 elif model == "vit_l": 717 encoder = ViT_Sam( 718 depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4, 719 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 720 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 721 global_attn_indexes=[5, 11, 17, 23], 722 window_size=14, out_chans=256, 723 ) 724 elif model == "vit_h": 725 encoder = ViT_Sam( 726 depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4, 727 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 728 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 729 global_attn_indexes=[7, 15, 23, 31], 730 window_size=14, out_chans=256, 731 ) 732 else: 733 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 734 735 elif backbone == "sam2": 736 if model == "hvit_t": 737 encoder = ViT_Sam2( 738 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 739 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 740 ) 741 elif model == "hvit_s": 742 encoder = ViT_Sam2( 743 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 744 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 745 ) 746 elif model == "hvit_b": 747 encoder = ViT_Sam2( 748 img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 749 ) 750 elif model == "hvit_l": 751 encoder = ViT_Sam2( 752 img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 753 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 754 ) 755 else: 756 raise ValueError( 757 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 758 ) 759 760 elif backbone == "sam3": 761 if model != "vit_pe": 762 raise ValueError( 763 "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration." 764 ) 765 766 encoder = ViT_Sam3( 767 img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16, 768 mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True, 769 tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True, 770 use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True, 771 ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None, 772 ) 773 774 elif backbone == "mae": 775 if model == "vit_b": 776 encoder = ViT_MAE( 777 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 778 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 779 ) 780 elif model == "vit_l": 781 encoder = ViT_MAE( 782 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 783 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 784 ) 785 elif model == "vit_h": 786 encoder = ViT_MAE( 787 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 788 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 789 ) 790 else: 791 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 792 793 elif backbone == "scalemae": 794 base_resolution = kwargs.get("base_resolution", 2.5) 795 796 if model == "vit_b": 797 encoder = ViT_ScaleMAE( 798 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 799 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 800 base_resolution=base_resolution, 801 ) 802 elif model == "vit_l": 803 encoder = ViT_ScaleMAE( 804 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 805 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 806 base_resolution=base_resolution, 807 ) 808 elif model == "vit_h": 809 encoder = ViT_ScaleMAE( 810 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 811 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 812 base_resolution=base_resolution, 813 ) 814 else: 815 raise ValueError( 816 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 817 ) 818 819 elif backbone == "dinov2": 820 block_fn = partial(Block, attn_class=MemEffAttention) 821 msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>." 822 823 if model.startswith("vit_s"): 824 assert model in ["vit_s", "vit_s_reg4"], msg 825 encoder = ViT_DINOv2( 826 img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 827 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 828 num_register_tokens=4 if model.endswith("_reg4") else 0, 829 ) 830 elif model.startswith("vit_b"): 831 assert model in ["vit_b", "vit_b_reg4"], msg 832 encoder = ViT_DINOv2( 833 img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 834 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 835 num_register_tokens=4 if model.endswith("_reg4") else 0, 836 ) 837 elif model.startswith("vit_l"): 838 assert model in ["vit_l", "vit_l_reg4"], msg 839 encoder = ViT_DINOv2( 840 img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 841 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 842 num_register_tokens=4 if model.endswith("_reg4") else 0, 843 ) 844 elif model.startswith("vit_g"): 845 assert model in ["vit_g", "vit_g_reg4"], msg 846 encoder = ViT_DINOv2( 847 img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, 848 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 849 num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu", 850 ) 851 else: 852 raise ValueError( 853 f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported." 854 ) 855 856 elif backbone == "dinov3": 857 858 if model == "vit_s": 859 encoder = ViT_DINOv3( 860 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 861 num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 862 ) 863 elif model == "vit_s+": 864 encoder = ViT_DINOv3( 865 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 866 num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", 867 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 868 ) 869 870 elif model == "vit_b": 871 encoder = ViT_DINOv3( 872 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", 873 layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 874 ) 875 elif model == "vit_l": 876 encoder = ViT_DINOv3( 877 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 878 depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16", 879 n_storage_tokens=4, mask_k_bias=True, 880 ) 881 elif model == "vit_l+": 882 encoder = ViT_DINOv3( 883 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 884 depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 885 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 886 ) 887 elif model == "vit_h+": 888 encoder = ViT_DINOv3( 889 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280, 890 depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 891 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 892 ) 893 elif model == "vit_7b": 894 encoder = ViT_DINOv3( 895 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096, 896 depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, 897 norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True, 898 untie_global_and_local_cls_norm=True, 899 ) 900 else: 901 raise ValueError( 902 f"'{model}' is not supported by DINOv3. Currently, " 903 " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported." 904 ) 905 906 else: 907 raise ValueError( 908 "The 'UNETR' supported backbones are 'sam', 'sam2', 'sam3', 'mae', 'scalemae', 'dinov2' or 'dinov3'. " 909 "Please choose one of them." 910 ) 911 912 return encoder
59class ViT_Sam(ImageEncoderViT): 60 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 61 62 Based on: 63 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 64 65 Args: 66 in_chans: The number of input channels. 67 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 68 global_attn_indexes: The global attention indices. 69 kwargs: Keyword arguments for the image encoder base class. 70 """ 71 def __init__( 72 self, 73 in_chans: int = 3, 74 embed_dim: int = 768, 75 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 76 **kwargs, 77 ) -> None: 78 if not _sam_import_success: 79 raise RuntimeError( 80 "The vision transformer backend can only be initialized if segment anything is installed. " 81 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 82 "and then rerun your code." 83 ) 84 85 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 86 self.chunks_for_projection = global_attn_indexes 87 self.in_chans = in_chans 88 self.embed_dim = embed_dim 89 90 def forward(self, x: torch.Tensor) -> torch.Tensor: 91 """Apply the vision transformer to input data. 92 93 Args: 94 x: The input data. 95 96 Returns: 97 The vision transformer output. 98 """ 99 x = self.patch_embed(x) 100 if self.pos_embed is not None: 101 x = x + self.pos_embed 102 103 list_from_encoder = [] 104 for i, blk in enumerate(self.blocks): 105 x = blk(x) 106 if i in self.chunks_for_projection: 107 list_from_encoder.append(x) 108 109 x = x.permute(0, 3, 1, 2) 110 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 111 return x, list_from_encoder[:3]
Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
Arguments:
- in_chans: The number of input channels.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- global_attn_indexes: The global attention indices.
- kwargs: Keyword arguments for the image encoder base class.
71 def __init__( 72 self, 73 in_chans: int = 3, 74 embed_dim: int = 768, 75 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 76 **kwargs, 77 ) -> None: 78 if not _sam_import_success: 79 raise RuntimeError( 80 "The vision transformer backend can only be initialized if segment anything is installed. " 81 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 82 "and then rerun your code." 83 ) 84 85 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 86 self.chunks_for_projection = global_attn_indexes 87 self.in_chans = in_chans 88 self.embed_dim = embed_dim
90 def forward(self, x: torch.Tensor) -> torch.Tensor: 91 """Apply the vision transformer to input data. 92 93 Args: 94 x: The input data. 95 96 Returns: 97 The vision transformer output. 98 """ 99 x = self.patch_embed(x) 100 if self.pos_embed is not None: 101 x = x + self.pos_embed 102 103 list_from_encoder = [] 104 for i, blk in enumerate(self.blocks): 105 x = blk(x) 106 if i in self.chunks_for_projection: 107 list_from_encoder.append(x) 108 109 x = x.permute(0, 3, 1, 2) 110 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 111 return x, list_from_encoder[:3]
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
114class ViT_MAE(VisionTransformer): 115 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 116 117 Based on: 118 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 119 120 Args: 121 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 122 in_chans: The number of input channels. 123 depth: The depth of the vision transformer. 124 kwargs: Additional keyword arguments for the vision transformer base class. 125 """ 126 def __init__( 127 self, 128 img_size: int = 1024, # chosen to match our experiments with segment anything 129 in_chans: int = 3, 130 depth: int = 12, 131 **kwargs 132 ): 133 if not _timm_import_success: 134 raise RuntimeError( 135 "The vision transformer backend can only be initialized if timm is installed. " 136 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 137 "and then rerun your code" 138 ) 139 super().__init__(img_size=img_size, depth=depth, **kwargs) 140 self.img_size = img_size 141 self.in_chans = in_chans 142 self.depth = depth 143 144 def convert_to_expected_dim(self, inputs_): 145 """@private 146 """ 147 inputs_ = inputs_[:, 1:, :] # removing the class tokens 148 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 149 rdim = inputs_.shape[1] 150 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 151 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 152 inputs_ = inputs_.permute(0, 3, 1, 2) 153 return inputs_ 154 155 def forward_features(self, x): 156 """@private 157 """ 158 B = x.shape[0] 159 x = self.patch_embed(x) 160 161 cls_tokens = self.cls_token.expand(B, -1, -1) 162 x = torch.cat((cls_tokens, x), dim=1) 163 164 x = x + self.pos_embed 165 x = self.pos_drop(x) 166 167 # chunks obtained for getting the projections for conjuctions with upsampling blocks 168 _chunks = int(self.depth / 4) 169 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 170 171 list_from_encoder = [] 172 for i, blk in enumerate(self.blocks): 173 x = blk(x) 174 if i in chunks_for_projection: 175 list_from_encoder.append(self.convert_to_expected_dim(x)) 176 177 x = self.convert_to_expected_dim(x) 178 return x, list_from_encoder[:3] 179 180 def forward(self, x: torch.Tensor) -> torch.Tensor: 181 """Apply the vision transformer to input data. 182 183 Args: 184 x: The input data. 185 186 Returns: 187 The vision transformer output. 188 """ 189 x, list_from_encoder = self.forward_features(x) 190 return x, list_from_encoder
Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- in_chans: The number of input channels.
- depth: The depth of the vision transformer.
- kwargs: Additional keyword arguments for the vision transformer base class.
126 def __init__( 127 self, 128 img_size: int = 1024, # chosen to match our experiments with segment anything 129 in_chans: int = 3, 130 depth: int = 12, 131 **kwargs 132 ): 133 if not _timm_import_success: 134 raise RuntimeError( 135 "The vision transformer backend can only be initialized if timm is installed. " 136 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 137 "and then rerun your code" 138 ) 139 super().__init__(img_size=img_size, depth=depth, **kwargs) 140 self.img_size = img_size 141 self.in_chans = in_chans 142 self.depth = depth
180 def forward(self, x: torch.Tensor) -> torch.Tensor: 181 """Apply the vision transformer to input data. 182 183 Args: 184 x: The input data. 185 186 Returns: 187 The vision transformer output. 188 """ 189 x, list_from_encoder = self.forward_features(x) 190 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
193class ViT_Sam2(ImageEncoder): 194 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 195 196 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 197 198 Args: 199 backbone_channel_list: The channels throughout the entire backbone. 200 embed_dim: The initial embedding dimension. 201 num_heads: The initial number of heads. 202 stages: The number of blocks per stage. 203 global_att_blocks: The parameter to decide which blocks have global attention. 204 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 205 window_spec: The window size per stage, when not using global attention. 206 scalp: The count of lowest resolution features to discard. 207 """ 208 def __init__( 209 self, 210 backbone_channel_list: List[int], 211 img_size: int = 1024, 212 embed_dim: int = 96, 213 num_heads: int = 1, 214 stages: Tuple[int, ...] = (2, 3, 16, 3), 215 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 216 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 217 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 218 scalp: int = 1, 219 **kwargs 220 ): 221 if not _sam2_import_success: 222 raise RuntimeError( 223 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 224 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 225 "and then rerun your code" 226 ) 227 228 trunk = Hiera( 229 embed_dim=embed_dim, 230 num_heads=num_heads, 231 stages=stages, 232 global_att_blocks=global_att_blocks, 233 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 234 window_spec=window_spec, 235 ) 236 neck = FpnNeck( 237 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 238 d_model=256, 239 backbone_channel_list=backbone_channel_list, 240 fpn_top_down_levels=[2, 3], 241 fpn_interp_model="nearest", 242 ) 243 244 super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs) 245 self.scalp = scalp 246 self.embed_dim = embed_dim 247 self.img_size = img_size 248 249 def forward(self, x: torch.Tensor): 250 # The forward pass throught the backbone. 251 features, pos = self.neck(self.trunk(x)) 252 if self.scalp > 0: # This discard the "n" lowest resolution features. 253 features, pos = features[:-self.scalp], pos[:-self.scalp] 254 255 return features[-1], features
Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
Arguments:
- backbone_channel_list: The channels throughout the entire backbone.
- embed_dim: The initial embedding dimension.
- num_heads: The initial number of heads.
- stages: The number of blocks per stage.
- global_att_blocks: The parameter to decide which blocks have global attention.
- window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
- window_spec: The window size per stage, when not using global attention.
- scalp: The count of lowest resolution features to discard.
208 def __init__( 209 self, 210 backbone_channel_list: List[int], 211 img_size: int = 1024, 212 embed_dim: int = 96, 213 num_heads: int = 1, 214 stages: Tuple[int, ...] = (2, 3, 16, 3), 215 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 216 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 217 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 218 scalp: int = 1, 219 **kwargs 220 ): 221 if not _sam2_import_success: 222 raise RuntimeError( 223 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 224 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 225 "and then rerun your code" 226 ) 227 228 trunk = Hiera( 229 embed_dim=embed_dim, 230 num_heads=num_heads, 231 stages=stages, 232 global_att_blocks=global_att_blocks, 233 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 234 window_spec=window_spec, 235 ) 236 neck = FpnNeck( 237 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 238 d_model=256, 239 backbone_channel_list=backbone_channel_list, 240 fpn_top_down_levels=[2, 3], 241 fpn_interp_model="nearest", 242 ) 243 244 super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs) 245 self.scalp = scalp 246 self.embed_dim = embed_dim 247 self.img_size = img_size
249 def forward(self, x: torch.Tensor): 250 # The forward pass throught the backbone. 251 features, pos = self.neck(self.trunk(x)) 252 if self.scalp > 0: # This discard the "n" lowest resolution features. 253 features, pos = features[:-self.scalp], pos[:-self.scalp] 254 255 return features[-1], features
258class ViT_Sam3(SAM3ViT): 259 """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719). 260 261 Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py. 262 263 Args: 264 img_size: The input image size. 265 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 266 kwargs: Keyword arguments for the image encoder base class. 267 """ 268 def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs): 269 if not _sam3_import_success: 270 raise RuntimeError( 271 "The vision transformer backend can only be initialized if segment anything 3 is installed. " 272 "Please install segment anything 3 from https://github.com/facebookresearch/sam3 " 273 "and then rerun your code" 274 ) 275 276 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 277 self.img_size = img_size 278 self.embed_dim = embed_dim 279 280 def forward_features(self, x): 281 """@private 282 """ 283 x = self.patch_embed(x) 284 h, w = x.shape[1], x.shape[2] 285 286 s = 0 287 if self.retain_cls_token: 288 # If the 'cls_token' is retained, we don't maintain the spatial shape. 289 x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1) 290 s = 1 291 292 if self.pos_embed is not None: 293 x = x + get_abs_pos( 294 self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos, 295 ) 296 297 x = self.ln_pre(x) 298 299 list_from_encoder = [] 300 for i, blk in enumerate(self.blocks): 301 if self.use_act_checkpoint and self.training: 302 x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False) 303 else: 304 x = blk(x) 305 306 x = self._convert_to_expected_dim(x, i, s) 307 308 if i in self.full_attn_ids: 309 list_from_encoder.append(x) 310 311 return x, list_from_encoder 312 313 def _convert_to_expected_dim(self, x, i, s): 314 if (i == self.full_attn_ids[-1]) or ( 315 self.return_interm_layers and i in self.full_attn_ids 316 ): 317 if i == self.full_attn_ids[-1]: 318 x = self.ln_post(x) 319 320 feats = x[:, s:] 321 if feats.ndim == 4: 322 feats = feats.permute(0, 3, 1, 2) 323 else: 324 assert feats.ndim == 3 325 h = w = math.sqrt(feats.shape[1]) 326 feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2) 327 return feats 328 329 else: 330 return x 331 332 def forward(self, x: torch.Tensor): 333 """Apply the vision transformer to input data. 334 335 Args: 336 x: The input data. 337 338 Returns: 339 The vision transformer output. 340 """ 341 x, list_from_encoder = self.forward_features(x) 342 return x, list_from_encoder
Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).
Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.
Arguments:
- img_size: The input image size.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- kwargs: Keyword arguments for the image encoder base class.
268 def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs): 269 if not _sam3_import_success: 270 raise RuntimeError( 271 "The vision transformer backend can only be initialized if segment anything 3 is installed. " 272 "Please install segment anything 3 from https://github.com/facebookresearch/sam3 " 273 "and then rerun your code" 274 ) 275 276 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 277 self.img_size = img_size 278 self.embed_dim = embed_dim
332 def forward(self, x: torch.Tensor): 333 """Apply the vision transformer to input data. 334 335 Args: 336 x: The input data. 337 338 Returns: 339 The vision transformer output. 340 """ 341 x, list_from_encoder = self.forward_features(x) 342 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
349class CustomCompose: 350 def __init__(self, rescale_transform, other_transforms, src_transform): 351 self.rescale_transform = rescale_transform 352 self.other_transforms = other_transforms 353 self.src_transform = src_transform 354 355 def __call__(self, x, valid_masks=None): 356 if valid_masks is not None: 357 nodata = (x * (1 - valid_masks.float())).max() 358 x_aug = self.rescale_transform(x) 359 parms = self.rescale_transform._params 360 361 # sanity check, comment if this is working 362 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 363 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 364 365 if valid_masks is not None: 366 valid_masks = x_aug != nodata 367 _, c, h, w = x_aug.shape 368 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 369 else: 370 zero_ratio = -1 371 372 if self.other_transforms: 373 x_aug = self.other_transforms(x_aug) 374 x_src = self.src_transform(x_aug) 375 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 376 377 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 378 # assert (dx == dy).all() 379 380 h, w = x_aug.shape[-2:] 381 # assert h == w 382 383 return x_aug, x_src, dx / h, zero_ratio, valid_masks
386def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 387 """ 388 grid_size: int of the grid height and width 389 res: array of size n, representing the resolution of a pixel (say, in meters), 390 return: 391 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 392 """ 393 # res = torch.FloatTensor(res).to(device) 394 res = res.to(device) 395 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 396 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 397 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 398 grid = torch.stack(grid, dim=0) # 2 x h x w 399 400 # grid = grid.reshape([2, 1, grid_size, grid_size]) 401 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 402 _, n, h, w = grid.shape 403 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 404 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 405 if cls_token: 406 pos_embed = torch.cat( 407 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 408 ) 409 410 return pos_embed
grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
413def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 414 assert embed_dim % 2 == 0 415 416 # use half of dimensions to encode grid_h 417 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 418 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 419 420 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 421 return emb
424def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 425 """ 426 embed_dim: output dimension for each position 427 pos: a list of positions to be encoded: size (M,) 428 out: (M, D) 429 """ 430 assert embed_dim % 2 == 0 431 # old_shape = pos 432 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 433 omega /= embed_dim / 2.0 434 omega = 1.0 / 10000**omega # (D/2,) 435 436 pos = pos.reshape(-1) # (M,) 437 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 438 439 emb_sin = torch.sin(out) # (M, D/2) 440 emb_cos = torch.cos(out) # (M, D/2) 441 442 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 443 return emb
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
446class PatchEmbedUnSafe(PatchEmbed): 447 """Image to Patch Embedding""" 448 449 def forward(self, x): 450 B, C, H, W = x.shape 451 452 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 453 # assert H == self.img_size[0] and W == self.img_size[1], \ 454 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 455 456 x = self.proj(x).flatten(2).transpose(1, 2) 457 return x
Image to Patch Embedding
449 def forward(self, x): 450 B, C, H, W = x.shape 451 452 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 453 # assert H == self.img_size[0] and W == self.img_size[1], \ 454 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 455 456 x = self.proj(x).flatten(2).transpose(1, 2) 457 return x
460class ViT_ScaleMAE(VisionTransformer): 461 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 462 463 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 464 the model on a different zoom factor dataset. 465 """ 466 467 def __init__( 468 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 469 ): 470 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 471 self.img_size = img_size 472 self.in_chans = in_chans 473 self.depth = depth 474 self.base_resolution = base_resolution 475 476 self.patch_embed = PatchEmbedUnSafe( 477 img_size=img_size, 478 patch_size=patch_size, 479 in_chans=in_chans, 480 embed_dim=embed_dim, 481 ) 482 483 def transform_inputs(self, x): 484 import kornia.augmentation as K 485 from kornia.constants import Resample 486 487 self._transforms = CustomCompose( 488 rescale_transform=K.RandomResizedCrop( 489 (448, 448), 490 ratio=(1.0, 1.0), 491 scale=(1.0, 1.0), 492 resample=Resample.BICUBIC.name, 493 ), 494 other_transforms=None, 495 src_transform=K.Resize((224, 224)), 496 ) 497 x, _, ratios, _, _ = self._transforms(x) 498 input_res = ratios * self.base_resolution 499 return x, input_res 500 501 def convert_to_expected_dim(self, x): 502 inputs_ = x[:, 1:, :] # removing the class tokens 503 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 504 rdim = inputs_.shape[1] 505 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 506 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 507 inputs_ = inputs_.permute(0, 3, 1, 2) 508 return inputs_ 509 510 def forward_features(self, x): 511 x, input_res = self.transform_inputs(x) 512 513 B, _, h, w = x.shape 514 x = self.patch_embed(x) 515 516 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 517 pos_embed = get_2d_sincos_pos_embed_with_resolution( 518 x.shape[-1], 519 int(num_patches ** 0.5), 520 input_res, 521 cls_token=True, 522 device=x.device, 523 ) 524 525 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 526 x = torch.cat((cls_tokens, x), dim=1) 527 x = x + pos_embed 528 x = self.pos_drop(x) 529 530 # chunks obtained for getting the projections for conjuctions with upsampling blocks 531 _chunks = int(self.depth / 4) 532 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 533 534 list_from_encoder = [] 535 for i, blk in enumerate(self.blocks): 536 x = blk(x) 537 if i in chunks_for_projection: 538 list_from_encoder.append(self.convert_to_expected_dim(x)) 539 540 x = self.convert_to_expected_dim(x) 541 542 return x, list_from_encoder 543 544 def forward(self, x): 545 x, list_from_encoder = self.forward_features(x) 546 return x, list_from_encoder
Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.
467 def __init__( 468 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 469 ): 470 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 471 self.img_size = img_size 472 self.in_chans = in_chans 473 self.depth = depth 474 self.base_resolution = base_resolution 475 476 self.patch_embed = PatchEmbedUnSafe( 477 img_size=img_size, 478 patch_size=patch_size, 479 in_chans=in_chans, 480 embed_dim=embed_dim, 481 )
483 def transform_inputs(self, x): 484 import kornia.augmentation as K 485 from kornia.constants import Resample 486 487 self._transforms = CustomCompose( 488 rescale_transform=K.RandomResizedCrop( 489 (448, 448), 490 ratio=(1.0, 1.0), 491 scale=(1.0, 1.0), 492 resample=Resample.BICUBIC.name, 493 ), 494 other_transforms=None, 495 src_transform=K.Resize((224, 224)), 496 ) 497 x, _, ratios, _, _ = self._transforms(x) 498 input_res = ratios * self.base_resolution 499 return x, input_res
501 def convert_to_expected_dim(self, x): 502 inputs_ = x[:, 1:, :] # removing the class tokens 503 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 504 rdim = inputs_.shape[1] 505 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 506 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 507 inputs_ = inputs_.permute(0, 3, 1, 2) 508 return inputs_
510 def forward_features(self, x): 511 x, input_res = self.transform_inputs(x) 512 513 B, _, h, w = x.shape 514 x = self.patch_embed(x) 515 516 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 517 pos_embed = get_2d_sincos_pos_embed_with_resolution( 518 x.shape[-1], 519 int(num_patches ** 0.5), 520 input_res, 521 cls_token=True, 522 device=x.device, 523 ) 524 525 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 526 x = torch.cat((cls_tokens, x), dim=1) 527 x = x + pos_embed 528 x = self.pos_drop(x) 529 530 # chunks obtained for getting the projections for conjuctions with upsampling blocks 531 _chunks = int(self.depth / 4) 532 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 533 534 list_from_encoder = [] 535 for i, blk in enumerate(self.blocks): 536 x = blk(x) 537 if i in chunks_for_projection: 538 list_from_encoder.append(self.convert_to_expected_dim(x)) 539 540 x = self.convert_to_expected_dim(x) 541 542 return x, list_from_encoder
549class ViT_DINOv2(DinoV2VisionTransformer): 550 """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193). 551 552 Based on: 553 https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py. 554 555 Args: 556 img_size: The input image size. 557 patch_size: The patch size. 558 depth: The depth of the network. 559 num_register_tokens: The number of registers added (in addition to the class tokens). 560 It's important to know for ViTs trained with registers, to remove them at the end. 561 """ 562 def __init__( 563 self, 564 img_size: int = 224, 565 patch_size: int = 16, 566 depth: int = 12, 567 num_register_tokens: int = 0, 568 **kwargs 569 ): 570 if not _dinov2_import_success: 571 raise RuntimeError( 572 "The vision transformer backend can only be initialized if DINOv2 is installed. " 573 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 574 "and then rerun your code." 575 ) 576 577 super().__init__( 578 img_size=img_size, 579 depth=depth, 580 patch_size=patch_size, 581 num_register_tokens=num_register_tokens, 582 **kwargs 583 ) 584 585 self.img_size = img_size 586 self.num_register_tokens = num_register_tokens 587 self.patch_size = patch_size 588 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 589 590 def forward(self, x, masks=None) -> torch.Tensor: 591 592 B = x.shape[0] 593 594 x = self.prepare_tokens_with_masks(x) 595 596 list_of_encoder = [] 597 for i, blk in enumerate(self.blocks): 598 x = blk(x) 599 if i in self.attn_outs: 600 list_of_encoder.append(x) 601 602 x = self.norm(x) 603 x = x[:, self.num_register_tokens + 1:].reshape( 604 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 605 ).permute(0, 3, 1, 2).contiguous() 606 607 list_of_encoder = [ 608 o[:, self.num_register_tokens + 1:].reshape( 609 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 610 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 611 ] 612 613 return x, list_of_encoder[:3]
Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
Based on: https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
Arguments:
- img_size: The input image size.
- patch_size: The patch size.
- depth: The depth of the network.
- num_register_tokens: The number of registers added (in addition to the class tokens). It's important to know for ViTs trained with registers, to remove them at the end.
562 def __init__( 563 self, 564 img_size: int = 224, 565 patch_size: int = 16, 566 depth: int = 12, 567 num_register_tokens: int = 0, 568 **kwargs 569 ): 570 if not _dinov2_import_success: 571 raise RuntimeError( 572 "The vision transformer backend can only be initialized if DINOv2 is installed. " 573 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 574 "and then rerun your code." 575 ) 576 577 super().__init__( 578 img_size=img_size, 579 depth=depth, 580 patch_size=patch_size, 581 num_register_tokens=num_register_tokens, 582 **kwargs 583 ) 584 585 self.img_size = img_size 586 self.num_register_tokens = num_register_tokens 587 self.patch_size = patch_size 588 self.attn_outs = [i for i in range(depth) if i % 3 == 2]
590 def forward(self, x, masks=None) -> torch.Tensor: 591 592 B = x.shape[0] 593 594 x = self.prepare_tokens_with_masks(x) 595 596 list_of_encoder = [] 597 for i, blk in enumerate(self.blocks): 598 x = blk(x) 599 if i in self.attn_outs: 600 list_of_encoder.append(x) 601 602 x = self.norm(x) 603 x = x[:, self.num_register_tokens + 1:].reshape( 604 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 605 ).permute(0, 3, 1, 2).contiguous() 606 607 list_of_encoder = [ 608 o[:, self.num_register_tokens + 1:].reshape( 609 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 610 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 611 ] 612 613 return x, list_of_encoder[:3]
616class ViT_DINOv3(DinoV3VisionTransformer): 617 """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104). 618 619 Based on: 620 https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py. 621 622 Args: 623 img_size: The input image size. 624 patch_size: The patch size. 625 embed_dim: The embedding dimension. 626 depth: The depth of the network. 627 num_heads: The number of heads. 628 ffn_ratio: The FFN rato. 629 n_storage_tokens: The number of storage (class) tokens to remove. 630 kwargs: Keyword arguments for the image encoder base class. 631 """ 632 def __init__( 633 self, 634 in_chans: int = 3, 635 img_size: int = 224, 636 patch_size: int = 16, 637 embed_dim: int = 768, 638 depth: int = 12, 639 num_heads: int = 12, 640 ffn_ratio: float = 4.0, 641 n_storage_tokens: int = 0, 642 **kwargs 643 ): 644 if not _dinov3_import_success: 645 raise RuntimeError( 646 "The vision transformer backend can only be initialized if DINOv3 is installed. " 647 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 648 "and then rerun your code." 649 ) 650 651 super().__init__( 652 in_chans=in_chans, 653 img_size=img_size, 654 patch_size=patch_size, 655 embed_dim=embed_dim, 656 depth=depth, 657 num_heads=num_heads, 658 ffn_ratio=ffn_ratio, 659 n_storage_tokens=n_storage_tokens, 660 **kwargs 661 ) 662 663 self.in_chans = in_chans 664 self.img_size = img_size 665 self.n_storage_tokens = n_storage_tokens 666 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 667 668 def forward(self, x) -> torch.Tensor: 669 670 B = x.shape[0] 671 672 x, hw_tuple = self.prepare_tokens_with_masks(x) 673 674 list_of_encoder = [] 675 for i, blk in enumerate(self.blocks): 676 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 677 x = blk(x, rope_sincos) 678 if i in self.attn_outs: 679 list_of_encoder.append(x) 680 681 x = self.norm(x) 682 x = x[:, self.n_storage_tokens + 1:].reshape( 683 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 684 ).permute(0, 3, 1, 2).contiguous() 685 686 list_of_encoder = [ 687 o[:, self.n_storage_tokens + 1:].reshape( 688 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 689 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 690 ] 691 692 return x, list_of_encoder[:3]
Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
Based on: https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
Arguments:
- img_size: The input image size.
- patch_size: The patch size.
- embed_dim: The embedding dimension.
- depth: The depth of the network.
- num_heads: The number of heads.
- ffn_ratio: The FFN rato.
- n_storage_tokens: The number of storage (class) tokens to remove.
- kwargs: Keyword arguments for the image encoder base class.
632 def __init__( 633 self, 634 in_chans: int = 3, 635 img_size: int = 224, 636 patch_size: int = 16, 637 embed_dim: int = 768, 638 depth: int = 12, 639 num_heads: int = 12, 640 ffn_ratio: float = 4.0, 641 n_storage_tokens: int = 0, 642 **kwargs 643 ): 644 if not _dinov3_import_success: 645 raise RuntimeError( 646 "The vision transformer backend can only be initialized if DINOv3 is installed. " 647 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 648 "and then rerun your code." 649 ) 650 651 super().__init__( 652 in_chans=in_chans, 653 img_size=img_size, 654 patch_size=patch_size, 655 embed_dim=embed_dim, 656 depth=depth, 657 num_heads=num_heads, 658 ffn_ratio=ffn_ratio, 659 n_storage_tokens=n_storage_tokens, 660 **kwargs 661 ) 662 663 self.in_chans = in_chans 664 self.img_size = img_size 665 self.n_storage_tokens = n_storage_tokens 666 self.attn_outs = [i for i in range(depth) if i % 3 == 2]
668 def forward(self, x) -> torch.Tensor: 669 670 B = x.shape[0] 671 672 x, hw_tuple = self.prepare_tokens_with_masks(x) 673 674 list_of_encoder = [] 675 for i, blk in enumerate(self.blocks): 676 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 677 x = blk(x, rope_sincos) 678 if i in self.attn_outs: 679 list_of_encoder.append(x) 680 681 x = self.norm(x) 682 x = x[:, self.n_storage_tokens + 1:].reshape( 683 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 684 ).permute(0, 3, 1, 2).contiguous() 685 686 list_of_encoder = [ 687 o[:, self.n_storage_tokens + 1:].reshape( 688 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 689 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 690 ] 691 692 return x, list_of_encoder[:3]
695def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 696 """Get vision transformer encoder. 697 698 Args: 699 backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae". 700 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 701 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 702 kwargs: Additional kwargs which can be expected by the vision transformer, 703 e.g. 'base_resolution' for `ViT_ScaleMAE`. 704 705 Returns: 706 The vision transformer. 707 """ 708 if backbone == "sam": 709 if model == "vit_b": 710 encoder = ViT_Sam( 711 depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4, 712 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 713 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 714 global_attn_indexes=[2, 5, 8, 11], 715 window_size=14, out_chans=256, 716 ) 717 elif model == "vit_l": 718 encoder = ViT_Sam( 719 depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4, 720 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 721 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 722 global_attn_indexes=[5, 11, 17, 23], 723 window_size=14, out_chans=256, 724 ) 725 elif model == "vit_h": 726 encoder = ViT_Sam( 727 depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4, 728 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 729 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 730 global_attn_indexes=[7, 15, 23, 31], 731 window_size=14, out_chans=256, 732 ) 733 else: 734 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 735 736 elif backbone == "sam2": 737 if model == "hvit_t": 738 encoder = ViT_Sam2( 739 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 740 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 741 ) 742 elif model == "hvit_s": 743 encoder = ViT_Sam2( 744 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 745 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 746 ) 747 elif model == "hvit_b": 748 encoder = ViT_Sam2( 749 img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 750 ) 751 elif model == "hvit_l": 752 encoder = ViT_Sam2( 753 img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 754 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 755 ) 756 else: 757 raise ValueError( 758 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 759 ) 760 761 elif backbone == "sam3": 762 if model != "vit_pe": 763 raise ValueError( 764 "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration." 765 ) 766 767 encoder = ViT_Sam3( 768 img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16, 769 mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True, 770 tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True, 771 use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True, 772 ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None, 773 ) 774 775 elif backbone == "mae": 776 if model == "vit_b": 777 encoder = ViT_MAE( 778 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 779 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 780 ) 781 elif model == "vit_l": 782 encoder = ViT_MAE( 783 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 784 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 785 ) 786 elif model == "vit_h": 787 encoder = ViT_MAE( 788 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 789 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 790 ) 791 else: 792 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 793 794 elif backbone == "scalemae": 795 base_resolution = kwargs.get("base_resolution", 2.5) 796 797 if model == "vit_b": 798 encoder = ViT_ScaleMAE( 799 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 800 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 801 base_resolution=base_resolution, 802 ) 803 elif model == "vit_l": 804 encoder = ViT_ScaleMAE( 805 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 806 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 807 base_resolution=base_resolution, 808 ) 809 elif model == "vit_h": 810 encoder = ViT_ScaleMAE( 811 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 812 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 813 base_resolution=base_resolution, 814 ) 815 else: 816 raise ValueError( 817 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 818 ) 819 820 elif backbone == "dinov2": 821 block_fn = partial(Block, attn_class=MemEffAttention) 822 msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>." 823 824 if model.startswith("vit_s"): 825 assert model in ["vit_s", "vit_s_reg4"], msg 826 encoder = ViT_DINOv2( 827 img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 828 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 829 num_register_tokens=4 if model.endswith("_reg4") else 0, 830 ) 831 elif model.startswith("vit_b"): 832 assert model in ["vit_b", "vit_b_reg4"], msg 833 encoder = ViT_DINOv2( 834 img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 835 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 836 num_register_tokens=4 if model.endswith("_reg4") else 0, 837 ) 838 elif model.startswith("vit_l"): 839 assert model in ["vit_l", "vit_l_reg4"], msg 840 encoder = ViT_DINOv2( 841 img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 842 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 843 num_register_tokens=4 if model.endswith("_reg4") else 0, 844 ) 845 elif model.startswith("vit_g"): 846 assert model in ["vit_g", "vit_g_reg4"], msg 847 encoder = ViT_DINOv2( 848 img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, 849 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 850 num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu", 851 ) 852 else: 853 raise ValueError( 854 f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported." 855 ) 856 857 elif backbone == "dinov3": 858 859 if model == "vit_s": 860 encoder = ViT_DINOv3( 861 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 862 num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 863 ) 864 elif model == "vit_s+": 865 encoder = ViT_DINOv3( 866 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 867 num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", 868 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 869 ) 870 871 elif model == "vit_b": 872 encoder = ViT_DINOv3( 873 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", 874 layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 875 ) 876 elif model == "vit_l": 877 encoder = ViT_DINOv3( 878 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 879 depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16", 880 n_storage_tokens=4, mask_k_bias=True, 881 ) 882 elif model == "vit_l+": 883 encoder = ViT_DINOv3( 884 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 885 depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 886 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 887 ) 888 elif model == "vit_h+": 889 encoder = ViT_DINOv3( 890 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280, 891 depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 892 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 893 ) 894 elif model == "vit_7b": 895 encoder = ViT_DINOv3( 896 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096, 897 depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, 898 norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True, 899 untie_global_and_local_cls_norm=True, 900 ) 901 else: 902 raise ValueError( 903 f"'{model}' is not supported by DINOv3. Currently, " 904 " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported." 905 ) 906 907 else: 908 raise ValueError( 909 "The 'UNETR' supported backbones are 'sam', 'sam2', 'sam3', 'mae', 'scalemae', 'dinov2' or 'dinov3'. " 910 "Please choose one of them." 911 ) 912 913 return encoder
Get vision transformer encoder.
Arguments:
- backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
- model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- kwargs: Additional kwargs which can be expected by the vision transformer,
e.g. 'base_resolution' for
ViT_ScaleMAE.
Returns:
The vision transformer.