torch_em.model.vit
1from typing import Tuple 2from functools import partial 3 4import torch 5import torch.nn as nn 6 7# we catch ImportErrors here because segment_anything, micro_sam and timm should 8# only be optional dependencies for torch_em 9try: 10 from segment_anything.modeling import ImageEncoderViT 11 _sam_import_success = True 12except ImportError: 13 ImageEncoderViT = object 14 _sam_import_success = False 15 16try: 17 from timm.models.vision_transformer import VisionTransformer 18 _timm_import_success = True 19except ImportError: 20 VisionTransformer = object 21 _timm_import_success = False 22 23 24class ViT_Sam(ImageEncoderViT): 25 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 26 27 Based on: 28 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 29 30 Args: 31 in_chans: The number of input channels. 32 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 33 global_attn_indexes: The global attention indices. 34 kwargs: Keyword arguments for the image encoder base class. 35 """ 36 def __init__( 37 self, 38 in_chans: int = 3, 39 embed_dim: int = 768, 40 global_attn_indexes: Tuple[int, ...] = ..., 41 **kwargs, 42 ) -> None: 43 if not _sam_import_success: 44 raise RuntimeError( 45 "The vision transformer backend can only be initialized if segment anything is installed." 46 "Please install segment anything from https://github.com/facebookresearch/segment-anything." 47 "and then rerun your code." 48 ) 49 50 super().__init__( 51 embed_dim=embed_dim, 52 global_attn_indexes=global_attn_indexes, 53 **kwargs, 54 ) 55 self.chunks_for_projection = global_attn_indexes 56 self.in_chans = in_chans 57 self.embed_dim = embed_dim 58 59 def forward(self, x: torch.Tensor) -> torch.Tensor: 60 """Apply the vision transformer to input data. 61 62 Args: 63 x: The input data. 64 65 Returns: 66 The vision transformer output. 67 """ 68 x = self.patch_embed(x) 69 if self.pos_embed is not None: 70 x = x + self.pos_embed 71 72 list_from_encoder = [] 73 for i, blk in enumerate(self.blocks): 74 x = blk(x) 75 if i in self.chunks_for_projection: 76 list_from_encoder.append(x) 77 78 x = x.permute(0, 3, 1, 2) 79 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 80 return x, list_from_encoder[:3] 81 82 83class ViT_MAE(VisionTransformer): 84 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 85 86 Based on: 87 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 88 89 Args: 90 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 91 in_chans: The number of input channels. 92 depth: The depth of the vision transformer. 93 kwargs: Additional keyword arguments for the vision transformer base class. 94 """ 95 def __init__( 96 self, 97 img_size: int = 1024, # chosen to match our experiments with segment anything 98 in_chans: int = 3, 99 depth: int = 12, 100 **kwargs 101 ): 102 if not _timm_import_success: 103 raise RuntimeError( 104 "The vision transformer backend can only be initialized if timm is installed." 105 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/." 106 "and then rerun your code" 107 ) 108 super().__init__(img_size=img_size, depth=depth, **kwargs) 109 self.img_size = img_size 110 self.in_chans = in_chans 111 self.depth = depth 112 113 def convert_to_expected_dim(self, inputs_): 114 """@private 115 """ 116 inputs_ = inputs_[:, 1:, :] # removing the class tokens 117 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 118 rdim = inputs_.shape[1] 119 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 120 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 121 inputs_ = inputs_.permute(0, 3, 1, 2) 122 return inputs_ 123 124 def forward_features(self, x): 125 """@private 126 """ 127 B = x.shape[0] 128 x = self.patch_embed(x) 129 130 cls_tokens = self.cls_token.expand(B, -1, -1) 131 x = torch.cat((cls_tokens, x), dim=1) 132 133 x = x + self.pos_embed 134 x = self.pos_drop(x) 135 136 # chunks obtained for getting the projections for conjuctions with upsampling blocks 137 _chunks = int(self.depth / 4) 138 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 139 140 list_from_encoder = [] 141 for i, blk in enumerate(self.blocks): 142 x = blk(x) 143 if i in chunks_for_projection: 144 list_from_encoder.append(self.convert_to_expected_dim(x)) 145 146 x = self.convert_to_expected_dim(x) 147 return x, list_from_encoder[:3] 148 149 def forward(self, x: torch.Tensor) -> torch.Tensor: 150 """Apply the vision transformer to input data. 151 152 Args: 153 x: The input data. 154 155 Returns: 156 The vision transformer output. 157 """ 158 x, list_from_encoder = self.forward_features(x) 159 return x, list_from_encoder 160 161 162def get_vision_transformer(backbone: str, model: str, img_size: int = 1024) -> nn.Module: 163 """Get vision transformer encoder. 164 165 Args: 166 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 167 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 168 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 169 170 Returns: 171 The vision transformer. 172 """ 173 if backbone == "sam": 174 if model == "vit_b": 175 encoder = ViT_Sam( 176 depth=12, embed_dim=768, img_size=1024, mlp_ratio=4, 177 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 178 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 179 global_attn_indexes=[2, 5, 8, 11], 180 window_size=14, out_chans=256, 181 ) 182 elif model == "vit_l": 183 encoder = ViT_Sam( 184 depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4, 185 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 186 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 187 global_attn_indexes=[5, 11, 17, 23], 188 window_size=14, out_chans=256 189 ) 190 elif model == "vit_h": 191 encoder = ViT_Sam( 192 depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4, 193 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 194 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 195 global_attn_indexes=[7, 15, 23, 31], 196 window_size=14, out_chans=256 197 ) 198 else: 199 raise ValueError(f"{model} is not supported by SAM. Currently vit_b, vit_l, vit_h are supported.") 200 201 elif backbone == "mae": 202 if model == "vit_b": 203 encoder = ViT_MAE( 204 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 205 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 206 ) 207 elif model == "vit_l": 208 encoder = ViT_MAE( 209 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 210 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 211 ) 212 elif model == "vit_h": 213 encoder = ViT_MAE( 214 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 215 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 216 ) 217 else: 218 raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.") 219 220 else: 221 raise ValueError("The UNETR supported backbones are `sam` or `mae`. Please choose either of the two.") 222 223 return encoder
class
ViT_Sam:
25class ViT_Sam(ImageEncoderViT): 26 """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643). 27 28 Based on: 29 https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py 30 31 Args: 32 in_chans: The number of input channels. 33 embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer. 34 global_attn_indexes: The global attention indices. 35 kwargs: Keyword arguments for the image encoder base class. 36 """ 37 def __init__( 38 self, 39 in_chans: int = 3, 40 embed_dim: int = 768, 41 global_attn_indexes: Tuple[int, ...] = ..., 42 **kwargs, 43 ) -> None: 44 if not _sam_import_success: 45 raise RuntimeError( 46 "The vision transformer backend can only be initialized if segment anything is installed." 47 "Please install segment anything from https://github.com/facebookresearch/segment-anything." 48 "and then rerun your code." 49 ) 50 51 super().__init__( 52 embed_dim=embed_dim, 53 global_attn_indexes=global_attn_indexes, 54 **kwargs, 55 ) 56 self.chunks_for_projection = global_attn_indexes 57 self.in_chans = in_chans 58 self.embed_dim = embed_dim 59 60 def forward(self, x: torch.Tensor) -> torch.Tensor: 61 """Apply the vision transformer to input data. 62 63 Args: 64 x: The input data. 65 66 Returns: 67 The vision transformer output. 68 """ 69 x = self.patch_embed(x) 70 if self.pos_embed is not None: 71 x = x + self.pos_embed 72 73 list_from_encoder = [] 74 for i, blk in enumerate(self.blocks): 75 x = blk(x) 76 if i in self.chunks_for_projection: 77 list_from_encoder.append(x) 78 79 x = x.permute(0, 3, 1, 2) 80 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 81 return x, list_from_encoder[:3]
Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
Arguments:
- in_chans: The number of input channels.
- embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
- global_attn_indexes: The global attention indices.
- kwargs: Keyword arguments for the image encoder base class.
ViT_Sam( in_chans: int = 3, embed_dim: int = 768, global_attn_indexes: Tuple[int, ...] = Ellipsis, **kwargs)
37 def __init__( 38 self, 39 in_chans: int = 3, 40 embed_dim: int = 768, 41 global_attn_indexes: Tuple[int, ...] = ..., 42 **kwargs, 43 ) -> None: 44 if not _sam_import_success: 45 raise RuntimeError( 46 "The vision transformer backend can only be initialized if segment anything is installed." 47 "Please install segment anything from https://github.com/facebookresearch/segment-anything." 48 "and then rerun your code." 49 ) 50 51 super().__init__( 52 embed_dim=embed_dim, 53 global_attn_indexes=global_attn_indexes, 54 **kwargs, 55 ) 56 self.chunks_for_projection = global_attn_indexes 57 self.in_chans = in_chans 58 self.embed_dim = embed_dim
def
forward(self, x: torch.Tensor) -> torch.Tensor:
60 def forward(self, x: torch.Tensor) -> torch.Tensor: 61 """Apply the vision transformer to input data. 62 63 Args: 64 x: The input data. 65 66 Returns: 67 The vision transformer output. 68 """ 69 x = self.patch_embed(x) 70 if self.pos_embed is not None: 71 x = x + self.pos_embed 72 73 list_from_encoder = [] 74 for i, blk in enumerate(self.blocks): 75 x = blk(x) 76 if i in self.chunks_for_projection: 77 list_from_encoder.append(x) 78 79 x = x.permute(0, 3, 1, 2) 80 list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder] 81 return x, list_from_encoder[:3]
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
class
ViT_MAE:
84class ViT_MAE(VisionTransformer): 85 """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377). 86 87 Based on: 88 https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53 89 90 Args: 91 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 92 in_chans: The number of input channels. 93 depth: The depth of the vision transformer. 94 kwargs: Additional keyword arguments for the vision transformer base class. 95 """ 96 def __init__( 97 self, 98 img_size: int = 1024, # chosen to match our experiments with segment anything 99 in_chans: int = 3, 100 depth: int = 12, 101 **kwargs 102 ): 103 if not _timm_import_success: 104 raise RuntimeError( 105 "The vision transformer backend can only be initialized if timm is installed." 106 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/." 107 "and then rerun your code" 108 ) 109 super().__init__(img_size=img_size, depth=depth, **kwargs) 110 self.img_size = img_size 111 self.in_chans = in_chans 112 self.depth = depth 113 114 def convert_to_expected_dim(self, inputs_): 115 """@private 116 """ 117 inputs_ = inputs_[:, 1:, :] # removing the class tokens 118 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 119 rdim = inputs_.shape[1] 120 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 121 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 122 inputs_ = inputs_.permute(0, 3, 1, 2) 123 return inputs_ 124 125 def forward_features(self, x): 126 """@private 127 """ 128 B = x.shape[0] 129 x = self.patch_embed(x) 130 131 cls_tokens = self.cls_token.expand(B, -1, -1) 132 x = torch.cat((cls_tokens, x), dim=1) 133 134 x = x + self.pos_embed 135 x = self.pos_drop(x) 136 137 # chunks obtained for getting the projections for conjuctions with upsampling blocks 138 _chunks = int(self.depth / 4) 139 chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1] 140 141 list_from_encoder = [] 142 for i, blk in enumerate(self.blocks): 143 x = blk(x) 144 if i in chunks_for_projection: 145 list_from_encoder.append(self.convert_to_expected_dim(x)) 146 147 x = self.convert_to_expected_dim(x) 148 return x, list_from_encoder[:3] 149 150 def forward(self, x: torch.Tensor) -> torch.Tensor: 151 """Apply the vision transformer to input data. 152 153 Args: 154 x: The input data. 155 156 Returns: 157 The vision transformer output. 158 """ 159 x, list_from_encoder = self.forward_features(x) 160 return x, list_from_encoder
Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- in_chans: The number of input channels.
- depth: The depth of the vision transformer.
- kwargs: Additional keyword arguments for the vision transformer base class.
ViT_MAE(img_size: int = 1024, in_chans: int = 3, depth: int = 12, **kwargs)
96 def __init__( 97 self, 98 img_size: int = 1024, # chosen to match our experiments with segment anything 99 in_chans: int = 3, 100 depth: int = 12, 101 **kwargs 102 ): 103 if not _timm_import_success: 104 raise RuntimeError( 105 "The vision transformer backend can only be initialized if timm is installed." 106 "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/." 107 "and then rerun your code" 108 ) 109 super().__init__(img_size=img_size, depth=depth, **kwargs) 110 self.img_size = img_size 111 self.in_chans = in_chans 112 self.depth = depth
def
forward(self, x: torch.Tensor) -> torch.Tensor:
150 def forward(self, x: torch.Tensor) -> torch.Tensor: 151 """Apply the vision transformer to input data. 152 153 Args: 154 x: The input data. 155 156 Returns: 157 The vision transformer output. 158 """ 159 x, list_from_encoder = self.forward_features(x) 160 return x, list_from_encoder
Apply the vision transformer to input data.
Arguments:
- x: The input data.
Returns:
The vision transformer output.
def
get_vision_transformer( backbone: str, model: str, img_size: int = 1024) -> torch.nn.modules.module.Module:
163def get_vision_transformer(backbone: str, model: str, img_size: int = 1024) -> nn.Module: 164 """Get vision transformer encoder. 165 166 Args: 167 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 168 model: The name of the model. One of "vit_b", "vit_l" or "vit_h". 169 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 170 171 Returns: 172 The vision transformer. 173 """ 174 if backbone == "sam": 175 if model == "vit_b": 176 encoder = ViT_Sam( 177 depth=12, embed_dim=768, img_size=1024, mlp_ratio=4, 178 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 179 num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True, 180 global_attn_indexes=[2, 5, 8, 11], 181 window_size=14, out_chans=256, 182 ) 183 elif model == "vit_l": 184 encoder = ViT_Sam( 185 depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4, 186 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 187 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 188 global_attn_indexes=[5, 11, 17, 23], 189 window_size=14, out_chans=256 190 ) 191 elif model == "vit_h": 192 encoder = ViT_Sam( 193 depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4, 194 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 195 num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True, 196 global_attn_indexes=[7, 15, 23, 31], 197 window_size=14, out_chans=256 198 ) 199 else: 200 raise ValueError(f"{model} is not supported by SAM. Currently vit_b, vit_l, vit_h are supported.") 201 202 elif backbone == "mae": 203 if model == "vit_b": 204 encoder = ViT_MAE( 205 img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 206 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 207 ) 208 elif model == "vit_l": 209 encoder = ViT_MAE( 210 img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 211 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 212 ) 213 elif model == "vit_h": 214 encoder = ViT_MAE( 215 img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 216 qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 217 ) 218 else: 219 raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.") 220 221 else: 222 raise ValueError("The UNETR supported backbones are `sam` or `mae`. Please choose either of the two.") 223 224 return encoder
Get vision transformer encoder.
Arguments:
- backbone: The name of the vision transformer implementation. One of "sam" or "mae".
- model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
Returns:
The vision transformer.