torch_em.model.vit
1from functools import partial 2from typing import Tuple, List 3 4import torch 5import torch.nn as nn 6 7# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should 8# only be optional dependencies for torch_em 9try: 10 from segment_anything.modeling import ImageEncoderViT 11 _sam_import_success = True 12except ImportError: 13 ImageEncoderViT = object 14 _sam_import_success = False 15 16try: 17 from timm.models.vision_transformer import VisionTransformer, PatchEmbed 18 _timm_import_success = True 19except ImportError: 20 VisionTransformer = object 21 PatchEmbed = object 22 _timm_import_success = False 23 24try: 25 from sam2.modeling.backbones.hieradet import Hiera 26 from sam2.modeling.position_encoding import PositionEmbeddingSine 27 from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck 28 _sam2_import_success = True 29except ImportError: 30 ImageEncoder = object 31 _sam2_import_success = False 32 33 34try: 35 from dinov3.models.vision_transformer import DinoVisionTransformer 36 _dinov3_import_success = True 37except ImportError: 38 DinoVisionTransformer = object 39 _dinov3_import_success = False 40 41 42class ViT_Sam(ImageEncoderViT): 43 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 44 45 Based on: 46 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 47 48 Args: 49 in_chans: The number of input channels. 50 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 51 global_attn_indexes: The global attention indices. 52 kwargs: Keyword arguments for the image encoder base class. 53 """ 54 def __init__( 55 self, 56 in_chans: int = 3, 57 embed_dim: int = 768, 58 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 59 **kwargs, 60 ) -> None: 61 if not _sam_import_success: 62 raise RuntimeError( 63 "The vision transformer backend can only be initialized if segment anything is installed. " 64 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 65 "and then rerun your code." 66 ) 67 68 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 69 self.chunks_for_projection = global_attn_indexes 70 self.in_chans = in_chans 71 self.embed_dim = embed_dim 72 73 def forward(self, x: torch.Tensor) -> torch.Tensor: 74 """Apply the vision transformer to input data. 75 76 Args: 77 x: The input data. 78 79 Returns: 80 The vision transformer output. 81 """ 82 x = self.patch_embed(x) 83 if self.pos_embed is not None: 84 x = x + self.pos_embed 85 86 list_from_encoder = [] 87 for i, blk in enumerate(self.blocks): 88 x = blk(x) 89 if i in self.chunks_for_projection: 90 list_from_encoder.append(x) 91 92 x = x.permute(0, 3, 1, 2) 93 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 94 return x, list_from_encoder[:3] 95 96 97class ViT_MAE(VisionTransformer): 98 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 99 100 Based on: 101 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 102 103 Args: 104 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 105 in_chans: The number of input channels. 106 depth: The depth of the vision transformer. 107 kwargs: Additional keyword arguments for the vision transformer base class. 108 """ 109 def __init__( 110 self, 111 img_size: int = 1024, # chosen to match our experiments with segment anything 112 in_chans: int = 3, 113 depth: int = 12, 114 **kwargs 115 ): 116 if not _timm_import_success: 117 raise RuntimeError( 118 "The vision transformer backend can only be initialized if timm is installed. " 119 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 120 "and then rerun your code" 121 ) 122 super().__init__(img_size=img_size, depth=depth, **kwargs) 123 self.img_size = img_size 124 self.in_chans = in_chans 125 self.depth = depth 126 127 def convert_to_expected_dim(self, inputs_): 128 """@private 129 """ 130 inputs_ = inputs_[:, 1:, :] # removing the class tokens 131 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 132 rdim = inputs_.shape[1] 133 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 134 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 135 inputs_ = inputs_.permute(0, 3, 1, 2) 136 return inputs_ 137 138 def forward_features(self, x): 139 """@private 140 """ 141 B = x.shape[0] 142 x = self.patch_embed(x) 143 144 cls_tokens = self.cls_token.expand(B, -1, -1) 145 x = torch.cat((cls_tokens, x), dim=1) 146 147 x = x + self.pos_embed 148 x = self.pos_drop(x) 149 150 # chunks obtained for getting the projections for conjuctions with upsampling blocks 151 _chunks = int(self.depth / 4) 152 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 153 154 list_from_encoder = [] 155 for i, blk in enumerate(self.blocks): 156 x = blk(x) 157 if i in chunks_for_projection: 158 list_from_encoder.append(self.convert_to_expected_dim(x)) 159 160 x = self.convert_to_expected_dim(x) 161 return x, list_from_encoder[:3] 162 163 def forward(self, x: torch.Tensor) -> torch.Tensor: 164 """Apply the vision transformer to input data. 165 166 Args: 167 x: The input data. 168 169 Returns: 170 The vision transformer output. 171 """ 172 x, list_from_encoder = self.forward_features(x) 173 return x, list_from_encoder 174 175 176class ViT_Sam2(ImageEncoder): 177 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 178 179 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 180 181 Args: 182 backbone_channel_list: The channels throughout the entire backbone. 183 embed_dim: The initial embedding dimension. 184 num_heads: The initial number of heads. 185 stages: The number of blocks per stage. 186 global_att_blocks: The parameter to decide which blocks have global attention. 187 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 188 window_spec: The window size per stage, when not using global attention. 189 scalp: The count of lowest resolution features to discard. 190 """ 191 def __init__( 192 self, 193 backbone_channel_list: List[int], 194 img_size: int = 1024, 195 embed_dim: int = 96, 196 num_heads: int = 1, 197 stages: Tuple[int, ...] = (2, 3, 16, 3), 198 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 199 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 200 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 201 scalp: int = 1, 202 ): 203 if not _sam2_import_success: 204 raise RuntimeError( 205 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 206 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 207 "and then rerun your code" 208 ) 209 210 trunk = Hiera( 211 embed_dim=embed_dim, 212 num_heads=num_heads, 213 stages=stages, 214 global_att_blocks=global_att_blocks, 215 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 216 window_spec=window_spec, 217 ) 218 neck = FpnNeck( 219 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 220 d_model=256, 221 backbone_channel_list=backbone_channel_list, 222 fpn_top_down_levels=[2, 3], 223 fpn_interp_model="nearest", 224 ) 225 226 super().__init__(trunk=trunk, neck=neck, scalp=scalp) 227 self.scalp = scalp 228 self.embed_dim = embed_dim 229 self.img_size = img_size 230 231 def forward(self, x: torch.Tensor): 232 # The forward pass throught the backbone. 233 features, pos = self.neck(self.trunk(x)) 234 if self.scalp > 0: # This discard the "n" lowest resolution features. 235 features, pos = features[:-self.scalp], pos[:-self.scalp] 236 237 return features[-1], features 238 239 240# 241# Utilities for ScaleMAE's ViT 242# 243 244 245class CustomCompose: 246 def __init__(self, rescale_transform, other_transforms, src_transform): 247 self.rescale_transform = rescale_transform 248 self.other_transforms = other_transforms 249 self.src_transform = src_transform 250 251 def __call__(self, x, valid_masks=None): 252 if valid_masks is not None: 253 nodata = (x * (1 - valid_masks.float())).max() 254 x_aug = self.rescale_transform(x) 255 parms = self.rescale_transform._params 256 257 # sanity check, comment if this is working 258 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 259 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 260 261 if valid_masks is not None: 262 valid_masks = x_aug != nodata 263 _, c, h, w = x_aug.shape 264 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 265 else: 266 zero_ratio = -1 267 268 if self.other_transforms: 269 x_aug = self.other_transforms(x_aug) 270 x_src = self.src_transform(x_aug) 271 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 272 273 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 274 # assert (dx == dy).all() 275 276 h, w = x_aug.shape[-2:] 277 # assert h == w 278 279 return x_aug, x_src, dx / h, zero_ratio, valid_masks 280 281 282def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 283 """ 284 grid_size: int of the grid height and width 285 res: array of size n, representing the resolution of a pixel (say, in meters), 286 return: 287 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 288 """ 289 # res = torch.FloatTensor(res).to(device) 290 res = res.to(device) 291 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 292 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 293 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 294 grid = torch.stack(grid, dim=0) # 2 x h x w 295 296 # grid = grid.reshape([2, 1, grid_size, grid_size]) 297 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 298 _, n, h, w = grid.shape 299 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 300 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 301 if cls_token: 302 pos_embed = torch.cat( 303 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 304 ) 305 306 return pos_embed 307 308 309def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 310 assert embed_dim % 2 == 0 311 312 # use half of dimensions to encode grid_h 313 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 314 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 315 316 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 317 return emb 318 319 320def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 321 """ 322 embed_dim: output dimension for each position 323 pos: a list of positions to be encoded: size (M,) 324 out: (M, D) 325 """ 326 assert embed_dim % 2 == 0 327 # old_shape = pos 328 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 329 omega /= embed_dim / 2.0 330 omega = 1.0 / 10000**omega # (D/2,) 331 332 pos = pos.reshape(-1) # (M,) 333 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 334 335 emb_sin = torch.sin(out) # (M, D/2) 336 emb_cos = torch.cos(out) # (M, D/2) 337 338 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 339 return emb 340 341 342class PatchEmbedUnSafe(PatchEmbed): 343 """Image to Patch Embedding""" 344 345 def forward(self, x): 346 B, C, H, W = x.shape 347 348 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 349 # assert H == self.img_size[0] and W == self.img_size[1], \ 350 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 351 352 x = self.proj(x).flatten(2).transpose(1, 2) 353 return x 354 355 356class ViT_ScaleMAE(VisionTransformer): 357 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 358 359 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 360 the model on a different zoom factor dataset. 361 """ 362 363 def __init__( 364 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 365 ): 366 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 367 self.img_size = img_size 368 self.in_chans = in_chans 369 self.depth = depth 370 self.base_resolution = base_resolution 371 372 self.patch_embed = PatchEmbedUnSafe( 373 img_size=img_size, 374 patch_size=patch_size, 375 in_chans=in_chans, 376 embed_dim=embed_dim, 377 ) 378 379 def transform_inputs(self, x): 380 import kornia.augmentation as K 381 from kornia.constants import Resample 382 383 self._transforms = CustomCompose( 384 rescale_transform=K.RandomResizedCrop( 385 (448, 448), 386 ratio=(1.0, 1.0), 387 scale=(1.0, 1.0), 388 resample=Resample.BICUBIC.name, 389 ), 390 other_transforms=None, 391 src_transform=K.Resize((224, 224)), 392 ) 393 x, _, ratios, _, _ = self._transforms(x) 394 input_res = ratios * self.base_resolution 395 return x, input_res 396 397 def convert_to_expected_dim(self, x): 398 inputs_ = x[:, 1:, :] # removing the class tokens 399 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 400 rdim = inputs_.shape[1] 401 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 402 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 403 inputs_ = inputs_.permute(0, 3, 1, 2) 404 return inputs_ 405 406 def forward_features(self, x): 407 x, input_res = self.transform_inputs(x) 408 409 B, _, h, w = x.shape 410 x = self.patch_embed(x) 411 412 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 413 pos_embed = get_2d_sincos_pos_embed_with_resolution( 414 x.shape[-1], 415 int(num_patches ** 0.5), 416 input_res, 417 cls_token=True, 418 device=x.device, 419 ) 420 421 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 422 x = torch.cat((cls_tokens, x), dim=1) 423 x = x + pos_embed 424 x = self.pos_drop(x) 425 426 # chunks obtained for getting the projections for conjuctions with upsampling blocks 427 _chunks = int(self.depth / 4) 428 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 429 430 list_from_encoder = [] 431 for i, blk in enumerate(self.blocks): 432 x = blk(x) 433 if i in chunks_for_projection: 434 list_from_encoder.append(self.convert_to_expected_dim(x)) 435 436 x = self.convert_to_expected_dim(x) 437 438 return x, list_from_encoder 439 440 def forward(self, x): 441 x, list_from_encoder = self.forward_features(x) 442 return x, list_from_encoder 443 444 445class ViT_DINOv3(DinoVisionTransformer): 446 """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104). 447 448 Based on: 449 https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py. 450 451 Args: 452 img_size: The input image size. 453 patch_size: The patch size. 454 embed_dim: The embedding dimension. 455 depth: The depth of the network. 456 num_heads: The number of heads. 457 ffn_ratio: The FFN rato. 458 n_storage_tokens: The number of storage (class) tokens to remove. 459 kwargs: Keyword arguments for the image encoder base class. 460 """ 461 def __init__( 462 self, 463 in_chans: int = 3, 464 img_size: int = 224, 465 patch_size: int = 16, 466 embed_dim: int = 768, 467 depth: int = 12, 468 num_heads: int = 12, 469 ffn_ratio: float = 4.0, 470 n_storage_tokens: int = 0, 471 **kwargs 472 ): 473 if not _dinov3_import_success: 474 raise RuntimeError( 475 "The vision transformer backend can only be initialized if DINOv3 is installed. " 476 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 477 "and then rerun your code." 478 ) 479 480 super().__init__( 481 in_chans=in_chans, 482 img_size=img_size, 483 patch_size=patch_size, 484 embed_dim=embed_dim, 485 depth=depth, 486 num_heads=num_heads, 487 ffn_ratio=ffn_ratio, 488 n_storage_tokens=n_storage_tokens, 489 **kwargs 490 ) 491 492 self.in_chans = in_chans 493 self.img_size = img_size 494 self.n_storage_tokens = n_storage_tokens 495 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 496 497 def forward(self, x) -> torch.Tensor: 498 499 B = x.shape[0] 500 501 x, hw_tuple = self.prepare_tokens_with_masks(x) 502 503 list_of_encoder = [] 504 for i, blk in enumerate(self.blocks): 505 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 506 x = blk(x, rope_sincos) 507 if i in self.attn_outs: 508 list_of_encoder.append(x) 509 510 x = self.norm(x) 511 x = x[:, self.n_storage_tokens + 1:].reshape( 512 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 513 ).permute(0, 3, 1, 2).contiguous() 514 515 list_of_encoder = [ 516 o[:, self.n_storage_tokens + 1:].reshape( 517 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 518 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 519 ] 520 521 return x, list_of_encoder[:3] 522 523 524def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 525 """Get vision transformer encoder. 526 527 Args: 528 backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae". 529 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 530 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 531 kwargs: Additional kwargs which can be expected by the vision transformer, 532 e.g. 'base_resolution' for `ViT_ScaleMAE`. 533 534 Returns: 535 The vision transformer. 536 """ 537 if backbone == "sam": 538 if model == "vit_b": 539 encoder = ViT_Sam( 540 depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4, 541 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 542 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 543 global_attn_indexes=[2, 5, 8, 11], 544 window_size=14, out_chans=256, 545 ) 546 elif model == "vit_l": 547 encoder = ViT_Sam( 548 depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4, 549 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 550 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 551 global_attn_indexes=[5, 11, 17, 23], 552 window_size=14, out_chans=256, 553 ) 554 elif model == "vit_h": 555 encoder = ViT_Sam( 556 depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4, 557 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 558 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 559 global_attn_indexes=[7, 15, 23, 31], 560 window_size=14, out_chans=256, 561 ) 562 else: 563 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 564 565 elif backbone == "sam2": 566 if model == "hvit_t": 567 encoder = ViT_Sam2( 568 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 569 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 570 ) 571 elif model == "hvit_s": 572 encoder = ViT_Sam2( 573 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 574 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 575 ) 576 elif model == "hvit_b": 577 encoder = ViT_Sam2( 578 img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 579 ) 580 elif model == "hvit_l": 581 encoder = ViT_Sam2( 582 img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 583 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 584 ) 585 else: 586 raise ValueError( 587 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 588 ) 589 590 elif backbone == "mae": 591 if model == "vit_b": 592 encoder = ViT_MAE( 593 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 594 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 595 ) 596 elif model == "vit_l": 597 encoder = ViT_MAE( 598 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 599 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 600 ) 601 elif model == "vit_h": 602 encoder = ViT_MAE( 603 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 604 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 605 ) 606 else: 607 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 608 609 elif backbone == "scalemae": 610 base_resolution = kwargs.get("base_resolution", 2.5) 611 612 if model == "vit_b": 613 encoder = ViT_ScaleMAE( 614 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 615 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 616 base_resolution=base_resolution, 617 ) 618 elif model == "vit_l": 619 encoder = ViT_ScaleMAE( 620 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 621 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 622 base_resolution=base_resolution, 623 ) 624 elif model == "vit_h": 625 encoder = ViT_ScaleMAE( 626 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 627 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 628 base_resolution=base_resolution, 629 ) 630 else: 631 raise ValueError( 632 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 633 ) 634 635 elif backbone == "dinov3": 636 637 if model == "vit_s": 638 encoder = ViT_DINOv3( 639 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 640 num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 641 ) 642 elif model == "vit_s+": 643 encoder = ViT_DINOv3( 644 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 645 num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", 646 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 647 ) 648 649 elif model == "vit_b": 650 encoder = ViT_DINOv3( 651 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", 652 layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 653 ) 654 elif model == "vit_l": 655 encoder = ViT_DINOv3( 656 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 657 depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16", 658 n_storage_tokens=4, mask_k_bias=True, 659 ) 660 elif model == "vit_l+": 661 encoder = ViT_DINOv3( 662 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 663 depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 664 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 665 ) 666 elif model == "vit_h+": 667 encoder = ViT_DINOv3( 668 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280, 669 depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 670 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 671 ) 672 elif model == "vit_7b": 673 encoder = ViT_DINOv3( 674 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096, 675 depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, 676 norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True, 677 untie_global_and_local_cls_norm=True, 678 ) 679 else: 680 raise ValueError( 681 f"'{model}' is not supported by DINOv3. Currently, " 682 " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported." 683 ) 684 685 else: 686 raise ValueError( 687 "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae', 'scalemae' or 'dinov3'. " 688 "Please choose one of them." 689 ) 690 691 return encoder
43class ViT_Sam(ImageEncoderViT): 44 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 45 46 Based on: 47 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 48 49 Args: 50 in_chans: The number of input channels. 51 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 52 global_attn_indexes: The global attention indices. 53 kwargs: Keyword arguments for the image encoder base class. 54 """ 55 def __init__( 56 self, 57 in_chans: int = 3, 58 embed_dim: int = 768, 59 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 60 **kwargs, 61 ) -> None: 62 if not _sam_import_success: 63 raise RuntimeError( 64 "The vision transformer backend can only be initialized if segment anything is installed. " 65 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 66 "and then rerun your code." 67 ) 68 69 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 70 self.chunks_for_projection = global_attn_indexes 71 self.in_chans = in_chans 72 self.embed_dim = embed_dim 73 74 def forward(self, x: torch.Tensor) -> torch.Tensor: 75 """Apply the vision transformer to input data. 76 77 Args: 78 x: The input data. 79 80 Returns: 81 The vision transformer output. 82 """ 83 x = self.patch_embed(x) 84 if self.pos_embed is not None: 85 x = x + self.pos_embed 86 87 list_from_encoder = [] 88 for i, blk in enumerate(self.blocks): 89 x = blk(x) 90 if i in self.chunks_for_projection: 91 list_from_encoder.append(x) 92 93 x = x.permute(0, 3, 1, 2) 94 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 95 return x, list_from_encoder[:3]
Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
Arguments:
- in_chans: The number of input channels.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- global_attn_indexes: The global attention indices.
- kwargs: Keyword arguments for the image encoder base class.
55 def __init__( 56 self, 57 in_chans: int = 3, 58 embed_dim: int = 768, 59 global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], 60 **kwargs, 61 ) -> None: 62 if not _sam_import_success: 63 raise RuntimeError( 64 "The vision transformer backend can only be initialized if segment anything is installed. " 65 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 66 "and then rerun your code." 67 ) 68 69 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 70 self.chunks_for_projection = global_attn_indexes 71 self.in_chans = in_chans 72 self.embed_dim = embed_dim
74 def forward(self, x: torch.Tensor) -> torch.Tensor: 75 """Apply the vision transformer to input data. 76 77 Args: 78 x: The input data. 79 80 Returns: 81 The vision transformer output. 82 """ 83 x = self.patch_embed(x) 84 if self.pos_embed is not None: 85 x = x + self.pos_embed 86 87 list_from_encoder = [] 88 for i, blk in enumerate(self.blocks): 89 x = blk(x) 90 if i in self.chunks_for_projection: 91 list_from_encoder.append(x) 92 93 x = x.permute(0, 3, 1, 2) 94 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 95 return x, list_from_encoder[:3]
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
98class ViT_MAE(VisionTransformer): 99 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 100 101 Based on: 102 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 103 104 Args: 105 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 106 in_chans: The number of input channels. 107 depth: The depth of the vision transformer. 108 kwargs: Additional keyword arguments for the vision transformer base class. 109 """ 110 def __init__( 111 self, 112 img_size: int = 1024, # chosen to match our experiments with segment anything 113 in_chans: int = 3, 114 depth: int = 12, 115 **kwargs 116 ): 117 if not _timm_import_success: 118 raise RuntimeError( 119 "The vision transformer backend can only be initialized if timm is installed. " 120 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 121 "and then rerun your code" 122 ) 123 super().__init__(img_size=img_size, depth=depth, **kwargs) 124 self.img_size = img_size 125 self.in_chans = in_chans 126 self.depth = depth 127 128 def convert_to_expected_dim(self, inputs_): 129 """@private 130 """ 131 inputs_ = inputs_[:, 1:, :] # removing the class tokens 132 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 133 rdim = inputs_.shape[1] 134 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 135 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 136 inputs_ = inputs_.permute(0, 3, 1, 2) 137 return inputs_ 138 139 def forward_features(self, x): 140 """@private 141 """ 142 B = x.shape[0] 143 x = self.patch_embed(x) 144 145 cls_tokens = self.cls_token.expand(B, -1, -1) 146 x = torch.cat((cls_tokens, x), dim=1) 147 148 x = x + self.pos_embed 149 x = self.pos_drop(x) 150 151 # chunks obtained for getting the projections for conjuctions with upsampling blocks 152 _chunks = int(self.depth / 4) 153 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 154 155 list_from_encoder = [] 156 for i, blk in enumerate(self.blocks): 157 x = blk(x) 158 if i in chunks_for_projection: 159 list_from_encoder.append(self.convert_to_expected_dim(x)) 160 161 x = self.convert_to_expected_dim(x) 162 return x, list_from_encoder[:3] 163 164 def forward(self, x: torch.Tensor) -> torch.Tensor: 165 """Apply the vision transformer to input data. 166 167 Args: 168 x: The input data. 169 170 Returns: 171 The vision transformer output. 172 """ 173 x, list_from_encoder = self.forward_features(x) 174 return x, list_from_encoder
Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- in_chans: The number of input channels.
- depth: The depth of the vision transformer.
- kwargs: Additional keyword arguments for the vision transformer base class.
110 def __init__( 111 self, 112 img_size: int = 1024, # chosen to match our experiments with segment anything 113 in_chans: int = 3, 114 depth: int = 12, 115 **kwargs 116 ): 117 if not _timm_import_success: 118 raise RuntimeError( 119 "The vision transformer backend can only be initialized if timm is installed. " 120 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 121 "and then rerun your code" 122 ) 123 super().__init__(img_size=img_size, depth=depth, **kwargs) 124 self.img_size = img_size 125 self.in_chans = in_chans 126 self.depth = depth
164 def forward(self, x: torch.Tensor) -> torch.Tensor: 165 """Apply the vision transformer to input data. 166 167 Args: 168 x: The input data. 169 170 Returns: 171 The vision transformer output. 172 """ 173 x, list_from_encoder = self.forward_features(x) 174 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
177class ViT_Sam2(ImageEncoder): 178 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 179 180 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 181 182 Args: 183 backbone_channel_list: The channels throughout the entire backbone. 184 embed_dim: The initial embedding dimension. 185 num_heads: The initial number of heads. 186 stages: The number of blocks per stage. 187 global_att_blocks: The parameter to decide which blocks have global attention. 188 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 189 window_spec: The window size per stage, when not using global attention. 190 scalp: The count of lowest resolution features to discard. 191 """ 192 def __init__( 193 self, 194 backbone_channel_list: List[int], 195 img_size: int = 1024, 196 embed_dim: int = 96, 197 num_heads: int = 1, 198 stages: Tuple[int, ...] = (2, 3, 16, 3), 199 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 200 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 201 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 202 scalp: int = 1, 203 ): 204 if not _sam2_import_success: 205 raise RuntimeError( 206 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 207 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 208 "and then rerun your code" 209 ) 210 211 trunk = Hiera( 212 embed_dim=embed_dim, 213 num_heads=num_heads, 214 stages=stages, 215 global_att_blocks=global_att_blocks, 216 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 217 window_spec=window_spec, 218 ) 219 neck = FpnNeck( 220 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 221 d_model=256, 222 backbone_channel_list=backbone_channel_list, 223 fpn_top_down_levels=[2, 3], 224 fpn_interp_model="nearest", 225 ) 226 227 super().__init__(trunk=trunk, neck=neck, scalp=scalp) 228 self.scalp = scalp 229 self.embed_dim = embed_dim 230 self.img_size = img_size 231 232 def forward(self, x: torch.Tensor): 233 # The forward pass throught the backbone. 234 features, pos = self.neck(self.trunk(x)) 235 if self.scalp > 0: # This discard the "n" lowest resolution features. 236 features, pos = features[:-self.scalp], pos[:-self.scalp] 237 238 return features[-1], features
Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
Arguments:
- backbone_channel_list: The channels throughout the entire backbone.
- embed_dim: The initial embedding dimension.
- num_heads: The initial number of heads.
- stages: The number of blocks per stage.
- global_att_blocks: The parameter to decide which blocks have global attention.
- window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
- window_spec: The window size per stage, when not using global attention.
- scalp: The count of lowest resolution features to discard.
192 def __init__( 193 self, 194 backbone_channel_list: List[int], 195 img_size: int = 1024, 196 embed_dim: int = 96, 197 num_heads: int = 1, 198 stages: Tuple[int, ...] = (2, 3, 16, 3), 199 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 200 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 201 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 202 scalp: int = 1, 203 ): 204 if not _sam2_import_success: 205 raise RuntimeError( 206 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 207 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 208 "and then rerun your code" 209 ) 210 211 trunk = Hiera( 212 embed_dim=embed_dim, 213 num_heads=num_heads, 214 stages=stages, 215 global_att_blocks=global_att_blocks, 216 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 217 window_spec=window_spec, 218 ) 219 neck = FpnNeck( 220 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 221 d_model=256, 222 backbone_channel_list=backbone_channel_list, 223 fpn_top_down_levels=[2, 3], 224 fpn_interp_model="nearest", 225 ) 226 227 super().__init__(trunk=trunk, neck=neck, scalp=scalp) 228 self.scalp = scalp 229 self.embed_dim = embed_dim 230 self.img_size = img_size
232 def forward(self, x: torch.Tensor): 233 # The forward pass throught the backbone. 234 features, pos = self.neck(self.trunk(x)) 235 if self.scalp > 0: # This discard the "n" lowest resolution features. 236 features, pos = features[:-self.scalp], pos[:-self.scalp] 237 238 return features[-1], features
246class CustomCompose: 247 def __init__(self, rescale_transform, other_transforms, src_transform): 248 self.rescale_transform = rescale_transform 249 self.other_transforms = other_transforms 250 self.src_transform = src_transform 251 252 def __call__(self, x, valid_masks=None): 253 if valid_masks is not None: 254 nodata = (x * (1 - valid_masks.float())).max() 255 x_aug = self.rescale_transform(x) 256 parms = self.rescale_transform._params 257 258 # sanity check, comment if this is working 259 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 260 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 261 262 if valid_masks is not None: 263 valid_masks = x_aug != nodata 264 _, c, h, w = x_aug.shape 265 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 266 else: 267 zero_ratio = -1 268 269 if self.other_transforms: 270 x_aug = self.other_transforms(x_aug) 271 x_src = self.src_transform(x_aug) 272 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 273 274 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 275 # assert (dx == dy).all() 276 277 h, w = x_aug.shape[-2:] 278 # assert h == w 279 280 return x_aug, x_src, dx / h, zero_ratio, valid_masks
283def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 284 """ 285 grid_size: int of the grid height and width 286 res: array of size n, representing the resolution of a pixel (say, in meters), 287 return: 288 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 289 """ 290 # res = torch.FloatTensor(res).to(device) 291 res = res.to(device) 292 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 293 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 294 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 295 grid = torch.stack(grid, dim=0) # 2 x h x w 296 297 # grid = grid.reshape([2, 1, grid_size, grid_size]) 298 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 299 _, n, h, w = grid.shape 300 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 301 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 302 if cls_token: 303 pos_embed = torch.cat( 304 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 305 ) 306 307 return pos_embed
grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
310def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 311 assert embed_dim % 2 == 0 312 313 # use half of dimensions to encode grid_h 314 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 315 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 316 317 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 318 return emb
321def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 322 """ 323 embed_dim: output dimension for each position 324 pos: a list of positions to be encoded: size (M,) 325 out: (M, D) 326 """ 327 assert embed_dim % 2 == 0 328 # old_shape = pos 329 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 330 omega /= embed_dim / 2.0 331 omega = 1.0 / 10000**omega # (D/2,) 332 333 pos = pos.reshape(-1) # (M,) 334 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 335 336 emb_sin = torch.sin(out) # (M, D/2) 337 emb_cos = torch.cos(out) # (M, D/2) 338 339 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 340 return emb
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
343class PatchEmbedUnSafe(PatchEmbed): 344 """Image to Patch Embedding""" 345 346 def forward(self, x): 347 B, C, H, W = x.shape 348 349 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 350 # assert H == self.img_size[0] and W == self.img_size[1], \ 351 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 352 353 x = self.proj(x).flatten(2).transpose(1, 2) 354 return x
Image to Patch Embedding
346 def forward(self, x): 347 B, C, H, W = x.shape 348 349 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 350 # assert H == self.img_size[0] and W == self.img_size[1], \ 351 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 352 353 x = self.proj(x).flatten(2).transpose(1, 2) 354 return x
357class ViT_ScaleMAE(VisionTransformer): 358 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 359 360 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 361 the model on a different zoom factor dataset. 362 """ 363 364 def __init__( 365 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 366 ): 367 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 368 self.img_size = img_size 369 self.in_chans = in_chans 370 self.depth = depth 371 self.base_resolution = base_resolution 372 373 self.patch_embed = PatchEmbedUnSafe( 374 img_size=img_size, 375 patch_size=patch_size, 376 in_chans=in_chans, 377 embed_dim=embed_dim, 378 ) 379 380 def transform_inputs(self, x): 381 import kornia.augmentation as K 382 from kornia.constants import Resample 383 384 self._transforms = CustomCompose( 385 rescale_transform=K.RandomResizedCrop( 386 (448, 448), 387 ratio=(1.0, 1.0), 388 scale=(1.0, 1.0), 389 resample=Resample.BICUBIC.name, 390 ), 391 other_transforms=None, 392 src_transform=K.Resize((224, 224)), 393 ) 394 x, _, ratios, _, _ = self._transforms(x) 395 input_res = ratios * self.base_resolution 396 return x, input_res 397 398 def convert_to_expected_dim(self, x): 399 inputs_ = x[:, 1:, :] # removing the class tokens 400 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 401 rdim = inputs_.shape[1] 402 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 403 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 404 inputs_ = inputs_.permute(0, 3, 1, 2) 405 return inputs_ 406 407 def forward_features(self, x): 408 x, input_res = self.transform_inputs(x) 409 410 B, _, h, w = x.shape 411 x = self.patch_embed(x) 412 413 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 414 pos_embed = get_2d_sincos_pos_embed_with_resolution( 415 x.shape[-1], 416 int(num_patches ** 0.5), 417 input_res, 418 cls_token=True, 419 device=x.device, 420 ) 421 422 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 423 x = torch.cat((cls_tokens, x), dim=1) 424 x = x + pos_embed 425 x = self.pos_drop(x) 426 427 # chunks obtained for getting the projections for conjuctions with upsampling blocks 428 _chunks = int(self.depth / 4) 429 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 430 431 list_from_encoder = [] 432 for i, blk in enumerate(self.blocks): 433 x = blk(x) 434 if i in chunks_for_projection: 435 list_from_encoder.append(self.convert_to_expected_dim(x)) 436 437 x = self.convert_to_expected_dim(x) 438 439 return x, list_from_encoder 440 441 def forward(self, x): 442 x, list_from_encoder = self.forward_features(x) 443 return x, list_from_encoder
Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.
364 def __init__( 365 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 366 ): 367 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 368 self.img_size = img_size 369 self.in_chans = in_chans 370 self.depth = depth 371 self.base_resolution = base_resolution 372 373 self.patch_embed = PatchEmbedUnSafe( 374 img_size=img_size, 375 patch_size=patch_size, 376 in_chans=in_chans, 377 embed_dim=embed_dim, 378 )
380 def transform_inputs(self, x): 381 import kornia.augmentation as K 382 from kornia.constants import Resample 383 384 self._transforms = CustomCompose( 385 rescale_transform=K.RandomResizedCrop( 386 (448, 448), 387 ratio=(1.0, 1.0), 388 scale=(1.0, 1.0), 389 resample=Resample.BICUBIC.name, 390 ), 391 other_transforms=None, 392 src_transform=K.Resize((224, 224)), 393 ) 394 x, _, ratios, _, _ = self._transforms(x) 395 input_res = ratios * self.base_resolution 396 return x, input_res
398 def convert_to_expected_dim(self, x): 399 inputs_ = x[:, 1:, :] # removing the class tokens 400 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 401 rdim = inputs_.shape[1] 402 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 403 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 404 inputs_ = inputs_.permute(0, 3, 1, 2) 405 return inputs_
407 def forward_features(self, x): 408 x, input_res = self.transform_inputs(x) 409 410 B, _, h, w = x.shape 411 x = self.patch_embed(x) 412 413 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 414 pos_embed = get_2d_sincos_pos_embed_with_resolution( 415 x.shape[-1], 416 int(num_patches ** 0.5), 417 input_res, 418 cls_token=True, 419 device=x.device, 420 ) 421 422 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 423 x = torch.cat((cls_tokens, x), dim=1) 424 x = x + pos_embed 425 x = self.pos_drop(x) 426 427 # chunks obtained for getting the projections for conjuctions with upsampling blocks 428 _chunks = int(self.depth / 4) 429 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 430 431 list_from_encoder = [] 432 for i, blk in enumerate(self.blocks): 433 x = blk(x) 434 if i in chunks_for_projection: 435 list_from_encoder.append(self.convert_to_expected_dim(x)) 436 437 x = self.convert_to_expected_dim(x) 438 439 return x, list_from_encoder
446class ViT_DINOv3(DinoVisionTransformer): 447 """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104). 448 449 Based on: 450 https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py. 451 452 Args: 453 img_size: The input image size. 454 patch_size: The patch size. 455 embed_dim: The embedding dimension. 456 depth: The depth of the network. 457 num_heads: The number of heads. 458 ffn_ratio: The FFN rato. 459 n_storage_tokens: The number of storage (class) tokens to remove. 460 kwargs: Keyword arguments for the image encoder base class. 461 """ 462 def __init__( 463 self, 464 in_chans: int = 3, 465 img_size: int = 224, 466 patch_size: int = 16, 467 embed_dim: int = 768, 468 depth: int = 12, 469 num_heads: int = 12, 470 ffn_ratio: float = 4.0, 471 n_storage_tokens: int = 0, 472 **kwargs 473 ): 474 if not _dinov3_import_success: 475 raise RuntimeError( 476 "The vision transformer backend can only be initialized if DINOv3 is installed. " 477 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 478 "and then rerun your code." 479 ) 480 481 super().__init__( 482 in_chans=in_chans, 483 img_size=img_size, 484 patch_size=patch_size, 485 embed_dim=embed_dim, 486 depth=depth, 487 num_heads=num_heads, 488 ffn_ratio=ffn_ratio, 489 n_storage_tokens=n_storage_tokens, 490 **kwargs 491 ) 492 493 self.in_chans = in_chans 494 self.img_size = img_size 495 self.n_storage_tokens = n_storage_tokens 496 self.attn_outs = [i for i in range(depth) if i % 3 == 2] 497 498 def forward(self, x) -> torch.Tensor: 499 500 B = x.shape[0] 501 502 x, hw_tuple = self.prepare_tokens_with_masks(x) 503 504 list_of_encoder = [] 505 for i, blk in enumerate(self.blocks): 506 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 507 x = blk(x, rope_sincos) 508 if i in self.attn_outs: 509 list_of_encoder.append(x) 510 511 x = self.norm(x) 512 x = x[:, self.n_storage_tokens + 1:].reshape( 513 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 514 ).permute(0, 3, 1, 2).contiguous() 515 516 list_of_encoder = [ 517 o[:, self.n_storage_tokens + 1:].reshape( 518 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 519 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 520 ] 521 522 return x, list_of_encoder[:3]
Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
Based on: https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
Arguments:
- img_size: The input image size.
- patch_size: The patch size.
- embed_dim: The embedding dimension.
- depth: The depth of the network.
- num_heads: The number of heads.
- ffn_ratio: The FFN rato.
- n_storage_tokens: The number of storage (class) tokens to remove.
- kwargs: Keyword arguments for the image encoder base class.
462 def __init__( 463 self, 464 in_chans: int = 3, 465 img_size: int = 224, 466 patch_size: int = 16, 467 embed_dim: int = 768, 468 depth: int = 12, 469 num_heads: int = 12, 470 ffn_ratio: float = 4.0, 471 n_storage_tokens: int = 0, 472 **kwargs 473 ): 474 if not _dinov3_import_success: 475 raise RuntimeError( 476 "The vision transformer backend can only be initialized if DINOv3 is installed. " 477 "Please install DINOv3 from https://github.com/facebookresearch/dinov3 " 478 "and then rerun your code." 479 ) 480 481 super().__init__( 482 in_chans=in_chans, 483 img_size=img_size, 484 patch_size=patch_size, 485 embed_dim=embed_dim, 486 depth=depth, 487 num_heads=num_heads, 488 ffn_ratio=ffn_ratio, 489 n_storage_tokens=n_storage_tokens, 490 **kwargs 491 ) 492 493 self.in_chans = in_chans 494 self.img_size = img_size 495 self.n_storage_tokens = n_storage_tokens 496 self.attn_outs = [i for i in range(depth) if i % 3 == 2]
498 def forward(self, x) -> torch.Tensor: 499 500 B = x.shape[0] 501 502 x, hw_tuple = self.prepare_tokens_with_masks(x) 503 504 list_of_encoder = [] 505 for i, blk in enumerate(self.blocks): 506 rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1]) 507 x = blk(x, rope_sincos) 508 if i in self.attn_outs: 509 list_of_encoder.append(x) 510 511 x = self.norm(x) 512 x = x[:, self.n_storage_tokens + 1:].reshape( 513 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 514 ).permute(0, 3, 1, 2).contiguous() 515 516 list_of_encoder = [ 517 o[:, self.n_storage_tokens + 1:].reshape( 518 B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1 519 ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder 520 ] 521 522 return x, list_of_encoder[:3]
525def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 526 """Get vision transformer encoder. 527 528 Args: 529 backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae". 530 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 531 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 532 kwargs: Additional kwargs which can be expected by the vision transformer, 533 e.g. 'base_resolution' for `ViT_ScaleMAE`. 534 535 Returns: 536 The vision transformer. 537 """ 538 if backbone == "sam": 539 if model == "vit_b": 540 encoder = ViT_Sam( 541 depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4, 542 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 543 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 544 global_attn_indexes=[2, 5, 8, 11], 545 window_size=14, out_chans=256, 546 ) 547 elif model == "vit_l": 548 encoder = ViT_Sam( 549 depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4, 550 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 551 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 552 global_attn_indexes=[5, 11, 17, 23], 553 window_size=14, out_chans=256, 554 ) 555 elif model == "vit_h": 556 encoder = ViT_Sam( 557 depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4, 558 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 559 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 560 global_attn_indexes=[7, 15, 23, 31], 561 window_size=14, out_chans=256, 562 ) 563 else: 564 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 565 566 elif backbone == "sam2": 567 if model == "hvit_t": 568 encoder = ViT_Sam2( 569 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 570 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 571 ) 572 elif model == "hvit_s": 573 encoder = ViT_Sam2( 574 img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 575 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 576 ) 577 elif model == "hvit_b": 578 encoder = ViT_Sam2( 579 img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 580 ) 581 elif model == "hvit_l": 582 encoder = ViT_Sam2( 583 img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 584 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 585 ) 586 else: 587 raise ValueError( 588 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 589 ) 590 591 elif backbone == "mae": 592 if model == "vit_b": 593 encoder = ViT_MAE( 594 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 595 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 596 ) 597 elif model == "vit_l": 598 encoder = ViT_MAE( 599 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 600 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 601 ) 602 elif model == "vit_h": 603 encoder = ViT_MAE( 604 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 605 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 606 ) 607 else: 608 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 609 610 elif backbone == "scalemae": 611 base_resolution = kwargs.get("base_resolution", 2.5) 612 613 if model == "vit_b": 614 encoder = ViT_ScaleMAE( 615 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 616 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 617 base_resolution=base_resolution, 618 ) 619 elif model == "vit_l": 620 encoder = ViT_ScaleMAE( 621 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 622 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 623 base_resolution=base_resolution, 624 ) 625 elif model == "vit_h": 626 encoder = ViT_ScaleMAE( 627 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 628 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 629 base_resolution=base_resolution, 630 ) 631 else: 632 raise ValueError( 633 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 634 ) 635 636 elif backbone == "dinov3": 637 638 if model == "vit_s": 639 encoder = ViT_DINOv3( 640 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 641 num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 642 ) 643 elif model == "vit_s+": 644 encoder = ViT_DINOv3( 645 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384, 646 num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", 647 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 648 ) 649 650 elif model == "vit_b": 651 encoder = ViT_DINOv3( 652 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", 653 layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True, 654 ) 655 elif model == "vit_l": 656 encoder = ViT_DINOv3( 657 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 658 depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16", 659 n_storage_tokens=4, mask_k_bias=True, 660 ) 661 elif model == "vit_l+": 662 encoder = ViT_DINOv3( 663 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024, 664 depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 665 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 666 ) 667 elif model == "vit_h+": 668 encoder = ViT_DINOv3( 669 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280, 670 depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", 671 ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True, 672 ) 673 elif model == "vit_7b": 674 encoder = ViT_DINOv3( 675 img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096, 676 depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, 677 norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True, 678 untie_global_and_local_cls_norm=True, 679 ) 680 else: 681 raise ValueError( 682 f"'{model}' is not supported by DINOv3. Currently, " 683 " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported." 684 ) 685 686 else: 687 raise ValueError( 688 "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae', 'scalemae' or 'dinov3'. " 689 "Please choose one of them." 690 ) 691 692 return encoder
Get vision transformer encoder.
Arguments:
- backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
- model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- kwargs: Additional kwargs which can be expected by the vision transformer,
e.g. 'base_resolution' for
ViT_ScaleMAE.
Returns:
The vision transformer.