torch_em.model.vit

  1import torch
  2import torch.nn as nn
  3
  4from typing import Tuple
  5from functools import partial
  6
  7# we catch ImportErrors here because segment_anything, micro_sam and timm should
  8# only be optional dependencies for torch_em
  9try:
 10    from segment_anything.modeling import ImageEncoderViT
 11    _sam_import_success = True
 12except ImportError:
 13    ImageEncoderViT = object
 14    _sam_import_success = False
 15
 16try:
 17    from timm.models.vision_transformer import VisionTransformer
 18    _timm_import_success = True
 19except ImportError:
 20    VisionTransformer = object
 21    _timm_import_success = False
 22
 23
 24class ViT_Sam(ImageEncoderViT):
 25    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643):
 26    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 27    """
 28    def __init__(
 29        self,
 30        in_chans: int = 3,
 31        embed_dim: int = 768,
 32        global_attn_indexes: Tuple[int, ...] = ...,
 33        **kwargs
 34    ) -> None:
 35        if not _sam_import_success:
 36            raise RuntimeError(
 37                "The vision transformer backend can only be initialized if segment anything is installed."
 38                "Please install segment anything from https://github.com/facebookresearch/segment-anything."
 39                "and then rerun your code."
 40            )
 41
 42        super().__init__(
 43            embed_dim=embed_dim,
 44            global_attn_indexes=global_attn_indexes,
 45            **kwargs,
 46        )
 47        self.chunks_for_projection = global_attn_indexes
 48        self.in_chans = in_chans
 49        self.embed_dim = embed_dim
 50
 51    def forward(self, x: torch.Tensor) -> torch.Tensor:
 52        x = self.patch_embed(x)
 53        if self.pos_embed is not None:
 54            x = x + self.pos_embed
 55
 56        list_from_encoder = []
 57        for i, blk in enumerate(self.blocks):
 58            x = blk(x)
 59            if i in self.chunks_for_projection:
 60                list_from_encoder.append(x)
 61
 62        x = x.permute(0, 3, 1, 2)
 63        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
 64        return x, list_from_encoder[:3]
 65
 66
 67class ViT_MAE(VisionTransformer):
 68    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377)
 69    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
 70    """
 71    def __init__(
 72            self,
 73            img_size=1024,  # chosen to match our experiments with segment anything
 74            in_chans=3,
 75            depth=12,
 76            **kwargs
 77    ):
 78        if not _timm_import_success:
 79            raise RuntimeError(
 80                "The vision transformer backend can only be initialized if timm is installed."
 81                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/."
 82                "and then rerun your code"
 83            )
 84        super().__init__(img_size=img_size, depth=depth, **kwargs)
 85        self.img_size = img_size
 86        self.in_chans = in_chans
 87        self.depth = depth
 88
 89    def convert_to_expected_dim(self, inputs_):
 90        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
 91        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
 92        rdim = inputs_.shape[1]
 93        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
 94        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
 95        inputs_ = inputs_.permute(0, 3, 1, 2)
 96        return inputs_
 97
 98    def forward_features(self, x):
 99        B = x.shape[0]
100        x = self.patch_embed(x)
101
102        cls_tokens = self.cls_token.expand(B, -1, -1)
103        x = torch.cat((cls_tokens, x), dim=1)
104
105        x = x + self.pos_embed
106        x = self.pos_drop(x)
107
108        # chunks obtained for getting the projections for conjuctions with upsampling blocks
109        _chunks = int(self.depth / 4)
110        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
111
112        list_from_encoder = []
113        for i, blk in enumerate(self.blocks):
114            x = blk(x)
115            if i in chunks_for_projection:
116                list_from_encoder.append(self.convert_to_expected_dim(x))
117
118        x = self.convert_to_expected_dim(x)
119        return x, list_from_encoder[:3]
120
121    def forward(self, x):
122        x, list_from_encoder = self.forward_features(x)
123        return x, list_from_encoder
124
125
126def get_vision_transformer(backbone: str, model: str, img_size: int = 1024):
127    if backbone == "sam":
128        if model == "vit_b":
129            encoder = ViT_Sam(
130                depth=12, embed_dim=768, img_size=1024, mlp_ratio=4,
131                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
132                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
133                global_attn_indexes=[2, 5, 8, 11],
134                window_size=14, out_chans=256,
135            )
136        elif model == "vit_l":
137            encoder = ViT_Sam(
138                depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4,
139                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
140                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
141                global_attn_indexes=[5, 11, 17, 23],
142                window_size=14,  out_chans=256
143            )
144        elif model == "vit_h":
145            encoder = ViT_Sam(
146                depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4,
147                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
148                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
149                global_attn_indexes=[7, 15, 23, 31],
150                window_size=14, out_chans=256
151            )
152        else:
153            raise ValueError(f"{model} is not supported by SAM. Currently vit_b, vit_l, vit_h are supported.")
154
155    elif backbone == "mae":
156        if model == "vit_b":
157            encoder = ViT_MAE(
158                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
159                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
160            )
161        elif model == "vit_l":
162            encoder = ViT_MAE(
163                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
164                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
165            )
166        elif model == "vit_h":
167            encoder = ViT_MAE(
168                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
169                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
170            )
171        else:
172            raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.")
173
174    else:
175        raise ValueError("The UNETR supported backbones are `sam` or `mae`. Please choose either of the two")
176
177    return encoder
class ViT_Sam(segment_anything.modeling.image_encoder.ImageEncoderViT):
25class ViT_Sam(ImageEncoderViT):
26    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643):
27    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
28    """
29    def __init__(
30        self,
31        in_chans: int = 3,
32        embed_dim: int = 768,
33        global_attn_indexes: Tuple[int, ...] = ...,
34        **kwargs
35    ) -> None:
36        if not _sam_import_success:
37            raise RuntimeError(
38                "The vision transformer backend can only be initialized if segment anything is installed."
39                "Please install segment anything from https://github.com/facebookresearch/segment-anything."
40                "and then rerun your code."
41            )
42
43        super().__init__(
44            embed_dim=embed_dim,
45            global_attn_indexes=global_attn_indexes,
46            **kwargs,
47        )
48        self.chunks_for_projection = global_attn_indexes
49        self.in_chans = in_chans
50        self.embed_dim = embed_dim
51
52    def forward(self, x: torch.Tensor) -> torch.Tensor:
53        x = self.patch_embed(x)
54        if self.pos_embed is not None:
55            x = x + self.pos_embed
56
57        list_from_encoder = []
58        for i, blk in enumerate(self.blocks):
59            x = blk(x)
60            if i in self.chunks_for_projection:
61                list_from_encoder.append(x)
62
63        x = x.permute(0, 3, 1, 2)
64        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
65        return x, list_from_encoder[:3]

Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643): https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

ViT_Sam( in_chans: int = 3, embed_dim: int = 768, global_attn_indexes: Tuple[int, ...] = Ellipsis, **kwargs)
29    def __init__(
30        self,
31        in_chans: int = 3,
32        embed_dim: int = 768,
33        global_attn_indexes: Tuple[int, ...] = ...,
34        **kwargs
35    ) -> None:
36        if not _sam_import_success:
37            raise RuntimeError(
38                "The vision transformer backend can only be initialized if segment anything is installed."
39                "Please install segment anything from https://github.com/facebookresearch/segment-anything."
40                "and then rerun your code."
41            )
42
43        super().__init__(
44            embed_dim=embed_dim,
45            global_attn_indexes=global_attn_indexes,
46            **kwargs,
47        )
48        self.chunks_for_projection = global_attn_indexes
49        self.in_chans = in_chans
50        self.embed_dim = embed_dim
Arguments:
  • img_size (int): Input image size.
  • patch_size (int): Patch size.
  • in_chans (int): Number of input image channels.
  • embed_dim (int): Patch embedding dimension.
  • depth (int): Depth of ViT.
  • num_heads (int): Number of attention heads in each ViT block.
  • mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  • qkv_bias (bool): If True, add a learnable bias to query, key, value.
  • norm_layer (nn.Module): Normalization layer.
  • act_layer (nn.Module): Activation layer.
  • use_abs_pos (bool): If True, use absolute positional embeddings.
  • use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  • rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  • window_size (int): Window size for window attention blocks.
  • global_attn_indexes (list): Indexes for blocks using global attention.
chunks_for_projection
in_chans
embed_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
52    def forward(self, x: torch.Tensor) -> torch.Tensor:
53        x = self.patch_embed(x)
54        if self.pos_embed is not None:
55            x = x + self.pos_embed
56
57        list_from_encoder = []
58        for i, blk in enumerate(self.blocks):
59            x = blk(x)
60            if i in self.chunks_for_projection:
61                list_from_encoder.append(x)
62
63        x = x.permute(0, 3, 1, 2)
64        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
65        return x, list_from_encoder[:3]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
segment_anything.modeling.image_encoder.ImageEncoderViT
img_size
patch_embed
pos_embed
blocks
neck
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ViT_MAE(timm.models.vision_transformer.VisionTransformer):
 68class ViT_MAE(VisionTransformer):
 69    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377)
 70    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
 71    """
 72    def __init__(
 73            self,
 74            img_size=1024,  # chosen to match our experiments with segment anything
 75            in_chans=3,
 76            depth=12,
 77            **kwargs
 78    ):
 79        if not _timm_import_success:
 80            raise RuntimeError(
 81                "The vision transformer backend can only be initialized if timm is installed."
 82                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/."
 83                "and then rerun your code"
 84            )
 85        super().__init__(img_size=img_size, depth=depth, **kwargs)
 86        self.img_size = img_size
 87        self.in_chans = in_chans
 88        self.depth = depth
 89
 90    def convert_to_expected_dim(self, inputs_):
 91        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
 92        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
 93        rdim = inputs_.shape[1]
 94        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
 95        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
 96        inputs_ = inputs_.permute(0, 3, 1, 2)
 97        return inputs_
 98
 99    def forward_features(self, x):
100        B = x.shape[0]
101        x = self.patch_embed(x)
102
103        cls_tokens = self.cls_token.expand(B, -1, -1)
104        x = torch.cat((cls_tokens, x), dim=1)
105
106        x = x + self.pos_embed
107        x = self.pos_drop(x)
108
109        # chunks obtained for getting the projections for conjuctions with upsampling blocks
110        _chunks = int(self.depth / 4)
111        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
112
113        list_from_encoder = []
114        for i, blk in enumerate(self.blocks):
115            x = blk(x)
116            if i in chunks_for_projection:
117                list_from_encoder.append(self.convert_to_expected_dim(x))
118
119        x = self.convert_to_expected_dim(x)
120        return x, list_from_encoder[:3]
121
122    def forward(self, x):
123        x, list_from_encoder = self.forward_features(x)
124        return x, list_from_encoder

Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377) https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53

ViT_MAE(img_size=1024, in_chans=3, depth=12, **kwargs)
72    def __init__(
73            self,
74            img_size=1024,  # chosen to match our experiments with segment anything
75            in_chans=3,
76            depth=12,
77            **kwargs
78    ):
79        if not _timm_import_success:
80            raise RuntimeError(
81                "The vision transformer backend can only be initialized if timm is installed."
82                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/."
83                "and then rerun your code"
84            )
85        super().__init__(img_size=img_size, depth=depth, **kwargs)
86        self.img_size = img_size
87        self.in_chans = in_chans
88        self.depth = depth
Arguments:
  • img_size: Input image size.
  • patch_size: Patch size.
  • in_chans: Number of image input channels.
  • num_classes: Mumber of classes for classification head.
  • global_pool: Type of global pooling for final sequence (default: 'token').
  • embed_dim: Transformer embedding dimension.
  • depth: Depth of transformer.
  • num_heads: Number of attention heads.
  • mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  • qkv_bias: Enable bias for qkv projections if True.
  • init_values: Layer-scale init values (layer-scale enabled if not None).
  • class_token: Use class token.
  • no_embed_class: Don't include position embeddings for class (or reg) tokens.
  • reg_tokens: Number of register tokens.
  • fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
  • drop_rate: Head dropout rate.
  • pos_drop_rate: Position embedding dropout rate.
  • attn_drop_rate: Attention dropout rate.
  • drop_path_rate: Stochastic depth rate.
  • weight_init: Weight initialization scheme.
  • embed_layer: Patch embedding layer.
  • norm_layer: Normalization layer.
  • act_layer: MLP activation layer.
  • block_fn: Transformer block layer.
img_size
in_chans
depth
def convert_to_expected_dim(self, inputs_):
90    def convert_to_expected_dim(self, inputs_):
91        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
92        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
93        rdim = inputs_.shape[1]
94        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
95        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
96        inputs_ = inputs_.permute(0, 3, 1, 2)
97        return inputs_
def forward_features(self, x):
 99    def forward_features(self, x):
100        B = x.shape[0]
101        x = self.patch_embed(x)
102
103        cls_tokens = self.cls_token.expand(B, -1, -1)
104        x = torch.cat((cls_tokens, x), dim=1)
105
106        x = x + self.pos_embed
107        x = self.pos_drop(x)
108
109        # chunks obtained for getting the projections for conjuctions with upsampling blocks
110        _chunks = int(self.depth / 4)
111        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
112
113        list_from_encoder = []
114        for i, blk in enumerate(self.blocks):
115            x = blk(x)
116            if i in chunks_for_projection:
117                list_from_encoder.append(self.convert_to_expected_dim(x))
118
119        x = self.convert_to_expected_dim(x)
120        return x, list_from_encoder[:3]
def forward(self, x):
122    def forward(self, x):
123        x, list_from_encoder = self.forward_features(x)
124        return x, list_from_encoder

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
timm.models.vision_transformer.VisionTransformer
dynamic_img_size
num_classes
global_pool
num_prefix_tokens
num_reg_tokens
has_class_token
no_embed_class
grad_checkpointing
patch_embed
cls_token
reg_token
pos_embed
pos_drop
norm_pre
blocks
norm
fc_norm
head_drop
head
init_weights
load_pretrained
no_weight_decay
group_matcher
set_grad_checkpointing
get_classifier
reset_classifier
get_intermediate_layers
forward_head
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def get_vision_transformer(backbone: str, model: str, img_size: int = 1024):
127def get_vision_transformer(backbone: str, model: str, img_size: int = 1024):
128    if backbone == "sam":
129        if model == "vit_b":
130            encoder = ViT_Sam(
131                depth=12, embed_dim=768, img_size=1024, mlp_ratio=4,
132                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
133                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
134                global_attn_indexes=[2, 5, 8, 11],
135                window_size=14, out_chans=256,
136            )
137        elif model == "vit_l":
138            encoder = ViT_Sam(
139                depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4,
140                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
141                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
142                global_attn_indexes=[5, 11, 17, 23],
143                window_size=14,  out_chans=256
144            )
145        elif model == "vit_h":
146            encoder = ViT_Sam(
147                depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4,
148                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
149                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
150                global_attn_indexes=[7, 15, 23, 31],
151                window_size=14, out_chans=256
152            )
153        else:
154            raise ValueError(f"{model} is not supported by SAM. Currently vit_b, vit_l, vit_h are supported.")
155
156    elif backbone == "mae":
157        if model == "vit_b":
158            encoder = ViT_MAE(
159                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
160                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
161            )
162        elif model == "vit_l":
163            encoder = ViT_MAE(
164                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
165                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
166            )
167        elif model == "vit_h":
168            encoder = ViT_MAE(
169                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
170                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
171            )
172        else:
173            raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.")
174
175    else:
176        raise ValueError("The UNETR supported backbones are `sam` or `mae`. Please choose either of the two")
177
178    return encoder