torch_em.model.vit
1import torch 2import torch.nn as nn 3 4from typing import Tuple 5from functools import partial 6 7# we catch ImportErrors here because segment_anything, micro_sam and timm should 8# only be optional dependencies for torch_em 9try: 10 from segment_anything.modeling import ImageEncoderViT 11 _sam_import_success = True 12except ImportError: 13 ImageEncoderViT = object 14 _sam_import_success = False 15 16try: 17 from timm.models.vision_transformer import VisionTransformer 18 _timm_import_success = True 19except ImportError: 20 VisionTransformer = object 21 _timm_import_success = False 22 23 24class ViT_Sam(ImageEncoderViT): 25 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643): 26 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 27 """ 28 def __init__( 29 self, 30 in_chans: int = 3, 31 embed_dim: int = 768, 32 global_attn_indexes: Tuple[int, ...] = ..., 33 **kwargs 34 ) -> None: 35 if not _sam_import_success: 36 raise RuntimeError( 37 "The vision transformer backend can only be initialized if segment anything is installed." 38 "Please install segment anything from https://github.com/facebookresearch/segment-anything." 39 "and then rerun your code." 40 ) 41 42 super().__init__( 43 embed_dim=embed_dim, 44 global_attn_indexes=global_attn_indexes, 45 **kwargs, 46 ) 47 self.chunks_for_projection = global_attn_indexes 48 self.in_chans = in_chans 49 self.embed_dim = embed_dim 50 51 def forward(self, x: torch.Tensor) -> torch.Tensor: 52 x = self.patch_embed(x) 53 if self.pos_embed is not None: 54 x = x + self.pos_embed 55 56 list_from_encoder = [] 57 for i, blk in enumerate(self.blocks): 58 x = blk(x) 59 if i in self.chunks_for_projection: 60 list_from_encoder.append(x) 61 62 x = x.permute(0, 3, 1, 2) 63 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 64 return x, list_from_encoder[:3] 65 66 67class ViT_MAE(VisionTransformer): 68 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377) 69 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 70 """ 71 def __init__( 72 self, 73 img_size=1024, # chosen to match our experiments with segment anything 74 in_chans=3, 75 depth=12, 76 **kwargs 77 ): 78 if not _timm_import_success: 79 raise RuntimeError( 80 "The vision transformer backend can only be initialized if timm is installed." 81 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/." 82 "and then rerun your code" 83 ) 84 super().__init__(img_size=img_size, depth=depth, **kwargs) 85 self.img_size = img_size 86 self.in_chans = in_chans 87 self.depth = depth 88 89 def convert_to_expected_dim(self, inputs_): 90 inputs_ = inputs_[:, 1:, :] # removing the class tokens 91 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 92 rdim = inputs_.shape[1] 93 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 94 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 95 inputs_ = inputs_.permute(0, 3, 1, 2) 96 return inputs_ 97 98 def forward_features(self, x): 99 B = x.shape[0] 100 x = self.patch_embed(x) 101 102 cls_tokens = self.cls_token.expand(B, -1, -1) 103 x = torch.cat((cls_tokens, x), dim=1) 104 105 x = x + self.pos_embed 106 x = self.pos_drop(x) 107 108 # chunks obtained for getting the projections for conjuctions with upsampling blocks 109 _chunks = int(self.depth / 4) 110 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 111 112 list_from_encoder = [] 113 for i, blk in enumerate(self.blocks): 114 x = blk(x) 115 if i in chunks_for_projection: 116 list_from_encoder.append(self.convert_to_expected_dim(x)) 117 118 x = self.convert_to_expected_dim(x) 119 return x, list_from_encoder[:3] 120 121 def forward(self, x): 122 x, list_from_encoder = self.forward_features(x) 123 return x, list_from_encoder 124 125 126def get_vision_transformer(backbone: str, model: str, img_size: int = 1024): 127 if backbone == "sam": 128 if model == "vit_b": 129 encoder = ViT_Sam( 130 depth=12, embed_dim=768, img_size=1024, mlp_ratio=4, 131 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 132 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 133 global_attn_indexes=[2, 5, 8, 11], 134 window_size=14, out_chans=256, 135 ) 136 elif model == "vit_l": 137 encoder = ViT_Sam( 138 depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4, 139 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 140 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 141 global_attn_indexes=[5, 11, 17, 23], 142 window_size=14, out_chans=256 143 ) 144 elif model == "vit_h": 145 encoder = ViT_Sam( 146 depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4, 147 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 148 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 149 global_attn_indexes=[7, 15, 23, 31], 150 window_size=14, out_chans=256 151 ) 152 else: 153 raise ValueError(f"{model} is not supported by SAM. Currently vit_b, vit_l, vit_h are supported.") 154 155 elif backbone == "mae": 156 if model == "vit_b": 157 encoder = ViT_MAE( 158 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 159 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 160 ) 161 elif model == "vit_l": 162 encoder = ViT_MAE( 163 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 164 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 165 ) 166 elif model == "vit_h": 167 encoder = ViT_MAE( 168 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 169 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 170 ) 171 else: 172 raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.") 173 174 else: 175 raise ValueError("The UNETR supported backbones are `sam` or `mae`. Please choose either of the two") 176 177 return encoder
25class ViT_Sam(ImageEncoderViT): 26 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643): 27 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 28 """ 29 def __init__( 30 self, 31 in_chans: int = 3, 32 embed_dim: int = 768, 33 global_attn_indexes: Tuple[int, ...] = ..., 34 **kwargs 35 ) -> None: 36 if not _sam_import_success: 37 raise RuntimeError( 38 "The vision transformer backend can only be initialized if segment anything is installed." 39 "Please install segment anything from https://github.com/facebookresearch/segment-anything." 40 "and then rerun your code." 41 ) 42 43 super().__init__( 44 embed_dim=embed_dim, 45 global_attn_indexes=global_attn_indexes, 46 **kwargs, 47 ) 48 self.chunks_for_projection = global_attn_indexes 49 self.in_chans = in_chans 50 self.embed_dim = embed_dim 51 52 def forward(self, x: torch.Tensor) -> torch.Tensor: 53 x = self.patch_embed(x) 54 if self.pos_embed is not None: 55 x = x + self.pos_embed 56 57 list_from_encoder = [] 58 for i, blk in enumerate(self.blocks): 59 x = blk(x) 60 if i in self.chunks_for_projection: 61 list_from_encoder.append(x) 62 63 x = x.permute(0, 3, 1, 2) 64 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 65 return x, list_from_encoder[:3]
Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643): https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
29 def __init__( 30 self, 31 in_chans: int = 3, 32 embed_dim: int = 768, 33 global_attn_indexes: Tuple[int, ...] = ..., 34 **kwargs 35 ) -> None: 36 if not _sam_import_success: 37 raise RuntimeError( 38 "The vision transformer backend can only be initialized if segment anything is installed." 39 "Please install segment anything from https://github.com/facebookresearch/segment-anything." 40 "and then rerun your code." 41 ) 42 43 super().__init__( 44 embed_dim=embed_dim, 45 global_attn_indexes=global_attn_indexes, 46 **kwargs, 47 ) 48 self.chunks_for_projection = global_attn_indexes 49 self.in_chans = in_chans 50 self.embed_dim = embed_dim
Arguments:
- img_size (int): Input image size.
- patch_size (int): Patch size.
- in_chans (int): Number of input image channels.
- embed_dim (int): Patch embedding dimension.
- depth (int): Depth of ViT.
- num_heads (int): Number of attention heads in each ViT block.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
- norm_layer (nn.Module): Normalization layer.
- act_layer (nn.Module): Activation layer.
- use_abs_pos (bool): If True, use absolute positional embeddings.
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
- window_size (int): Window size for window attention blocks.
- global_attn_indexes (list): Indexes for blocks using global attention.
52 def forward(self, x: torch.Tensor) -> torch.Tensor: 53 x = self.patch_embed(x) 54 if self.pos_embed is not None: 55 x = x + self.pos_embed 56 57 list_from_encoder = [] 58 for i, blk in enumerate(self.blocks): 59 x = blk(x) 60 if i in self.chunks_for_projection: 61 list_from_encoder.append(x) 62 63 x = x.permute(0, 3, 1, 2) 64 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 65 return x, list_from_encoder[:3]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- segment_anything.modeling.image_encoder.ImageEncoderViT
- img_size
- patch_embed
- pos_embed
- blocks
- neck
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
68class ViT_MAE(VisionTransformer): 69 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377) 70 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 71 """ 72 def __init__( 73 self, 74 img_size=1024, # chosen to match our experiments with segment anything 75 in_chans=3, 76 depth=12, 77 **kwargs 78 ): 79 if not _timm_import_success: 80 raise RuntimeError( 81 "The vision transformer backend can only be initialized if timm is installed." 82 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/." 83 "and then rerun your code" 84 ) 85 super().__init__(img_size=img_size, depth=depth, **kwargs) 86 self.img_size = img_size 87 self.in_chans = in_chans 88 self.depth = depth 89 90 def convert_to_expected_dim(self, inputs_): 91 inputs_ = inputs_[:, 1:, :] # removing the class tokens 92 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 93 rdim = inputs_.shape[1] 94 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 95 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 96 inputs_ = inputs_.permute(0, 3, 1, 2) 97 return inputs_ 98 99 def forward_features(self, x): 100 B = x.shape[0] 101 x = self.patch_embed(x) 102 103 cls_tokens = self.cls_token.expand(B, -1, -1) 104 x = torch.cat((cls_tokens, x), dim=1) 105 106 x = x + self.pos_embed 107 x = self.pos_drop(x) 108 109 # chunks obtained for getting the projections for conjuctions with upsampling blocks 110 _chunks = int(self.depth / 4) 111 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 112 113 list_from_encoder = [] 114 for i, blk in enumerate(self.blocks): 115 x = blk(x) 116 if i in chunks_for_projection: 117 list_from_encoder.append(self.convert_to_expected_dim(x)) 118 119 x = self.convert_to_expected_dim(x) 120 return x, list_from_encoder[:3] 121 122 def forward(self, x): 123 x, list_from_encoder = self.forward_features(x) 124 return x, list_from_encoder
Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377) https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
72 def __init__( 73 self, 74 img_size=1024, # chosen to match our experiments with segment anything 75 in_chans=3, 76 depth=12, 77 **kwargs 78 ): 79 if not _timm_import_success: 80 raise RuntimeError( 81 "The vision transformer backend can only be initialized if timm is installed." 82 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/." 83 "and then rerun your code" 84 ) 85 super().__init__(img_size=img_size, depth=depth, **kwargs) 86 self.img_size = img_size 87 self.in_chans = in_chans 88 self.depth = depth
Arguments:
- img_size: Input image size.
- patch_size: Patch size.
- in_chans: Number of image input channels.
- num_classes: Mumber of classes for classification head.
- global_pool: Type of global pooling for final sequence (default: 'token').
- embed_dim: Transformer embedding dimension.
- depth: Depth of transformer.
- num_heads: Number of attention heads.
- mlp_ratio: Ratio of mlp hidden dim to embedding dim.
- qkv_bias: Enable bias for qkv projections if True.
- init_values: Layer-scale init values (layer-scale enabled if not None).
- class_token: Use class token.
- no_embed_class: Don't include position embeddings for class (or reg) tokens.
- reg_tokens: Number of register tokens.
- fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
- drop_rate: Head dropout rate.
- pos_drop_rate: Position embedding dropout rate.
- attn_drop_rate: Attention dropout rate.
- drop_path_rate: Stochastic depth rate.
- weight_init: Weight initialization scheme.
- embed_layer: Patch embedding layer.
- norm_layer: Normalization layer.
- act_layer: MLP activation layer.
- block_fn: Transformer block layer.
90 def convert_to_expected_dim(self, inputs_): 91 inputs_ = inputs_[:, 1:, :] # removing the class tokens 92 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 93 rdim = inputs_.shape[1] 94 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 95 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 96 inputs_ = inputs_.permute(0, 3, 1, 2) 97 return inputs_
99 def forward_features(self, x): 100 B = x.shape[0] 101 x = self.patch_embed(x) 102 103 cls_tokens = self.cls_token.expand(B, -1, -1) 104 x = torch.cat((cls_tokens, x), dim=1) 105 106 x = x + self.pos_embed 107 x = self.pos_drop(x) 108 109 # chunks obtained for getting the projections for conjuctions with upsampling blocks 110 _chunks = int(self.depth / 4) 111 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 112 113 list_from_encoder = [] 114 for i, blk in enumerate(self.blocks): 115 x = blk(x) 116 if i in chunks_for_projection: 117 list_from_encoder.append(self.convert_to_expected_dim(x)) 118 119 x = self.convert_to_expected_dim(x) 120 return x, list_from_encoder[:3]
122 def forward(self, x): 123 x, list_from_encoder = self.forward_features(x) 124 return x, list_from_encoder
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- timm.models.vision_transformer.VisionTransformer
- dynamic_img_size
- num_classes
- global_pool
- num_prefix_tokens
- num_reg_tokens
- has_class_token
- no_embed_class
- grad_checkpointing
- patch_embed
- cls_token
- reg_token
- pos_embed
- pos_drop
- norm_pre
- blocks
- norm
- fc_norm
- head_drop
- head
- init_weights
- load_pretrained
- no_weight_decay
- group_matcher
- set_grad_checkpointing
- get_classifier
- reset_classifier
- get_intermediate_layers
- forward_head
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
127def get_vision_transformer(backbone: str, model: str, img_size: int = 1024): 128 if backbone == "sam": 129 if model == "vit_b": 130 encoder = ViT_Sam( 131 depth=12, embed_dim=768, img_size=1024, mlp_ratio=4, 132 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 133 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 134 global_attn_indexes=[2, 5, 8, 11], 135 window_size=14, out_chans=256, 136 ) 137 elif model == "vit_l": 138 encoder = ViT_Sam( 139 depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4, 140 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 141 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 142 global_attn_indexes=[5, 11, 17, 23], 143 window_size=14, out_chans=256 144 ) 145 elif model == "vit_h": 146 encoder = ViT_Sam( 147 depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4, 148 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 149 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 150 global_attn_indexes=[7, 15, 23, 31], 151 window_size=14, out_chans=256 152 ) 153 else: 154 raise ValueError(f"{model} is not supported by SAM. Currently vit_b, vit_l, vit_h are supported.") 155 156 elif backbone == "mae": 157 if model == "vit_b": 158 encoder = ViT_MAE( 159 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 160 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 161 ) 162 elif model == "vit_l": 163 encoder = ViT_MAE( 164 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 165 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 166 ) 167 elif model == "vit_h": 168 encoder = ViT_MAE( 169 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 170 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 171 ) 172 else: 173 raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.") 174 175 else: 176 raise ValueError("The UNETR supported backbones are `sam` or `mae`. Please choose either of the two") 177 178 return encoder