torch_em.model.vit
1from typing import Tuple 2from functools import partial 3 4import torch 5import torch.nn as nn 6 7# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should 8# only be optional dependencies for torch_em 9try: 10 from segment_anything.modeling import ImageEncoderViT 11 _sam_import_success = True 12except ImportError: 13 ImageEncoderViT = object 14 _sam_import_success = False 15 16try: 17 from timm.models.vision_transformer import VisionTransformer, PatchEmbed 18 _timm_import_success = True 19except ImportError: 20 VisionTransformer = object 21 PatchEmbed = object 22 _timm_import_success = False 23 24 25class ViT_Sam(ImageEncoderViT): 26 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 27 28 Based on: 29 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 30 31 Args: 32 in_chans: The number of input channels. 33 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 34 global_attn_indexes: The global attention indices. 35 kwargs: Keyword arguments for the image encoder base class. 36 """ 37 def __init__( 38 self, 39 in_chans: int = 3, 40 embed_dim: int = 768, 41 global_attn_indexes: Tuple[int, ...] = ..., 42 **kwargs, 43 ) -> None: 44 if not _sam_import_success: 45 raise RuntimeError( 46 "The vision transformer backend can only be initialized if segment anything is installed." 47 "Please install segment anything from https://github.com/facebookresearch/segment-anything." 48 "and then rerun your code." 49 ) 50 51 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 52 self.chunks_for_projection = global_attn_indexes 53 self.in_chans = in_chans 54 self.embed_dim = embed_dim 55 56 def forward(self, x: torch.Tensor) -> torch.Tensor: 57 """Apply the vision transformer to input data. 58 59 Args: 60 x: The input data. 61 62 Returns: 63 The vision transformer output. 64 """ 65 x = self.patch_embed(x) 66 if self.pos_embed is not None: 67 x = x + self.pos_embed 68 69 list_from_encoder = [] 70 for i, blk in enumerate(self.blocks): 71 x = blk(x) 72 if i in self.chunks_for_projection: 73 list_from_encoder.append(x) 74 75 x = x.permute(0, 3, 1, 2) 76 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 77 return x, list_from_encoder[:3] 78 79 80class ViT_MAE(VisionTransformer): 81 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 82 83 Based on: 84 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 85 86 Args: 87 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 88 in_chans: The number of input channels. 89 depth: The depth of the vision transformer. 90 kwargs: Additional keyword arguments for the vision transformer base class. 91 """ 92 def __init__( 93 self, 94 img_size: int = 1024, # chosen to match our experiments with segment anything 95 in_chans: int = 3, 96 depth: int = 12, 97 **kwargs 98 ): 99 if not _timm_import_success: 100 raise RuntimeError( 101 "The vision transformer backend can only be initialized if timm is installed." 102 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/." 103 "and then rerun your code" 104 ) 105 super().__init__(img_size=img_size, depth=depth, **kwargs) 106 self.img_size = img_size 107 self.in_chans = in_chans 108 self.depth = depth 109 110 def convert_to_expected_dim(self, inputs_): 111 """@private 112 """ 113 inputs_ = inputs_[:, 1:, :] # removing the class tokens 114 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 115 rdim = inputs_.shape[1] 116 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 117 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 118 inputs_ = inputs_.permute(0, 3, 1, 2) 119 return inputs_ 120 121 def forward_features(self, x): 122 """@private 123 """ 124 B = x.shape[0] 125 x = self.patch_embed(x) 126 127 cls_tokens = self.cls_token.expand(B, -1, -1) 128 x = torch.cat((cls_tokens, x), dim=1) 129 130 x = x + self.pos_embed 131 x = self.pos_drop(x) 132 133 # chunks obtained for getting the projections for conjuctions with upsampling blocks 134 _chunks = int(self.depth / 4) 135 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 136 137 list_from_encoder = [] 138 for i, blk in enumerate(self.blocks): 139 x = blk(x) 140 if i in chunks_for_projection: 141 list_from_encoder.append(self.convert_to_expected_dim(x)) 142 143 x = self.convert_to_expected_dim(x) 144 return x, list_from_encoder[:3] 145 146 def forward(self, x: torch.Tensor) -> torch.Tensor: 147 """Apply the vision transformer to input data. 148 149 Args: 150 x: The input data. 151 152 Returns: 153 The vision transformer output. 154 """ 155 x, list_from_encoder = self.forward_features(x) 156 return x, list_from_encoder 157 158 159# 160# Utilities for ScaleMAE's ViT 161# 162 163 164class CustomCompose: 165 def __init__(self, rescale_transform, other_transforms, src_transform): 166 self.rescale_transform = rescale_transform 167 self.other_transforms = other_transforms 168 self.src_transform = src_transform 169 170 def __call__(self, x, valid_masks=None): 171 if valid_masks is not None: 172 nodata = (x * (1 - valid_masks.float())).max() 173 x_aug = self.rescale_transform(x) 174 parms = self.rescale_transform._params 175 176 # sanity check, comment if this is working 177 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 178 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 179 180 if valid_masks is not None: 181 valid_masks = x_aug != nodata 182 _, c, h, w = x_aug.shape 183 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 184 else: 185 zero_ratio = -1 186 187 if self.other_transforms: 188 x_aug = self.other_transforms(x_aug) 189 x_src = self.src_transform(x_aug) 190 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 191 192 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 193 # assert (dx == dy).all() 194 195 h, w = x_aug.shape[-2:] 196 # assert h == w 197 198 return x_aug, x_src, dx / h, zero_ratio, valid_masks 199 200 201def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 202 """ 203 grid_size: int of the grid height and width 204 res: array of size n, representing the resolution of a pixel (say, in meters), 205 return: 206 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 207 """ 208 # res = torch.FloatTensor(res).to(device) 209 res = res.to(device) 210 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 211 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 212 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 213 grid = torch.stack(grid, dim=0) # 2 x h x w 214 215 # grid = grid.reshape([2, 1, grid_size, grid_size]) 216 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 217 _, n, h, w = grid.shape 218 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 219 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 220 if cls_token: 221 pos_embed = torch.cat( 222 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 223 ) 224 225 return pos_embed 226 227 228def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 229 assert embed_dim % 2 == 0 230 231 # use half of dimensions to encode grid_h 232 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 233 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 234 235 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 236 return emb 237 238 239def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 240 """ 241 embed_dim: output dimension for each position 242 pos: a list of positions to be encoded: size (M,) 243 out: (M, D) 244 """ 245 assert embed_dim % 2 == 0 246 # old_shape = pos 247 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 248 omega /= embed_dim / 2.0 249 omega = 1.0 / 10000**omega # (D/2,) 250 251 pos = pos.reshape(-1) # (M,) 252 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 253 254 emb_sin = torch.sin(out) # (M, D/2) 255 emb_cos = torch.cos(out) # (M, D/2) 256 257 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 258 return emb 259 260 261class PatchEmbedUnSafe(PatchEmbed): 262 """Image to Patch Embedding""" 263 264 def forward(self, x): 265 B, C, H, W = x.shape 266 267 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 268 # assert H == self.img_size[0] and W == self.img_size[1], \ 269 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 270 271 x = self.proj(x).flatten(2).transpose(1, 2) 272 return x 273 274 275class ViT_ScaleMAE(VisionTransformer): 276 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 277 278 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 279 the model on a different zoom factor dataset. 280 """ 281 282 def __init__( 283 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 284 ): 285 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 286 self.img_size = img_size 287 self.in_chans = in_chans 288 self.depth = depth 289 self.base_resolution = base_resolution 290 291 self.patch_embed = PatchEmbedUnSafe( 292 img_size=img_size, 293 patch_size=patch_size, 294 in_chans=in_chans, 295 embed_dim=embed_dim, 296 ) 297 298 def transform_inputs(self, x): 299 import kornia.augmentation as K 300 from kornia.constants import Resample 301 302 self._transforms = CustomCompose( 303 rescale_transform=K.RandomResizedCrop( 304 (448, 448), 305 ratio=(1.0, 1.0), 306 scale=(1.0, 1.0), 307 resample=Resample.BICUBIC.name, 308 ), 309 other_transforms=None, 310 src_transform=K.Resize((224, 224)), 311 ) 312 x, _, ratios, _, _ = self._transforms(x) 313 input_res = ratios * self.base_resolution 314 return x, input_res 315 316 def convert_to_expected_dim(self, x): 317 inputs_ = x[:, 1:, :] # removing the class tokens 318 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 319 rdim = inputs_.shape[1] 320 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 321 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 322 inputs_ = inputs_.permute(0, 3, 1, 2) 323 return inputs_ 324 325 def forward_features(self, x): 326 x, input_res = self.transform_inputs(x) 327 328 B, _, h, w = x.shape 329 x = self.patch_embed(x) 330 331 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 332 pos_embed = get_2d_sincos_pos_embed_with_resolution( 333 x.shape[-1], 334 int(num_patches ** 0.5), 335 input_res, 336 cls_token=True, 337 device=x.device, 338 ) 339 340 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 341 x = torch.cat((cls_tokens, x), dim=1) 342 x = x + pos_embed 343 x = self.pos_drop(x) 344 345 # chunks obtained for getting the projections for conjuctions with upsampling blocks 346 _chunks = int(self.depth / 4) 347 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 348 349 list_from_encoder = [] 350 for i, blk in enumerate(self.blocks): 351 x = blk(x) 352 if i in chunks_for_projection: 353 list_from_encoder.append(self.convert_to_expected_dim(x)) 354 355 x = self.convert_to_expected_dim(x) 356 357 return x, list_from_encoder 358 359 def forward(self, x): 360 x, list_from_encoder = self.forward_features(x) 361 return x, list_from_encoder 362 363 364def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 365 """Get vision transformer encoder. 366 367 Args: 368 backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae". 369 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 370 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 371 kwargs: Additional kwargs which can be expected by the vision transformer, 372 e.g. 'base_resolution' for `ViT_ScaleMAE`. 373 374 Returns: 375 The vision transformer. 376 """ 377 if backbone == "sam": 378 if model == "vit_b": 379 encoder = ViT_Sam( 380 depth=12, embed_dim=768, img_size=1024, mlp_ratio=4, 381 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 382 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 383 global_attn_indexes=[2, 5, 8, 11], 384 window_size=14, out_chans=256, 385 ) 386 elif model == "vit_l": 387 encoder = ViT_Sam( 388 depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4, 389 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 390 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 391 global_attn_indexes=[5, 11, 17, 23], 392 window_size=14, out_chans=256 393 ) 394 elif model == "vit_h": 395 encoder = ViT_Sam( 396 depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4, 397 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 398 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 399 global_attn_indexes=[7, 15, 23, 31], 400 window_size=14, out_chans=256 401 ) 402 else: 403 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 404 405 elif backbone == "mae": 406 if model == "vit_b": 407 encoder = ViT_MAE( 408 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 409 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 410 ) 411 elif model == "vit_l": 412 encoder = ViT_MAE( 413 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 414 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 415 ) 416 elif model == "vit_h": 417 encoder = ViT_MAE( 418 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 419 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 420 ) 421 else: 422 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 423 424 elif backbone == "scalemae": 425 base_resolution = kwargs.get("base_resolution", 2.5) 426 427 if model == "vit_b": 428 encoder = ViT_ScaleMAE( 429 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 430 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 431 base_resolution=base_resolution, 432 ) 433 elif model == "vit_l": 434 encoder = ViT_ScaleMAE( 435 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 436 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 437 base_resolution=base_resolution, 438 ) 439 elif model == "vit_h": 440 encoder = ViT_ScaleMAE( 441 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 442 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 443 base_resolution=base_resolution, 444 ) 445 else: 446 raise ValueError( 447 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 448 ) 449 450 else: 451 raise ValueError("The 'UNETR' supported backbones are 'sam', 'mae' or 'scalemae'. Please choose one of them.") 452 453 return encoder
26class ViT_Sam(ImageEncoderViT): 27 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 28 29 Based on: 30 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 31 32 Args: 33 in_chans: The number of input channels. 34 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 35 global_attn_indexes: The global attention indices. 36 kwargs: Keyword arguments for the image encoder base class. 37 """ 38 def __init__( 39 self, 40 in_chans: int = 3, 41 embed_dim: int = 768, 42 global_attn_indexes: Tuple[int, ...] = ..., 43 **kwargs, 44 ) -> None: 45 if not _sam_import_success: 46 raise RuntimeError( 47 "The vision transformer backend can only be initialized if segment anything is installed." 48 "Please install segment anything from https://github.com/facebookresearch/segment-anything." 49 "and then rerun your code." 50 ) 51 52 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 53 self.chunks_for_projection = global_attn_indexes 54 self.in_chans = in_chans 55 self.embed_dim = embed_dim 56 57 def forward(self, x: torch.Tensor) -> torch.Tensor: 58 """Apply the vision transformer to input data. 59 60 Args: 61 x: The input data. 62 63 Returns: 64 The vision transformer output. 65 """ 66 x = self.patch_embed(x) 67 if self.pos_embed is not None: 68 x = x + self.pos_embed 69 70 list_from_encoder = [] 71 for i, blk in enumerate(self.blocks): 72 x = blk(x) 73 if i in self.chunks_for_projection: 74 list_from_encoder.append(x) 75 76 x = x.permute(0, 3, 1, 2) 77 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 78 return x, list_from_encoder[:3]
Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
Arguments:
- in_chans: The number of input channels.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- global_attn_indexes: The global attention indices.
- kwargs: Keyword arguments for the image encoder base class.
38 def __init__( 39 self, 40 in_chans: int = 3, 41 embed_dim: int = 768, 42 global_attn_indexes: Tuple[int, ...] = ..., 43 **kwargs, 44 ) -> None: 45 if not _sam_import_success: 46 raise RuntimeError( 47 "The vision transformer backend can only be initialized if segment anything is installed." 48 "Please install segment anything from https://github.com/facebookresearch/segment-anything." 49 "and then rerun your code." 50 ) 51 52 super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs) 53 self.chunks_for_projection = global_attn_indexes 54 self.in_chans = in_chans 55 self.embed_dim = embed_dim
57 def forward(self, x: torch.Tensor) -> torch.Tensor: 58 """Apply the vision transformer to input data. 59 60 Args: 61 x: The input data. 62 63 Returns: 64 The vision transformer output. 65 """ 66 x = self.patch_embed(x) 67 if self.pos_embed is not None: 68 x = x + self.pos_embed 69 70 list_from_encoder = [] 71 for i, blk in enumerate(self.blocks): 72 x = blk(x) 73 if i in self.chunks_for_projection: 74 list_from_encoder.append(x) 75 76 x = x.permute(0, 3, 1, 2) 77 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 78 return x, list_from_encoder[:3]
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
81class ViT_MAE(VisionTransformer): 82 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 83 84 Based on: 85 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 86 87 Args: 88 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 89 in_chans: The number of input channels. 90 depth: The depth of the vision transformer. 91 kwargs: Additional keyword arguments for the vision transformer base class. 92 """ 93 def __init__( 94 self, 95 img_size: int = 1024, # chosen to match our experiments with segment anything 96 in_chans: int = 3, 97 depth: int = 12, 98 **kwargs 99 ): 100 if not _timm_import_success: 101 raise RuntimeError( 102 "The vision transformer backend can only be initialized if timm is installed." 103 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/." 104 "and then rerun your code" 105 ) 106 super().__init__(img_size=img_size, depth=depth, **kwargs) 107 self.img_size = img_size 108 self.in_chans = in_chans 109 self.depth = depth 110 111 def convert_to_expected_dim(self, inputs_): 112 """@private 113 """ 114 inputs_ = inputs_[:, 1:, :] # removing the class tokens 115 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 116 rdim = inputs_.shape[1] 117 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 118 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 119 inputs_ = inputs_.permute(0, 3, 1, 2) 120 return inputs_ 121 122 def forward_features(self, x): 123 """@private 124 """ 125 B = x.shape[0] 126 x = self.patch_embed(x) 127 128 cls_tokens = self.cls_token.expand(B, -1, -1) 129 x = torch.cat((cls_tokens, x), dim=1) 130 131 x = x + self.pos_embed 132 x = self.pos_drop(x) 133 134 # chunks obtained for getting the projections for conjuctions with upsampling blocks 135 _chunks = int(self.depth / 4) 136 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 137 138 list_from_encoder = [] 139 for i, blk in enumerate(self.blocks): 140 x = blk(x) 141 if i in chunks_for_projection: 142 list_from_encoder.append(self.convert_to_expected_dim(x)) 143 144 x = self.convert_to_expected_dim(x) 145 return x, list_from_encoder[:3] 146 147 def forward(self, x: torch.Tensor) -> torch.Tensor: 148 """Apply the vision transformer to input data. 149 150 Args: 151 x: The input data. 152 153 Returns: 154 The vision transformer output. 155 """ 156 x, list_from_encoder = self.forward_features(x) 157 return x, list_from_encoder
Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- in_chans: The number of input channels.
- depth: The depth of the vision transformer.
- kwargs: Additional keyword arguments for the vision transformer base class.
93 def __init__( 94 self, 95 img_size: int = 1024, # chosen to match our experiments with segment anything 96 in_chans: int = 3, 97 depth: int = 12, 98 **kwargs 99 ): 100 if not _timm_import_success: 101 raise RuntimeError( 102 "The vision transformer backend can only be initialized if timm is installed." 103 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/." 104 "and then rerun your code" 105 ) 106 super().__init__(img_size=img_size, depth=depth, **kwargs) 107 self.img_size = img_size 108 self.in_chans = in_chans 109 self.depth = depth
147 def forward(self, x: torch.Tensor) -> torch.Tensor: 148 """Apply the vision transformer to input data. 149 150 Args: 151 x: The input data. 152 153 Returns: 154 The vision transformer output. 155 """ 156 x, list_from_encoder = self.forward_features(x) 157 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
165class CustomCompose: 166 def __init__(self, rescale_transform, other_transforms, src_transform): 167 self.rescale_transform = rescale_transform 168 self.other_transforms = other_transforms 169 self.src_transform = src_transform 170 171 def __call__(self, x, valid_masks=None): 172 if valid_masks is not None: 173 nodata = (x * (1 - valid_masks.float())).max() 174 x_aug = self.rescale_transform(x) 175 parms = self.rescale_transform._params 176 177 # sanity check, comment if this is working 178 # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 179 # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 180 181 if valid_masks is not None: 182 valid_masks = x_aug != nodata 183 _, c, h, w = x_aug.shape 184 zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 185 else: 186 zero_ratio = -1 187 188 if self.other_transforms: 189 x_aug = self.other_transforms(x_aug) 190 x_src = self.src_transform(x_aug) 191 dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 192 193 # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 194 # assert (dx == dy).all() 195 196 h, w = x_aug.shape[-2:] 197 # assert h == w 198 199 return x_aug, x_src, dx / h, zero_ratio, valid_masks
202def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"): 203 """ 204 grid_size: int of the grid height and width 205 res: array of size n, representing the resolution of a pixel (say, in meters), 206 return: 207 pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 208 """ 209 # res = torch.FloatTensor(res).to(device) 210 res = res.to(device) 211 grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 212 grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 213 grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy 214 grid = torch.stack(grid, dim=0) # 2 x h x w 215 216 # grid = grid.reshape([2, 1, grid_size, grid_size]) 217 grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 218 _, n, h, w = grid.shape 219 pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2) 220 pos_embed = pos_embed.reshape(n, h * w, embed_dim) 221 if cls_token: 222 pos_embed = torch.cat( 223 [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1 224 ) 225 226 return pos_embed
grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
229def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 230 assert embed_dim % 2 == 0 231 232 # use half of dimensions to encode grid_h 233 emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) 234 emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) 235 236 emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 237 return emb
240def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 241 """ 242 embed_dim: output dimension for each position 243 pos: a list of positions to be encoded: size (M,) 244 out: (M, D) 245 """ 246 assert embed_dim % 2 == 0 247 # old_shape = pos 248 omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 249 omega /= embed_dim / 2.0 250 omega = 1.0 / 10000**omega # (D/2,) 251 252 pos = pos.reshape(-1) # (M,) 253 out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 254 255 emb_sin = torch.sin(out) # (M, D/2) 256 emb_cos = torch.cos(out) # (M, D/2) 257 258 emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 259 return emb
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
262class PatchEmbedUnSafe(PatchEmbed): 263 """Image to Patch Embedding""" 264 265 def forward(self, x): 266 B, C, H, W = x.shape 267 268 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 269 # assert H == self.img_size[0] and W == self.img_size[1], \ 270 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 271 272 x = self.proj(x).flatten(2).transpose(1, 2) 273 return x
Image to Patch Embedding
265 def forward(self, x): 266 B, C, H, W = x.shape 267 268 # NOTE: Comment code from ScaleMAE: Dropped size check in timm 269 # assert H == self.img_size[0] and W == self.img_size[1], \ 270 # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 271 272 x = self.proj(x).flatten(2).transpose(1, 2) 273 return x
276class ViT_ScaleMAE(VisionTransformer): 277 """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link). 278 279 NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using 280 the model on a different zoom factor dataset. 281 """ 282 283 def __init__( 284 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 285 ): 286 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 287 self.img_size = img_size 288 self.in_chans = in_chans 289 self.depth = depth 290 self.base_resolution = base_resolution 291 292 self.patch_embed = PatchEmbedUnSafe( 293 img_size=img_size, 294 patch_size=patch_size, 295 in_chans=in_chans, 296 embed_dim=embed_dim, 297 ) 298 299 def transform_inputs(self, x): 300 import kornia.augmentation as K 301 from kornia.constants import Resample 302 303 self._transforms = CustomCompose( 304 rescale_transform=K.RandomResizedCrop( 305 (448, 448), 306 ratio=(1.0, 1.0), 307 scale=(1.0, 1.0), 308 resample=Resample.BICUBIC.name, 309 ), 310 other_transforms=None, 311 src_transform=K.Resize((224, 224)), 312 ) 313 x, _, ratios, _, _ = self._transforms(x) 314 input_res = ratios * self.base_resolution 315 return x, input_res 316 317 def convert_to_expected_dim(self, x): 318 inputs_ = x[:, 1:, :] # removing the class tokens 319 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 320 rdim = inputs_.shape[1] 321 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 322 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 323 inputs_ = inputs_.permute(0, 3, 1, 2) 324 return inputs_ 325 326 def forward_features(self, x): 327 x, input_res = self.transform_inputs(x) 328 329 B, _, h, w = x.shape 330 x = self.patch_embed(x) 331 332 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 333 pos_embed = get_2d_sincos_pos_embed_with_resolution( 334 x.shape[-1], 335 int(num_patches ** 0.5), 336 input_res, 337 cls_token=True, 338 device=x.device, 339 ) 340 341 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 342 x = torch.cat((cls_tokens, x), dim=1) 343 x = x + pos_embed 344 x = self.pos_drop(x) 345 346 # chunks obtained for getting the projections for conjuctions with upsampling blocks 347 _chunks = int(self.depth / 4) 348 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 349 350 list_from_encoder = [] 351 for i, blk in enumerate(self.blocks): 352 x = blk(x) 353 if i in chunks_for_projection: 354 list_from_encoder.append(self.convert_to_expected_dim(x)) 355 356 x = self.convert_to_expected_dim(x) 357 358 return x, list_from_encoder 359 360 def forward(self, x): 361 x, list_from_encoder = self.forward_features(x) 362 return x, list_from_encoder
Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.
283 def __init__( 284 self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs 285 ): 286 super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs) 287 self.img_size = img_size 288 self.in_chans = in_chans 289 self.depth = depth 290 self.base_resolution = base_resolution 291 292 self.patch_embed = PatchEmbedUnSafe( 293 img_size=img_size, 294 patch_size=patch_size, 295 in_chans=in_chans, 296 embed_dim=embed_dim, 297 )
299 def transform_inputs(self, x): 300 import kornia.augmentation as K 301 from kornia.constants import Resample 302 303 self._transforms = CustomCompose( 304 rescale_transform=K.RandomResizedCrop( 305 (448, 448), 306 ratio=(1.0, 1.0), 307 scale=(1.0, 1.0), 308 resample=Resample.BICUBIC.name, 309 ), 310 other_transforms=None, 311 src_transform=K.Resize((224, 224)), 312 ) 313 x, _, ratios, _, _ = self._transforms(x) 314 input_res = ratios * self.base_resolution 315 return x, input_res
317 def convert_to_expected_dim(self, x): 318 inputs_ = x[:, 1:, :] # removing the class tokens 319 # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C) 320 rdim = inputs_.shape[1] 321 dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape 322 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 323 inputs_ = inputs_.permute(0, 3, 1, 2) 324 return inputs_
326 def forward_features(self, x): 327 x, input_res = self.transform_inputs(x) 328 329 B, _, h, w = x.shape 330 x = self.patch_embed(x) 331 332 num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])) 333 pos_embed = get_2d_sincos_pos_embed_with_resolution( 334 x.shape[-1], 335 int(num_patches ** 0.5), 336 input_res, 337 cls_token=True, 338 device=x.device, 339 ) 340 341 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 342 x = torch.cat((cls_tokens, x), dim=1) 343 x = x + pos_embed 344 x = self.pos_drop(x) 345 346 # chunks obtained for getting the projections for conjuctions with upsampling blocks 347 _chunks = int(self.depth / 4) 348 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 349 350 list_from_encoder = [] 351 for i, blk in enumerate(self.blocks): 352 x = blk(x) 353 if i in chunks_for_projection: 354 list_from_encoder.append(self.convert_to_expected_dim(x)) 355 356 x = self.convert_to_expected_dim(x) 357 358 return x, list_from_encoder
365def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module: 366 """Get vision transformer encoder. 367 368 Args: 369 backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae". 370 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 371 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 372 kwargs: Additional kwargs which can be expected by the vision transformer, 373 e.g. 'base_resolution' for `ViT_ScaleMAE`. 374 375 Returns: 376 The vision transformer. 377 """ 378 if backbone == "sam": 379 if model == "vit_b": 380 encoder = ViT_Sam( 381 depth=12, embed_dim=768, img_size=1024, mlp_ratio=4, 382 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 383 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 384 global_attn_indexes=[2, 5, 8, 11], 385 window_size=14, out_chans=256, 386 ) 387 elif model == "vit_l": 388 encoder = ViT_Sam( 389 depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4, 390 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 391 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 392 global_attn_indexes=[5, 11, 17, 23], 393 window_size=14, out_chans=256 394 ) 395 elif model == "vit_h": 396 encoder = ViT_Sam( 397 depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4, 398 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 399 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 400 global_attn_indexes=[7, 15, 23, 31], 401 window_size=14, out_chans=256 402 ) 403 else: 404 raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 405 406 elif backbone == "mae": 407 if model == "vit_b": 408 encoder = ViT_MAE( 409 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 410 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 411 ) 412 elif model == "vit_l": 413 encoder = ViT_MAE( 414 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 415 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 416 ) 417 elif model == "vit_h": 418 encoder = ViT_MAE( 419 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 420 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 421 ) 422 else: 423 raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.") 424 425 elif backbone == "scalemae": 426 base_resolution = kwargs.get("base_resolution", 2.5) 427 428 if model == "vit_b": 429 encoder = ViT_ScaleMAE( 430 img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12, 431 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 432 base_resolution=base_resolution, 433 ) 434 elif model == "vit_l": 435 encoder = ViT_ScaleMAE( 436 img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 437 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 438 base_resolution=base_resolution, 439 ) 440 elif model == "vit_h": 441 encoder = ViT_ScaleMAE( 442 img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16, 443 mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 444 base_resolution=base_resolution, 445 ) 446 else: 447 raise ValueError( 448 f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported." 449 ) 450 451 else: 452 raise ValueError("The 'UNETR' supported backbones are 'sam', 'mae' or 'scalemae'. Please choose one of them.") 453 454 return encoder
Get vision transformer encoder.
Arguments:
- backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
- model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- kwargs: Additional kwargs which can be expected by the vision transformer,
e.g. 'base_resolution' for
ViT_ScaleMAE
.
Returns:
The vision transformer.