torch_em.model.vit
1from functools import partial 2from typing import Tuple, List 3 4import torch 5import torch.nn as nn 6 7# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should 8# only be optional dependencies for torch_em 9try: 10 from segment_anything.modeling import ImageEncoderViT 11 _sam_import_success = True 12except ImportError: 13 ImageEncoderViT = object 14 _sam_import_success = False 15 16try: 17 from timm.models.vision_transformer import VisionTransformer, PatchEmbed 18 _timm_import_success = True 19except ImportError: 20 VisionTransformer = object 21 PatchEmbed = object 22 _timm_import_success = False 23 24try: 25 from sam2.modeling.backbones.hieradet import Hiera 26 from sam2.modeling.position_encoding import PositionEmbeddingSine 27 from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck 28 _sam2_import_success = True 29except ImportError: 30 ImageEncoder = object 31 _sam2_import_success = False 32 33try: 34 from dinov2.models.vision_transformer import DinoVisionTransformer as DinoV2VisionTransformer 35 from dinov2.layers import MemEffAttention, NestedTensorBlock as Block 36 _dinov2_import_success = True 37except ImportError: 38 DinoV2VisionTransformer = object 39 _dinov2_import_success = False 40 41try: 42 from dinov3.models.vision_transformer import DinoVisionTransformer as DinoV3VisionTransformer 43 _dinov3_import_success = True 44except ImportError: 45 DinoV3VisionTransformer = object 46 _dinov3_import_success = False 47 48 49class ViT_Sam(ImageEncoderViT): 50 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 51 52 Based on: 53 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 54 55 Args: 56 in_chans: The number of input channels. 57 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 58 global_attn_indexes: The global attention indices. 59 kwargs: Keyword arguments for the image encoder base class. 60 """ 61 def __init__( 62 self, 63 in_chans: int = 3, 64 embed_dim: int = 768, 65 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 66 **kwargs, 67 ) -> None: 68 if not _sam_import_success: 69 raise RuntimeError( 70 "The vision transformer backend can only be initialized if segment anything is installed. " 71 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 72 "and then rerun your code." 73 ) 74 75 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 76 self.chunks_for_projection = global_attn_indexes 77 self.in_chans = in_chans 78 self.embed_dim = embed_dim 79 80 def forward(self, x: torch.Tensor) -> torch.Tensor: 81 """Apply the vision transformer to input data. 82 83 Args: 84 x: The input data. 85 86 Returns: 87 The vision transformer output. 88 """ 89 x = self.patch_embed(x) 90 if self.pos_embed is not None: 91 x = x + self.pos_embed 92 93 list_from_encoder = [] 94 for i, blk in enumerate(self.blocks): 95 x = blk(x) 96 if i in self.chunks_for_projection: 97 list_from_encoder.append(x) 98 99 x = x.permute(0, 3, 1, 2) 100 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 101 return x, list_from_encoder[:3] 102 103 104class ViT_MAE(VisionTransformer): 105 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 106 107 Based on: 108 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 109 110 Args: 111 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 112 in_chans: The number of input channels. 113 depth: The depth of the vision transformer. 114 kwargs: Additional keyword arguments for the vision transformer base class. 115 """ 116 def __init__( 117 self, 118 img_size: int = 1024, # chosen to match our experiments with segment anything 119 in_chans: int = 3, 120 depth: int = 12, 121 **kwargs 122 ): 123 if not _timm_import_success: 124 raise RuntimeError( 125 "The vision transformer backend can only be initialized if timm is installed. " 126 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 127 "and then rerun your code" 128 ) 129 super().__init__(img_size=img_size, depth=depth, **kwargs) 130 self.img_size = img_size 131 self.in_chans = in_chans 132 self.depth = depth 133 134 def convert_to_expected_dim(self, inputs_): 135 """@private 136 """ 137 inputs_ = inputs_[:, 1:, :] # removing the class tokens 138 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 139 rdim = inputs_.shape[1] 140 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 141 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 142 inputs_ = inputs_.permute(0, 3, 1, 2) 143 return inputs_ 144 145 def forward_features(self, x): 146 """@private 147 """ 148 B = x.shape[0] 149 x = self.patch_embed(x) 150 151 cls_tokens = self.cls_token.expand(B, -1, -1) 152 x = torch.cat((cls_tokens, x), dim=1) 153 154 x = x + self.pos_embed 155 x = self.pos_drop(x) 156 157 # chunks obtained for getting the projections for conjuctions with upsampling blocks 158 _chunks = int(self.depth / 4) 159 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 160 161 list_from_encoder = [] 162 for i, blk in enumerate(self.blocks): 163 x = blk(x) 164 if i in chunks_for_projection: 165 list_from_encoder.append(self.convert_to_expected_dim(x)) 166 167 x = self.convert_to_expected_dim(x) 168 return x, list_from_encoder[:3] 169 170 def forward(self, x: torch.Tensor) -> torch.Tensor: 171 """Apply the vision transformer to input data. 172 173 Args: 174 x: The input data. 175 176 Returns: 177 The vision transformer output. 178 """ 179 x, list_from_encoder = self.forward_features(x) 180 return x, list_from_encoder 181 182 183class ViT_Sam2(ImageEncoder): 184 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 185 186 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 187 188 Args: 189 backbone_channel_list: The channels throughout the entire backbone. 190 embed_dim: The initial embedding dimension. 191 num_heads: The initial number of heads. 192 stages: The number of blocks per stage. 193 global_att_blocks: The parameter to decide which blocks have global attention. 194 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 195 window_spec: The window size per stage, when not using global attention. 196 scalp: The count of lowest resolution features to discard. 197 """ 198 def __init__( 199 self, 200 backbone_channel_list: List[int], 201 img_size: int = 1024, 202 embed_dim: int = 96, 203 num_heads: int = 1, 204 stages: Tuple[int, ...] = (2, 3, 16, 3), 205 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 206 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 207 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 208 scalp: int = 1, 209 ): 210 if not _sam2_import_success: 211 raise RuntimeError( 212 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 213 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 214 "and then rerun your code" 215 ) 216 217 trunk = Hiera( 218 embed_dim=embed_dim, 219 num_heads=num_heads, 220 stages=stages, 221 global_att_blocks=global_att_blocks, 222 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 223 window_spec=window_spec, 224 ) 225 neck = FpnNeck( 226 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 227 d_model=256, 228 backbone_channel_list=backbone_channel_list, 229 fpn_top_down_levels=[2, 3], 230 fpn_interp_model="nearest", 231 ) 232 233 super().__init__(trunk=trunk, neck=neck, scalp=scalp) 234 self.scalp = scalp 235 self.embed_dim = embed_dim 236 self.img_size = img_size 237 238 def forward(self, x: torch.Tensor): 239 # The forward pass throught the backbone. 240 features, pos = self.neck(self.trunk(x)) 241 if self.scalp > 0: # This discard the "n" lowest resolution features. 242 features, pos = features[:-self.scalp], pos[:-self.scalp] 243 244 return features[-1], features 245 246 247# 248# Utilities for ScaleMAE's ViT 249# 250 251 252class CustomCompose: 253 def __init__(self, rescale_transform, other_transforms, src_transform): 254 self.rescale_transform = rescale_transform 255 self.other_transforms = other_transforms 256 self.src_transform = src_transform 257 258 def __call__(self, x, valid_masks=None): 259 if valid_masks is not None: 260 nodata = (x * (1 - valid_masks.float())).max() 261 x_aug = self.rescale_transform(x) 262 parms = self.rescale_transform._params 263 264 # sanity check, comment if this is working 265 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 266 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 267 268 if valid_masks is not None: 269 valid_masks = x_aug != nodata 270 _, c, h, w = x_aug.shape 271 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 272 else: 273 zero_ratio = -1 274 275 if self.other_transforms: 276 x_aug = self.other_transforms(x_aug) 277 x_src = self.src_transform(x_aug) 278 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 279 280 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 281 # assert (dx == dy).all() 282 283 h, w = x_aug.shape[-2:] 284 # assert h == w 285 286 return x_aug, x_src, dx / h, zero_ratio, valid_masks 287 288 289def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 290 """ 291 grid_size: int of the grid height and width 292 res: array of size n, representing the resolution of a pixel (say, in meters), 293 return: 294 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 295 """ 296 # res = torch.FloatTensor(res).to(device) 297 res = res.to(device) 298 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 299 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 300 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 301 grid = torch.stack(grid, dim=0) # 2 x h x w 302 303 # grid = grid.reshape([2, 1, grid_size, grid_size]) 304 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 305 _, n, h, w = grid.shape 306 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 307 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 308 if cls_token: 309 pos_embed = torch.cat( 310 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 311 ) 312 313 return pos_embed 314 315 316def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 317 assert embed_dim % 2 == 0 318 319 # use half of dimensions to encode grid_h 320 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 321 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 322 323 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 324 return emb 325 326 327def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 328 """ 329 embed_dim: output dimension for each position 330 pos: a list of positions to be encoded: size (M,) 331 out: (M, D) 332 """ 333 assert embed_dim % 2 == 0 334 # old_shape = pos 335 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 336 omega /= embed_dim / 2.0 337 omega = 1.0 / 10000**omega # (D/2,) 338 339 pos = pos.reshape(-1) # (M,) 340 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 341 342 emb_sin = torch.sin(out) # (M, D/2) 343 emb_cos = torch.cos(out) # (M, D/2) 344 345 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 346 return emb 347 348 349class PatchEmbedUnSafe(PatchEmbed): 350 """Image to Patch Embedding""" 351 352 def forward(self, x): 353 B, C, H, W = x.shape 354 355 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 356 # assert H == self.img_size[0] and W == self.img_size[1], \ 357 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 358 359 x = self.proj(x).flatten(2).transpose(1, 2) 360 return x 361 362 363class ViT_ScaleMAE(VisionTransformer): 364 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 365 366 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 367 the model on a different zoom factor dataset. 368 """ 369 370 def __init__( 371 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 372 ): 373 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 374 self.img_size = img_size 375 self.in_chans = in_chans 376 self.depth = depth 377 self.base_resolution = base_resolution 378 379 self.patch_embed = PatchEmbedUnSafe( 380 img_size=img_size, 381 patch_size=patch_size, 382 in_chans=in_chans, 383 embed_dim=embed_dim, 384 ) 385 386 def transform_inputs(self, x): 387 import kornia.augmentation as K 388 from kornia.constants import Resample 389 390 self._transforms = CustomCompose( 391 rescale_transform=K.RandomResizedCrop( 392 (448, 448), 393 ratio=(1.0, 1.0), 394 scale=(1.0, 1.0), 395 resample=Resample.BICUBIC.name, 396 ), 397 other_transforms=None, 398 src_transform=K.Resize((224, 224)), 399 ) 400 x, _, ratios, _, _ = self._transforms(x) 401 input_res = ratios * self.base_resolution 402 return x, input_res 403 404 def convert_to_expected_dim(self, x): 405 inputs_ = x[:, 1:, :] # removing the class tokens 406 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 407 rdim = inputs_.shape[1] 408 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 409 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 410 inputs_ = inputs_.permute(0, 3, 1, 2) 411 return inputs_ 412 413 def forward_features(self, x): 414 x, input_res = self.transform_inputs(x) 415 416 B, _, h, w = x.shape 417 x = self.patch_embed(x) 418 419 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 420 pos_embed = get_2d_sincos_pos_embed_with_resolution( 421 x.shape[-1], 422 int(num_patches ** 0.5), 423 input_res, 424 cls_token=True, 425 device=x.device, 426 ) 427 428 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 429 x = torch.cat((cls_tokens, x), dim=1) 430 x = x + pos_embed 431 x = self.pos_drop(x) 432 433 # chunks obtained for getting the projections for conjuctions with upsampling blocks 434 _chunks = int(self.depth / 4) 435 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 436 437 list_from_encoder = [] 438 for i, blk in enumerate(self.blocks): 439 x = blk(x) 440 if i in chunks_for_projection: 441 list_from_encoder.append(self.convert_to_expected_dim(x)) 442 443 x = self.convert_to_expected_dim(x) 444 445 return x, list_from_encoder 446 447 def forward(self, x): 448 x, list_from_encoder = self.forward_features(x) 449 return x, list_from_encoder 450 451 452class ViT_DINOv2(DinoV2VisionTransformer): 453 """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193). 454 455 Based on: 456 https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py. 457 """ 458 def __init__( 459 self, 460 img_size: int = 224, 461 patch_size: int = 16, 462 depth: int = 12, 463 num_register_tokens: int = 0, 464 **kwargs 465 ): 466 if not _dinov2_import_success: 467 raise RuntimeError( 468 "The vision transformer backend can only be initialized if DINOv2 is installed. " 469 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 470 "and then rerun your code." 471 ) 472 473 super().__init__( 474 img_size=img_size, 475 depth=depth, 476 patch_size=patch_size, 477 num_register_tokens=num_register_tokens, 478 **kwargs 479 ) 480 481 self.img_size = img_size 482 self.num_register_tokens = num_register_tokens 483 self.patch_size = patch_size 484 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 485 486 def forward(self, x, masks=None) -> torch.Tensor: 487 488 B = x.shape[0] 489 490 x = self.prepare_tokens_with_masks(x) 491 492 list_of_encoder = [] 493 for i, blk in enumerate(self.blocks): 494 x = blk(x) 495 if i in self.attn_outs: 496 list_of_encoder.append(x) 497 498 x = self.norm(x) 499 x = x[:, self.num_register_tokens + 1:].reshape( 500 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 501 ).permute(0, 3, 1, 2).contiguous() 502 503 list_of_encoder = [ 504 o[:, self.num_register_tokens + 1:].reshape( 505 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 506 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 507 ] 508 509 return x, list_of_encoder[:3] 510 511 512class ViT_DINOv3(DinoV3VisionTransformer): 513 """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104). 514 515 Based on: 516 https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py. 517 518 Args: 519 img_size: The input image size. 520 patch_size: The patch size. 521 embed_dim: The embedding dimension. 522 depth: The depth of the network. 523 num_heads: The number of heads. 524 ffn_ratio: The FFN rato. 525 n_storage_tokens: The number of storage (class) tokens to remove. 526 kwargs: Keyword arguments for the image encoder base class. 527 """ 528 def __init__( 529 self, 530 in_chans: int = 3, 531 img_size: int = 224, 532 patch_size: int = 16, 533 embed_dim: int = 768, 534 depth: int = 12, 535 num_heads: int = 12, 536 ffn_ratio: float = 4.0, 537 n_storage_tokens: int = 0, 538 **kwargs 539 ): 540 if not _dinov3_import_success: 541 raise RuntimeError( 542 "The vision transformer backend can only be initialized if DINOv3 is installed. " 543 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 544 "and then rerun your code." 545 ) 546 547 super().__init__( 548 in_chans=in_chans, 549 img_size=img_size, 550 patch_size=patch_size, 551 embed_dim=embed_dim, 552 depth=depth, 553 num_heads=num_heads, 554 ffn_ratio=ffn_ratio, 555 n_storage_tokens=n_storage_tokens, 556 **kwargs 557 ) 558 559 self.in_chans = in_chans 560 self.img_size = img_size 561 self.n_storage_tokens = n_storage_tokens 562 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 563 564 def forward(self, x) -> torch.Tensor: 565 566 B = x.shape[0] 567 568 x, hw_tuple = self.prepare_tokens_with_masks(x) 569 570 list_of_encoder = [] 571 for i, blk in enumerate(self.blocks): 572 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 573 x = blk(x, rope_sincos) 574 if i in self.attn_outs: 575 list_of_encoder.append(x) 576 577 x = self.norm(x) 578 x = x[:, self.n_storage_tokens + 1:].reshape( 579 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 580 ).permute(0, 3, 1, 2).contiguous() 581 582 list_of_encoder = [ 583 o[:, self.n_storage_tokens + 1:].reshape( 584 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 585 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 586 ] 587 588 return x, list_of_encoder[:3] 589 590 591def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 592 """Get vision transformer encoder. 593 594 Args: 595 backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae". 596 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 597 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 598 kwargs: Additional kwargs which can be expected by the vision transformer, 599 e.g. 'base_resolution' for `ViT_ScaleMAE`. 600 601 Returns: 602 The vision transformer. 603 """ 604 if backbone == "sam": 605 if model == "vit_b": 606 encoder = ViT_Sam( 607 depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4, 608 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 609 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 610 global_attn_indexes=[2, 5, 8, 11], 611 window_size=14, out_chans=256, 612 ) 613 elif model == "vit_l": 614 encoder = ViT_Sam( 615 depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4, 616 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 617 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 618 global_attn_indexes=[5, 11, 17, 23], 619 window_size=14, out_chans=256, 620 ) 621 elif model == "vit_h": 622 encoder = ViT_Sam( 623 depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4, 624 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 625 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 626 global_attn_indexes=[7, 15, 23, 31], 627 window_size=14, out_chans=256, 628 ) 629 else: 630 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 631 632 elif backbone == "sam2": 633 if model == "hvit_t": 634 encoder = ViT_Sam2( 635 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 636 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 637 ) 638 elif model == "hvit_s": 639 encoder = ViT_Sam2( 640 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 641 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 642 ) 643 elif model == "hvit_b": 644 encoder = ViT_Sam2( 645 img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 646 ) 647 elif model == "hvit_l": 648 encoder = ViT_Sam2( 649 img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 650 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 651 ) 652 else: 653 raise ValueError( 654 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 655 ) 656 657 elif backbone == "mae": 658 if model == "vit_b": 659 encoder = ViT_MAE( 660 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 661 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 662 ) 663 elif model == "vit_l": 664 encoder = ViT_MAE( 665 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 666 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 667 ) 668 elif model == "vit_h": 669 encoder = ViT_MAE( 670 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 671 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 672 ) 673 else: 674 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 675 676 elif backbone == "scalemae": 677 base_resolution = kwargs.get("base_resolution", 2.5) 678 679 if model == "vit_b": 680 encoder = ViT_ScaleMAE( 681 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 682 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 683 base_resolution=base_resolution, 684 ) 685 elif model == "vit_l": 686 encoder = ViT_ScaleMAE( 687 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 688 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 689 base_resolution=base_resolution, 690 ) 691 elif model == "vit_h": 692 encoder = ViT_ScaleMAE( 693 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 694 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 695 base_resolution=base_resolution, 696 ) 697 else: 698 raise ValueError( 699 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 700 ) 701 702 elif backbone == "dinov2": 703 block_fn = partial(Block, attn_class=MemEffAttention) 704 msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>." 705 706 if model.startswith("vit_s"): 707 assert model in ["vit_s", "vit_s_reg4"], msg 708 encoder = ViT_DINOv2( 709 img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 710 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 711 num_register_tokens=4 if model.endswith("_reg4") else 0, 712 ) 713 elif model.startswith("vit_b"): 714 assert model in ["vit_b", "vit_b_reg4"], msg 715 encoder = ViT_DINOv2( 716 img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 717 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 718 num_register_tokens=4 if model.endswith("_reg4") else 0, 719 ) 720 elif model.startswith("vit_l"): 721 assert model in ["vit_l", "vit_l_reg4"], msg 722 encoder = ViT_DINOv2( 723 img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 724 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 725 num_register_tokens=4 if model.endswith("_reg4") else 0, 726 ) 727 elif model.startswith("vit_g"): 728 assert model in ["vit_g", "vit_g_reg4"], msg 729 encoder = ViT_DINOv2( 730 img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, 731 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 732 num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu", 733 ) 734 else: 735 raise ValueError( 736 f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported." 737 ) 738 739 elif backbone == "dinov3": 740 741 if model == "vit_s": 742 encoder = ViT_DINOv3( 743 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 744 num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 745 ) 746 elif model == "vit_s+": 747 encoder = ViT_DINOv3( 748 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 749 num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", 750 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 751 ) 752 753 elif model == "vit_b": 754 encoder = ViT_DINOv3( 755 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", 756 layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 757 ) 758 elif model == "vit_l": 759 encoder = ViT_DINOv3( 760 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 761 depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16", 762 n_storage_tokens=4, mask_k_bias=True, 763 ) 764 elif model == "vit_l+": 765 encoder = ViT_DINOv3( 766 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 767 depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 768 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 769 ) 770 elif model == "vit_h+": 771 encoder = ViT_DINOv3( 772 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280, 773 depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 774 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 775 ) 776 elif model == "vit_7b": 777 encoder = ViT_DINOv3( 778 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096, 779 depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, 780 norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True, 781 untie_global_and_local_cls_norm=True, 782 ) 783 else: 784 raise ValueError( 785 f"'{model}' is not supported by DINOv3. Currently, " 786 " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported." 787 ) 788 789 else: 790 raise ValueError( 791 "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae', 'scalemae' or 'dinov3'. " 792 "Please choose one of them." 793 ) 794 795 return encoder
50class ViT_Sam(ImageEncoderViT): 51 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 52 53 Based on: 54 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 55 56 Args: 57 in_chans: The number of input channels. 58 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 59 global_attn_indexes: The global attention indices. 60 kwargs: Keyword arguments for the image encoder base class. 61 """ 62 def __init__( 63 self, 64 in_chans: int = 3, 65 embed_dim: int = 768, 66 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 67 **kwargs, 68 ) -> None: 69 if not _sam_import_success: 70 raise RuntimeError( 71 "The vision transformer backend can only be initialized if segment anything is installed. " 72 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 73 "and then rerun your code." 74 ) 75 76 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 77 self.chunks_for_projection = global_attn_indexes 78 self.in_chans = in_chans 79 self.embed_dim = embed_dim 80 81 def forward(self, x: torch.Tensor) -> torch.Tensor: 82 """Apply the vision transformer to input data. 83 84 Args: 85 x: The input data. 86 87 Returns: 88 The vision transformer output. 89 """ 90 x = self.patch_embed(x) 91 if self.pos_embed is not None: 92 x = x + self.pos_embed 93 94 list_from_encoder = [] 95 for i, blk in enumerate(self.blocks): 96 x = blk(x) 97 if i in self.chunks_for_projection: 98 list_from_encoder.append(x) 99 100 x = x.permute(0, 3, 1, 2) 101 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 102 return x, list_from_encoder[:3]
Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
Arguments:
- in_chans: The number of input channels.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- global_attn_indexes: The global attention indices.
- kwargs: Keyword arguments for the image encoder base class.
62 def __init__( 63 self, 64 in_chans: int = 3, 65 embed_dim: int = 768, 66 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 67 **kwargs, 68 ) -> None: 69 if not _sam_import_success: 70 raise RuntimeError( 71 "The vision transformer backend can only be initialized if segment anything is installed. " 72 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 73 "and then rerun your code." 74 ) 75 76 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 77 self.chunks_for_projection = global_attn_indexes 78 self.in_chans = in_chans 79 self.embed_dim = embed_dim
81 def forward(self, x: torch.Tensor) -> torch.Tensor: 82 """Apply the vision transformer to input data. 83 84 Args: 85 x: The input data. 86 87 Returns: 88 The vision transformer output. 89 """ 90 x = self.patch_embed(x) 91 if self.pos_embed is not None: 92 x = x + self.pos_embed 93 94 list_from_encoder = [] 95 for i, blk in enumerate(self.blocks): 96 x = blk(x) 97 if i in self.chunks_for_projection: 98 list_from_encoder.append(x) 99 100 x = x.permute(0, 3, 1, 2) 101 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 102 return x, list_from_encoder[:3]
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
105class ViT_MAE(VisionTransformer): 106 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 107 108 Based on: 109 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 110 111 Args: 112 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 113 in_chans: The number of input channels. 114 depth: The depth of the vision transformer. 115 kwargs: Additional keyword arguments for the vision transformer base class. 116 """ 117 def __init__( 118 self, 119 img_size: int = 1024, # chosen to match our experiments with segment anything 120 in_chans: int = 3, 121 depth: int = 12, 122 **kwargs 123 ): 124 if not _timm_import_success: 125 raise RuntimeError( 126 "The vision transformer backend can only be initialized if timm is installed. " 127 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 128 "and then rerun your code" 129 ) 130 super().__init__(img_size=img_size, depth=depth, **kwargs) 131 self.img_size = img_size 132 self.in_chans = in_chans 133 self.depth = depth 134 135 def convert_to_expected_dim(self, inputs_): 136 """@private 137 """ 138 inputs_ = inputs_[:, 1:, :] # removing the class tokens 139 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 140 rdim = inputs_.shape[1] 141 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 142 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 143 inputs_ = inputs_.permute(0, 3, 1, 2) 144 return inputs_ 145 146 def forward_features(self, x): 147 """@private 148 """ 149 B = x.shape[0] 150 x = self.patch_embed(x) 151 152 cls_tokens = self.cls_token.expand(B, -1, -1) 153 x = torch.cat((cls_tokens, x), dim=1) 154 155 x = x + self.pos_embed 156 x = self.pos_drop(x) 157 158 # chunks obtained for getting the projections for conjuctions with upsampling blocks 159 _chunks = int(self.depth / 4) 160 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 161 162 list_from_encoder = [] 163 for i, blk in enumerate(self.blocks): 164 x = blk(x) 165 if i in chunks_for_projection: 166 list_from_encoder.append(self.convert_to_expected_dim(x)) 167 168 x = self.convert_to_expected_dim(x) 169 return x, list_from_encoder[:3] 170 171 def forward(self, x: torch.Tensor) -> torch.Tensor: 172 """Apply the vision transformer to input data. 173 174 Args: 175 x: The input data. 176 177 Returns: 178 The vision transformer output. 179 """ 180 x, list_from_encoder = self.forward_features(x) 181 return x, list_from_encoder
Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- in_chans: The number of input channels.
- depth: The depth of the vision transformer.
- kwargs: Additional keyword arguments for the vision transformer base class.
117 def __init__( 118 self, 119 img_size: int = 1024, # chosen to match our experiments with segment anything 120 in_chans: int = 3, 121 depth: int = 12, 122 **kwargs 123 ): 124 if not _timm_import_success: 125 raise RuntimeError( 126 "The vision transformer backend can only be initialized if timm is installed. " 127 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 128 "and then rerun your code" 129 ) 130 super().__init__(img_size=img_size, depth=depth, **kwargs) 131 self.img_size = img_size 132 self.in_chans = in_chans 133 self.depth = depth
171 def forward(self, x: torch.Tensor) -> torch.Tensor: 172 """Apply the vision transformer to input data. 173 174 Args: 175 x: The input data. 176 177 Returns: 178 The vision transformer output. 179 """ 180 x, list_from_encoder = self.forward_features(x) 181 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
184class ViT_Sam2(ImageEncoder): 185 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 186 187 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 188 189 Args: 190 backbone_channel_list: The channels throughout the entire backbone. 191 embed_dim: The initial embedding dimension. 192 num_heads: The initial number of heads. 193 stages: The number of blocks per stage. 194 global_att_blocks: The parameter to decide which blocks have global attention. 195 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 196 window_spec: The window size per stage, when not using global attention. 197 scalp: The count of lowest resolution features to discard. 198 """ 199 def __init__( 200 self, 201 backbone_channel_list: List[int], 202 img_size: int = 1024, 203 embed_dim: int = 96, 204 num_heads: int = 1, 205 stages: Tuple[int, ...] = (2, 3, 16, 3), 206 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 207 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 208 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 209 scalp: int = 1, 210 ): 211 if not _sam2_import_success: 212 raise RuntimeError( 213 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 214 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 215 "and then rerun your code" 216 ) 217 218 trunk = Hiera( 219 embed_dim=embed_dim, 220 num_heads=num_heads, 221 stages=stages, 222 global_att_blocks=global_att_blocks, 223 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 224 window_spec=window_spec, 225 ) 226 neck = FpnNeck( 227 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 228 d_model=256, 229 backbone_channel_list=backbone_channel_list, 230 fpn_top_down_levels=[2, 3], 231 fpn_interp_model="nearest", 232 ) 233 234 super().__init__(trunk=trunk, neck=neck, scalp=scalp) 235 self.scalp = scalp 236 self.embed_dim = embed_dim 237 self.img_size = img_size 238 239 def forward(self, x: torch.Tensor): 240 # The forward pass throught the backbone. 241 features, pos = self.neck(self.trunk(x)) 242 if self.scalp > 0: # This discard the "n" lowest resolution features. 243 features, pos = features[:-self.scalp], pos[:-self.scalp] 244 245 return features[-1], features
Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
Arguments:
- backbone_channel_list: The channels throughout the entire backbone.
- embed_dim: The initial embedding dimension.
- num_heads: The initial number of heads.
- stages: The number of blocks per stage.
- global_att_blocks: The parameter to decide which blocks have global attention.
- window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
- window_spec: The window size per stage, when not using global attention.
- scalp: The count of lowest resolution features to discard.
199 def __init__( 200 self, 201 backbone_channel_list: List[int], 202 img_size: int = 1024, 203 embed_dim: int = 96, 204 num_heads: int = 1, 205 stages: Tuple[int, ...] = (2, 3, 16, 3), 206 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 207 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 208 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 209 scalp: int = 1, 210 ): 211 if not _sam2_import_success: 212 raise RuntimeError( 213 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 214 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 215 "and then rerun your code" 216 ) 217 218 trunk = Hiera( 219 embed_dim=embed_dim, 220 num_heads=num_heads, 221 stages=stages, 222 global_att_blocks=global_att_blocks, 223 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 224 window_spec=window_spec, 225 ) 226 neck = FpnNeck( 227 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 228 d_model=256, 229 backbone_channel_list=backbone_channel_list, 230 fpn_top_down_levels=[2, 3], 231 fpn_interp_model="nearest", 232 ) 233 234 super().__init__(trunk=trunk, neck=neck, scalp=scalp) 235 self.scalp = scalp 236 self.embed_dim = embed_dim 237 self.img_size = img_size
239 def forward(self, x: torch.Tensor): 240 # The forward pass throught the backbone. 241 features, pos = self.neck(self.trunk(x)) 242 if self.scalp > 0: # This discard the "n" lowest resolution features. 243 features, pos = features[:-self.scalp], pos[:-self.scalp] 244 245 return features[-1], features
253class CustomCompose: 254 def __init__(self, rescale_transform, other_transforms, src_transform): 255 self.rescale_transform = rescale_transform 256 self.other_transforms = other_transforms 257 self.src_transform = src_transform 258 259 def __call__(self, x, valid_masks=None): 260 if valid_masks is not None: 261 nodata = (x * (1 - valid_masks.float())).max() 262 x_aug = self.rescale_transform(x) 263 parms = self.rescale_transform._params 264 265 # sanity check, comment if this is working 266 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 267 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 268 269 if valid_masks is not None: 270 valid_masks = x_aug != nodata 271 _, c, h, w = x_aug.shape 272 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 273 else: 274 zero_ratio = -1 275 276 if self.other_transforms: 277 x_aug = self.other_transforms(x_aug) 278 x_src = self.src_transform(x_aug) 279 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 280 281 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 282 # assert (dx == dy).all() 283 284 h, w = x_aug.shape[-2:] 285 # assert h == w 286 287 return x_aug, x_src, dx / h, zero_ratio, valid_masks
290def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 291 """ 292 grid_size: int of the grid height and width 293 res: array of size n, representing the resolution of a pixel (say, in meters), 294 return: 295 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 296 """ 297 # res = torch.FloatTensor(res).to(device) 298 res = res.to(device) 299 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 300 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 301 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 302 grid = torch.stack(grid, dim=0) # 2 x h x w 303 304 # grid = grid.reshape([2, 1, grid_size, grid_size]) 305 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 306 _, n, h, w = grid.shape 307 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 308 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 309 if cls_token: 310 pos_embed = torch.cat( 311 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 312 ) 313 314 return pos_embed
grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
317def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 318 assert embed_dim % 2 == 0 319 320 # use half of dimensions to encode grid_h 321 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 322 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 323 324 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 325 return emb
328def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 329 """ 330 embed_dim: output dimension for each position 331 pos: a list of positions to be encoded: size (M,) 332 out: (M, D) 333 """ 334 assert embed_dim % 2 == 0 335 # old_shape = pos 336 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 337 omega /= embed_dim / 2.0 338 omega = 1.0 / 10000**omega # (D/2,) 339 340 pos = pos.reshape(-1) # (M,) 341 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 342 343 emb_sin = torch.sin(out) # (M, D/2) 344 emb_cos = torch.cos(out) # (M, D/2) 345 346 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 347 return emb
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
350class PatchEmbedUnSafe(PatchEmbed): 351 """Image to Patch Embedding""" 352 353 def forward(self, x): 354 B, C, H, W = x.shape 355 356 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 357 # assert H == self.img_size[0] and W == self.img_size[1], \ 358 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 359 360 x = self.proj(x).flatten(2).transpose(1, 2) 361 return x
Image to Patch Embedding
353 def forward(self, x): 354 B, C, H, W = x.shape 355 356 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 357 # assert H == self.img_size[0] and W == self.img_size[1], \ 358 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 359 360 x = self.proj(x).flatten(2).transpose(1, 2) 361 return x
364class ViT_ScaleMAE(VisionTransformer): 365 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 366 367 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 368 the model on a different zoom factor dataset. 369 """ 370 371 def __init__( 372 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 373 ): 374 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 375 self.img_size = img_size 376 self.in_chans = in_chans 377 self.depth = depth 378 self.base_resolution = base_resolution 379 380 self.patch_embed = PatchEmbedUnSafe( 381 img_size=img_size, 382 patch_size=patch_size, 383 in_chans=in_chans, 384 embed_dim=embed_dim, 385 ) 386 387 def transform_inputs(self, x): 388 import kornia.augmentation as K 389 from kornia.constants import Resample 390 391 self._transforms = CustomCompose( 392 rescale_transform=K.RandomResizedCrop( 393 (448, 448), 394 ratio=(1.0, 1.0), 395 scale=(1.0, 1.0), 396 resample=Resample.BICUBIC.name, 397 ), 398 other_transforms=None, 399 src_transform=K.Resize((224, 224)), 400 ) 401 x, _, ratios, _, _ = self._transforms(x) 402 input_res = ratios * self.base_resolution 403 return x, input_res 404 405 def convert_to_expected_dim(self, x): 406 inputs_ = x[:, 1:, :] # removing the class tokens 407 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 408 rdim = inputs_.shape[1] 409 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 410 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 411 inputs_ = inputs_.permute(0, 3, 1, 2) 412 return inputs_ 413 414 def forward_features(self, x): 415 x, input_res = self.transform_inputs(x) 416 417 B, _, h, w = x.shape 418 x = self.patch_embed(x) 419 420 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 421 pos_embed = get_2d_sincos_pos_embed_with_resolution( 422 x.shape[-1], 423 int(num_patches ** 0.5), 424 input_res, 425 cls_token=True, 426 device=x.device, 427 ) 428 429 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 430 x = torch.cat((cls_tokens, x), dim=1) 431 x = x + pos_embed 432 x = self.pos_drop(x) 433 434 # chunks obtained for getting the projections for conjuctions with upsampling blocks 435 _chunks = int(self.depth / 4) 436 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 437 438 list_from_encoder = [] 439 for i, blk in enumerate(self.blocks): 440 x = blk(x) 441 if i in chunks_for_projection: 442 list_from_encoder.append(self.convert_to_expected_dim(x)) 443 444 x = self.convert_to_expected_dim(x) 445 446 return x, list_from_encoder 447 448 def forward(self, x): 449 x, list_from_encoder = self.forward_features(x) 450 return x, list_from_encoder
Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.
371 def __init__( 372 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 373 ): 374 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 375 self.img_size = img_size 376 self.in_chans = in_chans 377 self.depth = depth 378 self.base_resolution = base_resolution 379 380 self.patch_embed = PatchEmbedUnSafe( 381 img_size=img_size, 382 patch_size=patch_size, 383 in_chans=in_chans, 384 embed_dim=embed_dim, 385 )
387 def transform_inputs(self, x): 388 import kornia.augmentation as K 389 from kornia.constants import Resample 390 391 self._transforms = CustomCompose( 392 rescale_transform=K.RandomResizedCrop( 393 (448, 448), 394 ratio=(1.0, 1.0), 395 scale=(1.0, 1.0), 396 resample=Resample.BICUBIC.name, 397 ), 398 other_transforms=None, 399 src_transform=K.Resize((224, 224)), 400 ) 401 x, _, ratios, _, _ = self._transforms(x) 402 input_res = ratios * self.base_resolution 403 return x, input_res
405 def convert_to_expected_dim(self, x): 406 inputs_ = x[:, 1:, :] # removing the class tokens 407 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 408 rdim = inputs_.shape[1] 409 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 410 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 411 inputs_ = inputs_.permute(0, 3, 1, 2) 412 return inputs_
414 def forward_features(self, x): 415 x, input_res = self.transform_inputs(x) 416 417 B, _, h, w = x.shape 418 x = self.patch_embed(x) 419 420 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 421 pos_embed = get_2d_sincos_pos_embed_with_resolution( 422 x.shape[-1], 423 int(num_patches ** 0.5), 424 input_res, 425 cls_token=True, 426 device=x.device, 427 ) 428 429 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 430 x = torch.cat((cls_tokens, x), dim=1) 431 x = x + pos_embed 432 x = self.pos_drop(x) 433 434 # chunks obtained for getting the projections for conjuctions with upsampling blocks 435 _chunks = int(self.depth / 4) 436 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 437 438 list_from_encoder = [] 439 for i, blk in enumerate(self.blocks): 440 x = blk(x) 441 if i in chunks_for_projection: 442 list_from_encoder.append(self.convert_to_expected_dim(x)) 443 444 x = self.convert_to_expected_dim(x) 445 446 return x, list_from_encoder
453class ViT_DINOv2(DinoV2VisionTransformer): 454 """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193). 455 456 Based on: 457 https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py. 458 """ 459 def __init__( 460 self, 461 img_size: int = 224, 462 patch_size: int = 16, 463 depth: int = 12, 464 num_register_tokens: int = 0, 465 **kwargs 466 ): 467 if not _dinov2_import_success: 468 raise RuntimeError( 469 "The vision transformer backend can only be initialized if DINOv2 is installed. " 470 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 471 "and then rerun your code." 472 ) 473 474 super().__init__( 475 img_size=img_size, 476 depth=depth, 477 patch_size=patch_size, 478 num_register_tokens=num_register_tokens, 479 **kwargs 480 ) 481 482 self.img_size = img_size 483 self.num_register_tokens = num_register_tokens 484 self.patch_size = patch_size 485 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 486 487 def forward(self, x, masks=None) -> torch.Tensor: 488 489 B = x.shape[0] 490 491 x = self.prepare_tokens_with_masks(x) 492 493 list_of_encoder = [] 494 for i, blk in enumerate(self.blocks): 495 x = blk(x) 496 if i in self.attn_outs: 497 list_of_encoder.append(x) 498 499 x = self.norm(x) 500 x = x[:, self.num_register_tokens + 1:].reshape( 501 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 502 ).permute(0, 3, 1, 2).contiguous() 503 504 list_of_encoder = [ 505 o[:, self.num_register_tokens + 1:].reshape( 506 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 507 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 508 ] 509 510 return x, list_of_encoder[:3]
Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
Based on: https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
459 def __init__( 460 self, 461 img_size: int = 224, 462 patch_size: int = 16, 463 depth: int = 12, 464 num_register_tokens: int = 0, 465 **kwargs 466 ): 467 if not _dinov2_import_success: 468 raise RuntimeError( 469 "The vision transformer backend can only be initialized if DINOv2 is installed. " 470 "Please install DINOv2 from https://github.com/facebookresearch/dinov2 " 471 "and then rerun your code." 472 ) 473 474 super().__init__( 475 img_size=img_size, 476 depth=depth, 477 patch_size=patch_size, 478 num_register_tokens=num_register_tokens, 479 **kwargs 480 ) 481 482 self.img_size = img_size 483 self.num_register_tokens = num_register_tokens 484 self.patch_size = patch_size 485 self.attn_outs = [i for i in range(depth) if i % 3 == 2]
487 def forward(self, x, masks=None) -> torch.Tensor: 488 489 B = x.shape[0] 490 491 x = self.prepare_tokens_with_masks(x) 492 493 list_of_encoder = [] 494 for i, blk in enumerate(self.blocks): 495 x = blk(x) 496 if i in self.attn_outs: 497 list_of_encoder.append(x) 498 499 x = self.norm(x) 500 x = x[:, self.num_register_tokens + 1:].reshape( 501 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 502 ).permute(0, 3, 1, 2).contiguous() 503 504 list_of_encoder = [ 505 o[:, self.num_register_tokens + 1:].reshape( 506 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 507 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 508 ] 509 510 return x, list_of_encoder[:3]
513class ViT_DINOv3(DinoV3VisionTransformer): 514 """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104). 515 516 Based on: 517 https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py. 518 519 Args: 520 img_size: The input image size. 521 patch_size: The patch size. 522 embed_dim: The embedding dimension. 523 depth: The depth of the network. 524 num_heads: The number of heads. 525 ffn_ratio: The FFN rato. 526 n_storage_tokens: The number of storage (class) tokens to remove. 527 kwargs: Keyword arguments for the image encoder base class. 528 """ 529 def __init__( 530 self, 531 in_chans: int = 3, 532 img_size: int = 224, 533 patch_size: int = 16, 534 embed_dim: int = 768, 535 depth: int = 12, 536 num_heads: int = 12, 537 ffn_ratio: float = 4.0, 538 n_storage_tokens: int = 0, 539 **kwargs 540 ): 541 if not _dinov3_import_success: 542 raise RuntimeError( 543 "The vision transformer backend can only be initialized if DINOv3 is installed. " 544 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 545 "and then rerun your code." 546 ) 547 548 super().__init__( 549 in_chans=in_chans, 550 img_size=img_size, 551 patch_size=patch_size, 552 embed_dim=embed_dim, 553 depth=depth, 554 num_heads=num_heads, 555 ffn_ratio=ffn_ratio, 556 n_storage_tokens=n_storage_tokens, 557 **kwargs 558 ) 559 560 self.in_chans = in_chans 561 self.img_size = img_size 562 self.n_storage_tokens = n_storage_tokens 563 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 564 565 def forward(self, x) -> torch.Tensor: 566 567 B = x.shape[0] 568 569 x, hw_tuple = self.prepare_tokens_with_masks(x) 570 571 list_of_encoder = [] 572 for i, blk in enumerate(self.blocks): 573 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 574 x = blk(x, rope_sincos) 575 if i in self.attn_outs: 576 list_of_encoder.append(x) 577 578 x = self.norm(x) 579 x = x[:, self.n_storage_tokens + 1:].reshape( 580 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 581 ).permute(0, 3, 1, 2).contiguous() 582 583 list_of_encoder = [ 584 o[:, self.n_storage_tokens + 1:].reshape( 585 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 586 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 587 ] 588 589 return x, list_of_encoder[:3]
Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
Based on: https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
Arguments:
- img_size: The input image size.
- patch_size: The patch size.
- embed_dim: The embedding dimension.
- depth: The depth of the network.
- num_heads: The number of heads.
- ffn_ratio: The FFN rato.
- n_storage_tokens: The number of storage (class) tokens to remove.
- kwargs: Keyword arguments for the image encoder base class.
529 def __init__( 530 self, 531 in_chans: int = 3, 532 img_size: int = 224, 533 patch_size: int = 16, 534 embed_dim: int = 768, 535 depth: int = 12, 536 num_heads: int = 12, 537 ffn_ratio: float = 4.0, 538 n_storage_tokens: int = 0, 539 **kwargs 540 ): 541 if not _dinov3_import_success: 542 raise RuntimeError( 543 "The vision transformer backend can only be initialized if DINOv3 is installed. " 544 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 545 "and then rerun your code." 546 ) 547 548 super().__init__( 549 in_chans=in_chans, 550 img_size=img_size, 551 patch_size=patch_size, 552 embed_dim=embed_dim, 553 depth=depth, 554 num_heads=num_heads, 555 ffn_ratio=ffn_ratio, 556 n_storage_tokens=n_storage_tokens, 557 **kwargs 558 ) 559 560 self.in_chans = in_chans 561 self.img_size = img_size 562 self.n_storage_tokens = n_storage_tokens 563 self.attn_outs = [i for i in range(depth) if i % 3 == 2]
565 def forward(self, x) -> torch.Tensor: 566 567 B = x.shape[0] 568 569 x, hw_tuple = self.prepare_tokens_with_masks(x) 570 571 list_of_encoder = [] 572 for i, blk in enumerate(self.blocks): 573 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 574 x = blk(x, rope_sincos) 575 if i in self.attn_outs: 576 list_of_encoder.append(x) 577 578 x = self.norm(x) 579 x = x[:, self.n_storage_tokens + 1:].reshape( 580 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 581 ).permute(0, 3, 1, 2).contiguous() 582 583 list_of_encoder = [ 584 o[:, self.n_storage_tokens + 1:].reshape( 585 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 586 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 587 ] 588 589 return x, list_of_encoder[:3]
592def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 593 """Get vision transformer encoder. 594 595 Args: 596 backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae". 597 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 598 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 599 kwargs: Additional kwargs which can be expected by the vision transformer, 600 e.g. 'base_resolution' for `ViT_ScaleMAE`. 601 602 Returns: 603 The vision transformer. 604 """ 605 if backbone == "sam": 606 if model == "vit_b": 607 encoder = ViT_Sam( 608 depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4, 609 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 610 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 611 global_attn_indexes=[2, 5, 8, 11], 612 window_size=14, out_chans=256, 613 ) 614 elif model == "vit_l": 615 encoder = ViT_Sam( 616 depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4, 617 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 618 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 619 global_attn_indexes=[5, 11, 17, 23], 620 window_size=14, out_chans=256, 621 ) 622 elif model == "vit_h": 623 encoder = ViT_Sam( 624 depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4, 625 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 626 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 627 global_attn_indexes=[7, 15, 23, 31], 628 window_size=14, out_chans=256, 629 ) 630 else: 631 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 632 633 elif backbone == "sam2": 634 if model == "hvit_t": 635 encoder = ViT_Sam2( 636 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 637 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 638 ) 639 elif model == "hvit_s": 640 encoder = ViT_Sam2( 641 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 642 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 643 ) 644 elif model == "hvit_b": 645 encoder = ViT_Sam2( 646 img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 647 ) 648 elif model == "hvit_l": 649 encoder = ViT_Sam2( 650 img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 651 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 652 ) 653 else: 654 raise ValueError( 655 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 656 ) 657 658 elif backbone == "mae": 659 if model == "vit_b": 660 encoder = ViT_MAE( 661 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 662 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 663 ) 664 elif model == "vit_l": 665 encoder = ViT_MAE( 666 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 667 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 668 ) 669 elif model == "vit_h": 670 encoder = ViT_MAE( 671 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 672 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 673 ) 674 else: 675 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 676 677 elif backbone == "scalemae": 678 base_resolution = kwargs.get("base_resolution", 2.5) 679 680 if model == "vit_b": 681 encoder = ViT_ScaleMAE( 682 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 683 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 684 base_resolution=base_resolution, 685 ) 686 elif model == "vit_l": 687 encoder = ViT_ScaleMAE( 688 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 689 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 690 base_resolution=base_resolution, 691 ) 692 elif model == "vit_h": 693 encoder = ViT_ScaleMAE( 694 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 695 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 696 base_resolution=base_resolution, 697 ) 698 else: 699 raise ValueError( 700 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 701 ) 702 703 elif backbone == "dinov2": 704 block_fn = partial(Block, attn_class=MemEffAttention) 705 msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>." 706 707 if model.startswith("vit_s"): 708 assert model in ["vit_s", "vit_s_reg4"], msg 709 encoder = ViT_DINOv2( 710 img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 711 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 712 num_register_tokens=4 if model.endswith("_reg4") else 0, 713 ) 714 elif model.startswith("vit_b"): 715 assert model in ["vit_b", "vit_b_reg4"], msg 716 encoder = ViT_DINOv2( 717 img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 718 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 719 num_register_tokens=4 if model.endswith("_reg4") else 0, 720 ) 721 elif model.startswith("vit_l"): 722 assert model in ["vit_l", "vit_l_reg4"], msg 723 encoder = ViT_DINOv2( 724 img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 725 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 726 num_register_tokens=4 if model.endswith("_reg4") else 0, 727 ) 728 elif model.startswith("vit_g"): 729 assert model in ["vit_g", "vit_g_reg4"], msg 730 encoder = ViT_DINOv2( 731 img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, 732 block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0, 733 num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu", 734 ) 735 else: 736 raise ValueError( 737 f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported." 738 ) 739 740 elif backbone == "dinov3": 741 742 if model == "vit_s": 743 encoder = ViT_DINOv3( 744 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 745 num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 746 ) 747 elif model == "vit_s+": 748 encoder = ViT_DINOv3( 749 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 750 num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", 751 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 752 ) 753 754 elif model == "vit_b": 755 encoder = ViT_DINOv3( 756 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", 757 layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 758 ) 759 elif model == "vit_l": 760 encoder = ViT_DINOv3( 761 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 762 depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16", 763 n_storage_tokens=4, mask_k_bias=True, 764 ) 765 elif model == "vit_l+": 766 encoder = ViT_DINOv3( 767 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 768 depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 769 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 770 ) 771 elif model == "vit_h+": 772 encoder = ViT_DINOv3( 773 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280, 774 depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 775 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 776 ) 777 elif model == "vit_7b": 778 encoder = ViT_DINOv3( 779 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096, 780 depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, 781 norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True, 782 untie_global_and_local_cls_norm=True, 783 ) 784 else: 785 raise ValueError( 786 f"'{model}' is not supported by DINOv3. Currently, " 787 " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported." 788 ) 789 790 else: 791 raise ValueError( 792 "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae', 'scalemae' or 'dinov3'. " 793 "Please choose one of them." 794 ) 795 796 return encoder
Get vision transformer encoder.
Arguments:
- backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
- model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- kwargs: Additional kwargs which can be expected by the vision transformer,
e.g. 'base_resolution' for
ViT_ScaleMAE.
Returns:
The vision transformer.