torch_em.model.vit
1from functools import partial 2from typing import Tuple, List 3 4import torch 5import torch.nn as nn 6 7# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should 8# only be optional dependencies for torch_em 9try: 10 from segment_anything.modeling import ImageEncoderViT 11 _sam_import_success = True 12except ImportError: 13 ImageEncoderViT = object 14 _sam_import_success = False 15 16try: 17 from timm.models.vision_transformer import VisionTransformer, PatchEmbed 18 _timm_import_success = True 19except ImportError: 20 VisionTransformer = object 21 PatchEmbed = object 22 _timm_import_success = False 23 24try: 25 from sam2.modeling.backbones.hieradet import Hiera 26 from sam2.modeling.position_encoding import PositionEmbeddingSine 27 from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck 28 _sam2_import_success = True 29except ImportError: 30 ImageEncoder = object 31 _sam2_import_success = False 32 33 34class ViT_Sam(ImageEncoderViT): 35 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 36 37 Based on: 38 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 39 40 Args: 41 in_chans: The number of input channels. 42 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 43 global_attn_indexes: The global attention indices. 44 kwargs: Keyword arguments for the image encoder base class. 45 """ 46 def __init__( 47 self, 48 in_chans: int = 3, 49 embed_dim: int = 768, 50 global_attn_indexes: Tuple[int, ...] = ..., 51 **kwargs, 52 ) -> None: 53 if not _sam_import_success: 54 raise RuntimeError( 55 "The vision transformer backend can only be initialized if segment anything is installed. " 56 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 57 "and then rerun your code." 58 ) 59 60 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 61 self.chunks_for_projection = global_attn_indexes 62 self.in_chans = in_chans 63 self.embed_dim = embed_dim 64 65 def forward(self, x: torch.Tensor) -> torch.Tensor: 66 """Apply the vision transformer to input data. 67 68 Args: 69 x: The input data. 70 71 Returns: 72 The vision transformer output. 73 """ 74 x = self.patch_embed(x) 75 if self.pos_embed is not None: 76 x = x + self.pos_embed 77 78 list_from_encoder = [] 79 for i, blk in enumerate(self.blocks): 80 x = blk(x) 81 if i in self.chunks_for_projection: 82 list_from_encoder.append(x) 83 84 x = x.permute(0, 3, 1, 2) 85 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 86 return x, list_from_encoder[:3] 87 88 89class ViT_MAE(VisionTransformer): 90 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 91 92 Based on: 93 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 94 95 Args: 96 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 97 in_chans: The number of input channels. 98 depth: The depth of the vision transformer. 99 kwargs: Additional keyword arguments for the vision transformer base class. 100 """ 101 def __init__( 102 self, 103 img_size: int = 1024, # chosen to match our experiments with segment anything 104 in_chans: int = 3, 105 depth: int = 12, 106 **kwargs 107 ): 108 if not _timm_import_success: 109 raise RuntimeError( 110 "The vision transformer backend can only be initialized if timm is installed. " 111 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 112 "and then rerun your code" 113 ) 114 super().__init__(img_size=img_size, depth=depth, **kwargs) 115 self.img_size = img_size 116 self.in_chans = in_chans 117 self.depth = depth 118 119 def convert_to_expected_dim(self, inputs_): 120 """@private 121 """ 122 inputs_ = inputs_[:, 1:, :] # removing the class tokens 123 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 124 rdim = inputs_.shape[1] 125 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 126 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 127 inputs_ = inputs_.permute(0, 3, 1, 2) 128 return inputs_ 129 130 def forward_features(self, x): 131 """@private 132 """ 133 B = x.shape[0] 134 x = self.patch_embed(x) 135 136 cls_tokens = self.cls_token.expand(B, -1, -1) 137 x = torch.cat((cls_tokens, x), dim=1) 138 139 x = x + self.pos_embed 140 x = self.pos_drop(x) 141 142 # chunks obtained for getting the projections for conjuctions with upsampling blocks 143 _chunks = int(self.depth / 4) 144 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 145 146 list_from_encoder = [] 147 for i, blk in enumerate(self.blocks): 148 x = blk(x) 149 if i in chunks_for_projection: 150 list_from_encoder.append(self.convert_to_expected_dim(x)) 151 152 x = self.convert_to_expected_dim(x) 153 return x, list_from_encoder[:3] 154 155 def forward(self, x: torch.Tensor) -> torch.Tensor: 156 """Apply the vision transformer to input data. 157 158 Args: 159 x: The input data. 160 161 Returns: 162 The vision transformer output. 163 """ 164 x, list_from_encoder = self.forward_features(x) 165 return x, list_from_encoder 166 167 168class ViT_Sam2(ImageEncoder): 169 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 170 171 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 172 173 Args: 174 backbone_channel_list: The channels throughout the entire backbone. 175 embed_dim: The initial embedding dimension. 176 num_heads: The initial number of heads. 177 stages: The number of blocks per stage. 178 global_att_blocks: The parameter to decide which blocks have global attention. 179 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 180 window_spec: The window size per stage, when not using global attention. 181 scalp: The count of lowest resolution features to discard. 182 """ 183 def __init__( 184 self, 185 backbone_channel_list: List[int], 186 embed_dim: int = 96, 187 num_heads: int = 1, 188 stages: Tuple[int, ...] = (2, 3, 16, 3), 189 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 190 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 191 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 192 scalp: int = 1, 193 ): 194 if not _sam2_import_success: 195 raise RuntimeError( 196 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 197 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 198 "and then rerun your code" 199 ) 200 201 trunk = Hiera( 202 embed_dim=embed_dim, 203 num_heads=num_heads, 204 stages=stages, 205 global_att_blocks=global_att_blocks, 206 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 207 window_spec=window_spec, 208 ) 209 neck = FpnNeck( 210 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 211 d_model=256, 212 backbone_channel_list=backbone_channel_list, 213 fpn_top_down_levels=[2, 3], 214 fpn_interp_model="nearest", 215 ) 216 217 super().__init__(trunk=trunk, neck=neck, scalp=scalp) 218 self.scalp = scalp 219 self.embed_dim = embed_dim 220 self.img_size = 1024 # NOTE: Hard-coded atm, declared in the configuration file. 221 222 def forward(self, x: torch.Tensor): 223 # The forward pass throught the backbone. 224 features, pos = self.neck(self.trunk(x)) 225 if self.scalp > 0: # This discard the "n" lowest resolution features. 226 features, pos = features[:-self.scalp], pos[:-self.scalp] 227 228 return features[-1], features 229 230 231# 232# Utilities for ScaleMAE's ViT 233# 234 235 236class CustomCompose: 237 def __init__(self, rescale_transform, other_transforms, src_transform): 238 self.rescale_transform = rescale_transform 239 self.other_transforms = other_transforms 240 self.src_transform = src_transform 241 242 def __call__(self, x, valid_masks=None): 243 if valid_masks is not None: 244 nodata = (x * (1 - valid_masks.float())).max() 245 x_aug = self.rescale_transform(x) 246 parms = self.rescale_transform._params 247 248 # sanity check, comment if this is working 249 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 250 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 251 252 if valid_masks is not None: 253 valid_masks = x_aug != nodata 254 _, c, h, w = x_aug.shape 255 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 256 else: 257 zero_ratio = -1 258 259 if self.other_transforms: 260 x_aug = self.other_transforms(x_aug) 261 x_src = self.src_transform(x_aug) 262 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 263 264 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 265 # assert (dx == dy).all() 266 267 h, w = x_aug.shape[-2:] 268 # assert h == w 269 270 return x_aug, x_src, dx / h, zero_ratio, valid_masks 271 272 273def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 274 """ 275 grid_size: int of the grid height and width 276 res: array of size n, representing the resolution of a pixel (say, in meters), 277 return: 278 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 279 """ 280 # res = torch.FloatTensor(res).to(device) 281 res = res.to(device) 282 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 283 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 284 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 285 grid = torch.stack(grid, dim=0) # 2 x h x w 286 287 # grid = grid.reshape([2, 1, grid_size, grid_size]) 288 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 289 _, n, h, w = grid.shape 290 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 291 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 292 if cls_token: 293 pos_embed = torch.cat( 294 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 295 ) 296 297 return pos_embed 298 299 300def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 301 assert embed_dim % 2 == 0 302 303 # use half of dimensions to encode grid_h 304 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 305 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 306 307 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 308 return emb 309 310 311def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 312 """ 313 embed_dim: output dimension for each position 314 pos: a list of positions to be encoded: size (M,) 315 out: (M, D) 316 """ 317 assert embed_dim % 2 == 0 318 # old_shape = pos 319 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 320 omega /= embed_dim / 2.0 321 omega = 1.0 / 10000**omega # (D/2,) 322 323 pos = pos.reshape(-1) # (M,) 324 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 325 326 emb_sin = torch.sin(out) # (M, D/2) 327 emb_cos = torch.cos(out) # (M, D/2) 328 329 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 330 return emb 331 332 333class PatchEmbedUnSafe(PatchEmbed): 334 """Image to Patch Embedding""" 335 336 def forward(self, x): 337 B, C, H, W = x.shape 338 339 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 340 # assert H == self.img_size[0] and W == self.img_size[1], \ 341 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 342 343 x = self.proj(x).flatten(2).transpose(1, 2) 344 return x 345 346 347class ViT_ScaleMAE(VisionTransformer): 348 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 349 350 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 351 the model on a different zoom factor dataset. 352 """ 353 354 def __init__( 355 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 356 ): 357 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 358 self.img_size = img_size 359 self.in_chans = in_chans 360 self.depth = depth 361 self.base_resolution = base_resolution 362 363 self.patch_embed = PatchEmbedUnSafe( 364 img_size=img_size, 365 patch_size=patch_size, 366 in_chans=in_chans, 367 embed_dim=embed_dim, 368 ) 369 370 def transform_inputs(self, x): 371 import kornia.augmentation as K 372 from kornia.constants import Resample 373 374 self._transforms = CustomCompose( 375 rescale_transform=K.RandomResizedCrop( 376 (448, 448), 377 ratio=(1.0, 1.0), 378 scale=(1.0, 1.0), 379 resample=Resample.BICUBIC.name, 380 ), 381 other_transforms=None, 382 src_transform=K.Resize((224, 224)), 383 ) 384 x, _, ratios, _, _ = self._transforms(x) 385 input_res = ratios * self.base_resolution 386 return x, input_res 387 388 def convert_to_expected_dim(self, x): 389 inputs_ = x[:, 1:, :] # removing the class tokens 390 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 391 rdim = inputs_.shape[1] 392 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 393 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 394 inputs_ = inputs_.permute(0, 3, 1, 2) 395 return inputs_ 396 397 def forward_features(self, x): 398 x, input_res = self.transform_inputs(x) 399 400 B, _, h, w = x.shape 401 x = self.patch_embed(x) 402 403 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 404 pos_embed = get_2d_sincos_pos_embed_with_resolution( 405 x.shape[-1], 406 int(num_patches ** 0.5), 407 input_res, 408 cls_token=True, 409 device=x.device, 410 ) 411 412 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 413 x = torch.cat((cls_tokens, x), dim=1) 414 x = x + pos_embed 415 x = self.pos_drop(x) 416 417 # chunks obtained for getting the projections for conjuctions with upsampling blocks 418 _chunks = int(self.depth / 4) 419 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 420 421 list_from_encoder = [] 422 for i, blk in enumerate(self.blocks): 423 x = blk(x) 424 if i in chunks_for_projection: 425 list_from_encoder.append(self.convert_to_expected_dim(x)) 426 427 x = self.convert_to_expected_dim(x) 428 429 return x, list_from_encoder 430 431 def forward(self, x): 432 x, list_from_encoder = self.forward_features(x) 433 return x, list_from_encoder 434 435 436def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 437 """Get vision transformer encoder. 438 439 Args: 440 backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae". 441 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 442 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 443 kwargs: Additional kwargs which can be expected by the vision transformer, 444 e.g. 'base_resolution' for `ViT_ScaleMAE`. 445 446 Returns: 447 The vision transformer. 448 """ 449 if backbone == "sam": 450 if model == "vit_b": 451 encoder = ViT_Sam( 452 depth=12, embed_dim=768, img_size=1024, mlp_ratio=4, 453 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 454 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 455 global_attn_indexes=[2, 5, 8, 11], 456 window_size=14, out_chans=256, 457 ) 458 elif model == "vit_l": 459 encoder = ViT_Sam( 460 depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4, 461 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 462 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 463 global_attn_indexes=[5, 11, 17, 23], 464 window_size=14, out_chans=256, 465 ) 466 elif model == "vit_h": 467 encoder = ViT_Sam( 468 depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4, 469 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 470 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 471 global_attn_indexes=[7, 15, 23, 31], 472 window_size=14, out_chans=256, 473 ) 474 else: 475 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 476 477 elif backbone == "sam2": 478 if model == "hvit_t": 479 encoder = ViT_Sam2( 480 embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 481 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 482 ) 483 elif model == "hvit_s": 484 encoder = ViT_Sam2( 485 embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 486 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 487 ) 488 elif model == "hvit_b": 489 encoder = ViT_Sam2( 490 embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 491 ) 492 elif model == "hvit_l": 493 encoder = ViT_Sam2( 494 embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 495 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 496 ) 497 else: 498 raise ValueError( 499 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 500 ) 501 502 elif backbone == "mae": 503 if model == "vit_b": 504 encoder = ViT_MAE( 505 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 506 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 507 ) 508 elif model == "vit_l": 509 encoder = ViT_MAE( 510 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 511 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 512 ) 513 elif model == "vit_h": 514 encoder = ViT_MAE( 515 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 516 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 517 ) 518 else: 519 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 520 521 elif backbone == "scalemae": 522 base_resolution = kwargs.get("base_resolution", 2.5) 523 524 if model == "vit_b": 525 encoder = ViT_ScaleMAE( 526 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 527 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 528 base_resolution=base_resolution, 529 ) 530 elif model == "vit_l": 531 encoder = ViT_ScaleMAE( 532 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 533 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 534 base_resolution=base_resolution, 535 ) 536 elif model == "vit_h": 537 encoder = ViT_ScaleMAE( 538 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 539 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 540 base_resolution=base_resolution, 541 ) 542 else: 543 raise ValueError( 544 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 545 ) 546 547 else: 548 raise ValueError( 549 "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae' or 'scalemae'. Please choose one of them." 550 ) 551 552 return encoder
35class ViT_Sam(ImageEncoderViT): 36 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 37 38 Based on: 39 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 40 41 Args: 42 in_chans: The number of input channels. 43 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 44 global_attn_indexes: The global attention indices. 45 kwargs: Keyword arguments for the image encoder base class. 46 """ 47 def __init__( 48 self, 49 in_chans: int = 3, 50 embed_dim: int = 768, 51 global_attn_indexes: Tuple[int, ...] = ..., 52 **kwargs, 53 ) -> None: 54 if not _sam_import_success: 55 raise RuntimeError( 56 "The vision transformer backend can only be initialized if segment anything is installed. " 57 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 58 "and then rerun your code." 59 ) 60 61 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 62 self.chunks_for_projection = global_attn_indexes 63 self.in_chans = in_chans 64 self.embed_dim = embed_dim 65 66 def forward(self, x: torch.Tensor) -> torch.Tensor: 67 """Apply the vision transformer to input data. 68 69 Args: 70 x: The input data. 71 72 Returns: 73 The vision transformer output. 74 """ 75 x = self.patch_embed(x) 76 if self.pos_embed is not None: 77 x = x + self.pos_embed 78 79 list_from_encoder = [] 80 for i, blk in enumerate(self.blocks): 81 x = blk(x) 82 if i in self.chunks_for_projection: 83 list_from_encoder.append(x) 84 85 x = x.permute(0, 3, 1, 2) 86 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 87 return x, list_from_encoder[:3]
Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
Arguments:
- in_chans: The number of input channels.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- global_attn_indexes: The global attention indices.
- kwargs: Keyword arguments for the image encoder base class.
47 def __init__( 48 self, 49 in_chans: int = 3, 50 embed_dim: int = 768, 51 global_attn_indexes: Tuple[int, ...] = ..., 52 **kwargs, 53 ) -> None: 54 if not _sam_import_success: 55 raise RuntimeError( 56 "The vision transformer backend can only be initialized if segment anything is installed. " 57 "Please install segment anything from https://github.com/facebookresearch/segment-anything " 58 "and then rerun your code." 59 ) 60 61 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 62 self.chunks_for_projection = global_attn_indexes 63 self.in_chans = in_chans 64 self.embed_dim = embed_dim
66 def forward(self, x: torch.Tensor) -> torch.Tensor: 67 """Apply the vision transformer to input data. 68 69 Args: 70 x: The input data. 71 72 Returns: 73 The vision transformer output. 74 """ 75 x = self.patch_embed(x) 76 if self.pos_embed is not None: 77 x = x + self.pos_embed 78 79 list_from_encoder = [] 80 for i, blk in enumerate(self.blocks): 81 x = blk(x) 82 if i in self.chunks_for_projection: 83 list_from_encoder.append(x) 84 85 x = x.permute(0, 3, 1, 2) 86 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 87 return x, list_from_encoder[:3]
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
90class ViT_MAE(VisionTransformer): 91 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 92 93 Based on: 94 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 95 96 Args: 97 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 98 in_chans: The number of input channels. 99 depth: The depth of the vision transformer. 100 kwargs: Additional keyword arguments for the vision transformer base class. 101 """ 102 def __init__( 103 self, 104 img_size: int = 1024, # chosen to match our experiments with segment anything 105 in_chans: int = 3, 106 depth: int = 12, 107 **kwargs 108 ): 109 if not _timm_import_success: 110 raise RuntimeError( 111 "The vision transformer backend can only be initialized if timm is installed. " 112 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 113 "and then rerun your code" 114 ) 115 super().__init__(img_size=img_size, depth=depth, **kwargs) 116 self.img_size = img_size 117 self.in_chans = in_chans 118 self.depth = depth 119 120 def convert_to_expected_dim(self, inputs_): 121 """@private 122 """ 123 inputs_ = inputs_[:, 1:, :] # removing the class tokens 124 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 125 rdim = inputs_.shape[1] 126 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 127 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 128 inputs_ = inputs_.permute(0, 3, 1, 2) 129 return inputs_ 130 131 def forward_features(self, x): 132 """@private 133 """ 134 B = x.shape[0] 135 x = self.patch_embed(x) 136 137 cls_tokens = self.cls_token.expand(B, -1, -1) 138 x = torch.cat((cls_tokens, x), dim=1) 139 140 x = x + self.pos_embed 141 x = self.pos_drop(x) 142 143 # chunks obtained for getting the projections for conjuctions with upsampling blocks 144 _chunks = int(self.depth / 4) 145 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 146 147 list_from_encoder = [] 148 for i, blk in enumerate(self.blocks): 149 x = blk(x) 150 if i in chunks_for_projection: 151 list_from_encoder.append(self.convert_to_expected_dim(x)) 152 153 x = self.convert_to_expected_dim(x) 154 return x, list_from_encoder[:3] 155 156 def forward(self, x: torch.Tensor) -> torch.Tensor: 157 """Apply the vision transformer to input data. 158 159 Args: 160 x: The input data. 161 162 Returns: 163 The vision transformer output. 164 """ 165 x, list_from_encoder = self.forward_features(x) 166 return x, list_from_encoder
Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- in_chans: The number of input channels.
- depth: The depth of the vision transformer.
- kwargs: Additional keyword arguments for the vision transformer base class.
102 def __init__( 103 self, 104 img_size: int = 1024, # chosen to match our experiments with segment anything 105 in_chans: int = 3, 106 depth: int = 12, 107 **kwargs 108 ): 109 if not _timm_import_success: 110 raise RuntimeError( 111 "The vision transformer backend can only be initialized if timm is installed. " 112 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ " 113 "and then rerun your code" 114 ) 115 super().__init__(img_size=img_size, depth=depth, **kwargs) 116 self.img_size = img_size 117 self.in_chans = in_chans 118 self.depth = depth
156 def forward(self, x: torch.Tensor) -> torch.Tensor: 157 """Apply the vision transformer to input data. 158 159 Args: 160 x: The input data. 161 162 Returns: 163 The vision transformer output. 164 """ 165 x, list_from_encoder = self.forward_features(x) 166 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
169class ViT_Sam2(ImageEncoder): 170 """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714). 171 172 Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py. 173 174 Args: 175 backbone_channel_list: The channels throughout the entire backbone. 176 embed_dim: The initial embedding dimension. 177 num_heads: The initial number of heads. 178 stages: The number of blocks per stage. 179 global_att_blocks: The parameter to decide which blocks have global attention. 180 window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding. 181 window_spec: The window size per stage, when not using global attention. 182 scalp: The count of lowest resolution features to discard. 183 """ 184 def __init__( 185 self, 186 backbone_channel_list: List[int], 187 embed_dim: int = 96, 188 num_heads: int = 1, 189 stages: Tuple[int, ...] = (2, 3, 16, 3), 190 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 191 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 192 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 193 scalp: int = 1, 194 ): 195 if not _sam2_import_success: 196 raise RuntimeError( 197 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 198 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 199 "and then rerun your code" 200 ) 201 202 trunk = Hiera( 203 embed_dim=embed_dim, 204 num_heads=num_heads, 205 stages=stages, 206 global_att_blocks=global_att_blocks, 207 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 208 window_spec=window_spec, 209 ) 210 neck = FpnNeck( 211 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 212 d_model=256, 213 backbone_channel_list=backbone_channel_list, 214 fpn_top_down_levels=[2, 3], 215 fpn_interp_model="nearest", 216 ) 217 218 super().__init__(trunk=trunk, neck=neck, scalp=scalp) 219 self.scalp = scalp 220 self.embed_dim = embed_dim 221 self.img_size = 1024 # NOTE: Hard-coded atm, declared in the configuration file. 222 223 def forward(self, x: torch.Tensor): 224 # The forward pass throught the backbone. 225 features, pos = self.neck(self.trunk(x)) 226 if self.scalp > 0: # This discard the "n" lowest resolution features. 227 features, pos = features[:-self.scalp], pos[:-self.scalp] 228 229 return features[-1], features
Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
Arguments:
- backbone_channel_list: The channels throughout the entire backbone.
- embed_dim: The initial embedding dimension.
- num_heads: The initial number of heads.
- stages: The number of blocks per stage.
- global_att_blocks: The parameter to decide which blocks have global attention.
- window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
- window_spec: The window size per stage, when not using global attention.
- scalp: The count of lowest resolution features to discard.
184 def __init__( 185 self, 186 backbone_channel_list: List[int], 187 embed_dim: int = 96, 188 num_heads: int = 1, 189 stages: Tuple[int, ...] = (2, 3, 16, 3), 190 global_att_blocks: Tuple[int, ...] = (12, 16, 20), 191 window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 192 window_spec: Tuple[int, ...] = (8, 4, 14, 7), 193 scalp: int = 1, 194 ): 195 if not _sam2_import_success: 196 raise RuntimeError( 197 "The vision transformer backend can only be initialized if segment anything 2 is installed. " 198 "Please install segment anything 2 from https://github.com/facebookresearch/sam2 " 199 "and then rerun your code" 200 ) 201 202 trunk = Hiera( 203 embed_dim=embed_dim, 204 num_heads=num_heads, 205 stages=stages, 206 global_att_blocks=global_att_blocks, 207 window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, 208 window_spec=window_spec, 209 ) 210 neck = FpnNeck( 211 position_encoding=PositionEmbeddingSine(num_pos_feats=256), 212 d_model=256, 213 backbone_channel_list=backbone_channel_list, 214 fpn_top_down_levels=[2, 3], 215 fpn_interp_model="nearest", 216 ) 217 218 super().__init__(trunk=trunk, neck=neck, scalp=scalp) 219 self.scalp = scalp 220 self.embed_dim = embed_dim 221 self.img_size = 1024 # NOTE: Hard-coded atm, declared in the configuration file.
223 def forward(self, x: torch.Tensor): 224 # The forward pass throught the backbone. 225 features, pos = self.neck(self.trunk(x)) 226 if self.scalp > 0: # This discard the "n" lowest resolution features. 227 features, pos = features[:-self.scalp], pos[:-self.scalp] 228 229 return features[-1], features
237class CustomCompose: 238 def __init__(self, rescale_transform, other_transforms, src_transform): 239 self.rescale_transform = rescale_transform 240 self.other_transforms = other_transforms 241 self.src_transform = src_transform 242 243 def __call__(self, x, valid_masks=None): 244 if valid_masks is not None: 245 nodata = (x * (1 - valid_masks.float())).max() 246 x_aug = self.rescale_transform(x) 247 parms = self.rescale_transform._params 248 249 # sanity check, comment if this is working 250 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 251 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 252 253 if valid_masks is not None: 254 valid_masks = x_aug != nodata 255 _, c, h, w = x_aug.shape 256 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 257 else: 258 zero_ratio = -1 259 260 if self.other_transforms: 261 x_aug = self.other_transforms(x_aug) 262 x_src = self.src_transform(x_aug) 263 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 264 265 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 266 # assert (dx == dy).all() 267 268 h, w = x_aug.shape[-2:] 269 # assert h == w 270 271 return x_aug, x_src, dx / h, zero_ratio, valid_masks
274def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 275 """ 276 grid_size: int of the grid height and width 277 res: array of size n, representing the resolution of a pixel (say, in meters), 278 return: 279 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 280 """ 281 # res = torch.FloatTensor(res).to(device) 282 res = res.to(device) 283 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 284 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 285 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 286 grid = torch.stack(grid, dim=0) # 2 x h x w 287 288 # grid = grid.reshape([2, 1, grid_size, grid_size]) 289 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 290 _, n, h, w = grid.shape 291 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 292 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 293 if cls_token: 294 pos_embed = torch.cat( 295 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 296 ) 297 298 return pos_embed
grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
301def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 302 assert embed_dim % 2 == 0 303 304 # use half of dimensions to encode grid_h 305 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 306 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 307 308 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 309 return emb
312def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 313 """ 314 embed_dim: output dimension for each position 315 pos: a list of positions to be encoded: size (M,) 316 out: (M, D) 317 """ 318 assert embed_dim % 2 == 0 319 # old_shape = pos 320 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 321 omega /= embed_dim / 2.0 322 omega = 1.0 / 10000**omega # (D/2,) 323 324 pos = pos.reshape(-1) # (M,) 325 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 326 327 emb_sin = torch.sin(out) # (M, D/2) 328 emb_cos = torch.cos(out) # (M, D/2) 329 330 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 331 return emb
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
334class PatchEmbedUnSafe(PatchEmbed): 335 """Image to Patch Embedding""" 336 337 def forward(self, x): 338 B, C, H, W = x.shape 339 340 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 341 # assert H == self.img_size[0] and W == self.img_size[1], \ 342 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 343 344 x = self.proj(x).flatten(2).transpose(1, 2) 345 return x
Image to Patch Embedding
337 def forward(self, x): 338 B, C, H, W = x.shape 339 340 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 341 # assert H == self.img_size[0] and W == self.img_size[1], \ 342 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 343 344 x = self.proj(x).flatten(2).transpose(1, 2) 345 return x
348class ViT_ScaleMAE(VisionTransformer): 349 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 350 351 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 352 the model on a different zoom factor dataset. 353 """ 354 355 def __init__( 356 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 357 ): 358 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 359 self.img_size = img_size 360 self.in_chans = in_chans 361 self.depth = depth 362 self.base_resolution = base_resolution 363 364 self.patch_embed = PatchEmbedUnSafe( 365 img_size=img_size, 366 patch_size=patch_size, 367 in_chans=in_chans, 368 embed_dim=embed_dim, 369 ) 370 371 def transform_inputs(self, x): 372 import kornia.augmentation as K 373 from kornia.constants import Resample 374 375 self._transforms = CustomCompose( 376 rescale_transform=K.RandomResizedCrop( 377 (448, 448), 378 ratio=(1.0, 1.0), 379 scale=(1.0, 1.0), 380 resample=Resample.BICUBIC.name, 381 ), 382 other_transforms=None, 383 src_transform=K.Resize((224, 224)), 384 ) 385 x, _, ratios, _, _ = self._transforms(x) 386 input_res = ratios * self.base_resolution 387 return x, input_res 388 389 def convert_to_expected_dim(self, x): 390 inputs_ = x[:, 1:, :] # removing the class tokens 391 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 392 rdim = inputs_.shape[1] 393 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 394 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 395 inputs_ = inputs_.permute(0, 3, 1, 2) 396 return inputs_ 397 398 def forward_features(self, x): 399 x, input_res = self.transform_inputs(x) 400 401 B, _, h, w = x.shape 402 x = self.patch_embed(x) 403 404 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 405 pos_embed = get_2d_sincos_pos_embed_with_resolution( 406 x.shape[-1], 407 int(num_patches ** 0.5), 408 input_res, 409 cls_token=True, 410 device=x.device, 411 ) 412 413 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 414 x = torch.cat((cls_tokens, x), dim=1) 415 x = x + pos_embed 416 x = self.pos_drop(x) 417 418 # chunks obtained for getting the projections for conjuctions with upsampling blocks 419 _chunks = int(self.depth / 4) 420 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 421 422 list_from_encoder = [] 423 for i, blk in enumerate(self.blocks): 424 x = blk(x) 425 if i in chunks_for_projection: 426 list_from_encoder.append(self.convert_to_expected_dim(x)) 427 428 x = self.convert_to_expected_dim(x) 429 430 return x, list_from_encoder 431 432 def forward(self, x): 433 x, list_from_encoder = self.forward_features(x) 434 return x, list_from_encoder
Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.
355 def __init__( 356 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 357 ): 358 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 359 self.img_size = img_size 360 self.in_chans = in_chans 361 self.depth = depth 362 self.base_resolution = base_resolution 363 364 self.patch_embed = PatchEmbedUnSafe( 365 img_size=img_size, 366 patch_size=patch_size, 367 in_chans=in_chans, 368 embed_dim=embed_dim, 369 )
371 def transform_inputs(self, x): 372 import kornia.augmentation as K 373 from kornia.constants import Resample 374 375 self._transforms = CustomCompose( 376 rescale_transform=K.RandomResizedCrop( 377 (448, 448), 378 ratio=(1.0, 1.0), 379 scale=(1.0, 1.0), 380 resample=Resample.BICUBIC.name, 381 ), 382 other_transforms=None, 383 src_transform=K.Resize((224, 224)), 384 ) 385 x, _, ratios, _, _ = self._transforms(x) 386 input_res = ratios * self.base_resolution 387 return x, input_res
389 def convert_to_expected_dim(self, x): 390 inputs_ = x[:, 1:, :] # removing the class tokens 391 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 392 rdim = inputs_.shape[1] 393 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 394 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 395 inputs_ = inputs_.permute(0, 3, 1, 2) 396 return inputs_
398 def forward_features(self, x): 399 x, input_res = self.transform_inputs(x) 400 401 B, _, h, w = x.shape 402 x = self.patch_embed(x) 403 404 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 405 pos_embed = get_2d_sincos_pos_embed_with_resolution( 406 x.shape[-1], 407 int(num_patches ** 0.5), 408 input_res, 409 cls_token=True, 410 device=x.device, 411 ) 412 413 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 414 x = torch.cat((cls_tokens, x), dim=1) 415 x = x + pos_embed 416 x = self.pos_drop(x) 417 418 # chunks obtained for getting the projections for conjuctions with upsampling blocks 419 _chunks = int(self.depth / 4) 420 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 421 422 list_from_encoder = [] 423 for i, blk in enumerate(self.blocks): 424 x = blk(x) 425 if i in chunks_for_projection: 426 list_from_encoder.append(self.convert_to_expected_dim(x)) 427 428 x = self.convert_to_expected_dim(x) 429 430 return x, list_from_encoder
437def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 438 """Get vision transformer encoder. 439 440 Args: 441 backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae". 442 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 443 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 444 kwargs: Additional kwargs which can be expected by the vision transformer, 445 e.g. 'base_resolution' for `ViT_ScaleMAE`. 446 447 Returns: 448 The vision transformer. 449 """ 450 if backbone == "sam": 451 if model == "vit_b": 452 encoder = ViT_Sam( 453 depth=12, embed_dim=768, img_size=1024, mlp_ratio=4, 454 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 455 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 456 global_attn_indexes=[2, 5, 8, 11], 457 window_size=14, out_chans=256, 458 ) 459 elif model == "vit_l": 460 encoder = ViT_Sam( 461 depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4, 462 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 463 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 464 global_attn_indexes=[5, 11, 17, 23], 465 window_size=14, out_chans=256, 466 ) 467 elif model == "vit_h": 468 encoder = ViT_Sam( 469 depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4, 470 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 471 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 472 global_attn_indexes=[7, 15, 23, 31], 473 window_size=14, out_chans=256, 474 ) 475 else: 476 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 477 478 elif backbone == "sam2": 479 if model == "hvit_t": 480 encoder = ViT_Sam2( 481 embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9], 482 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 483 ) 484 elif model == "hvit_s": 485 encoder = ViT_Sam2( 486 embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13], 487 window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96], 488 ) 489 elif model == "hvit_b": 490 encoder = ViT_Sam2( 491 embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112], 492 ) 493 elif model == "hvit_l": 494 encoder = ViT_Sam2( 495 embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43], 496 window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144], 497 ) 498 else: 499 raise ValueError( 500 f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported." 501 ) 502 503 elif backbone == "mae": 504 if model == "vit_b": 505 encoder = ViT_MAE( 506 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 507 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 508 ) 509 elif model == "vit_l": 510 encoder = ViT_MAE( 511 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 512 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 513 ) 514 elif model == "vit_h": 515 encoder = ViT_MAE( 516 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 517 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 518 ) 519 else: 520 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 521 522 elif backbone == "scalemae": 523 base_resolution = kwargs.get("base_resolution", 2.5) 524 525 if model == "vit_b": 526 encoder = ViT_ScaleMAE( 527 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 528 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 529 base_resolution=base_resolution, 530 ) 531 elif model == "vit_l": 532 encoder = ViT_ScaleMAE( 533 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 534 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 535 base_resolution=base_resolution, 536 ) 537 elif model == "vit_h": 538 encoder = ViT_ScaleMAE( 539 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 540 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 541 base_resolution=base_resolution, 542 ) 543 else: 544 raise ValueError( 545 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 546 ) 547 548 else: 549 raise ValueError( 550 "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae' or 'scalemae'. Please choose one of them." 551 ) 552 553 return encoder
Get vision transformer encoder.
Arguments:
- backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
- model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- kwargs: Additional kwargs which can be expected by the vision transformer,
e.g. 'base_resolution' for
ViT_ScaleMAE
.
Returns:
The vision transformer.