torch_em.model.vit

  1from typing import Tuple
  2from functools import partial
  3
  4import torch
  5import torch.nn as nn
  6
  7# we catch ImportErrors here because segment_anything, micro_sam and timm should
  8# only be optional dependencies for torch_em
  9try:
 10    from segment_anything.modeling import ImageEncoderViT
 11    _sam_import_success = True
 12except ImportError:
 13    ImageEncoderViT = object
 14    _sam_import_success = False
 15
 16try:
 17    from timm.models.vision_transformer import VisionTransformer
 18    _timm_import_success = True
 19except ImportError:
 20    VisionTransformer = object
 21    _timm_import_success = False
 22
 23
 24class ViT_Sam(ImageEncoderViT):
 25    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
 26
 27    Based on:
 28    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 29
 30    Args:
 31        in_chans: The number of input channels.
 32        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 33        global_attn_indexes: The global attention indices.
 34        kwargs: Keyword arguments for the image encoder base class.
 35    """
 36    def __init__(
 37        self,
 38        in_chans: int = 3,
 39        embed_dim: int = 768,
 40        global_attn_indexes: Tuple[int, ...] = ...,
 41        **kwargs,
 42    ) -> None:
 43        if not _sam_import_success:
 44            raise RuntimeError(
 45                "The vision transformer backend can only be initialized if segment anything is installed."
 46                "Please install segment anything from https://github.com/facebookresearch/segment-anything."
 47                "and then rerun your code."
 48            )
 49
 50        super().__init__(
 51            embed_dim=embed_dim,
 52            global_attn_indexes=global_attn_indexes,
 53            **kwargs,
 54        )
 55        self.chunks_for_projection = global_attn_indexes
 56        self.in_chans = in_chans
 57        self.embed_dim = embed_dim
 58
 59    def forward(self, x: torch.Tensor) -> torch.Tensor:
 60        """Apply the vision transformer to input data.
 61
 62        Args:
 63            x: The input data.
 64
 65        Returns:
 66            The vision transformer output.
 67        """
 68        x = self.patch_embed(x)
 69        if self.pos_embed is not None:
 70            x = x + self.pos_embed
 71
 72        list_from_encoder = []
 73        for i, blk in enumerate(self.blocks):
 74            x = blk(x)
 75            if i in self.chunks_for_projection:
 76                list_from_encoder.append(x)
 77
 78        x = x.permute(0, 3, 1, 2)
 79        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
 80        return x, list_from_encoder[:3]
 81
 82
 83class ViT_MAE(VisionTransformer):
 84    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
 85
 86    Based on:
 87    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
 88
 89    Args:
 90        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 91        in_chans: The number of input channels.
 92        depth: The depth of the vision transformer.
 93        kwargs: Additional keyword arguments for the vision transformer base class.
 94    """
 95    def __init__(
 96        self,
 97        img_size: int = 1024,  # chosen to match our experiments with segment anything
 98        in_chans: int = 3,
 99        depth: int = 12,
100        **kwargs
101    ):
102        if not _timm_import_success:
103            raise RuntimeError(
104                "The vision transformer backend can only be initialized if timm is installed."
105                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/."
106                "and then rerun your code"
107            )
108        super().__init__(img_size=img_size, depth=depth, **kwargs)
109        self.img_size = img_size
110        self.in_chans = in_chans
111        self.depth = depth
112
113    def convert_to_expected_dim(self, inputs_):
114        """@private
115        """
116        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
117        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
118        rdim = inputs_.shape[1]
119        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
120        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
121        inputs_ = inputs_.permute(0, 3, 1, 2)
122        return inputs_
123
124    def forward_features(self, x):
125        """@private
126        """
127        B = x.shape[0]
128        x = self.patch_embed(x)
129
130        cls_tokens = self.cls_token.expand(B, -1, -1)
131        x = torch.cat((cls_tokens, x), dim=1)
132
133        x = x + self.pos_embed
134        x = self.pos_drop(x)
135
136        # chunks obtained for getting the projections for conjuctions with upsampling blocks
137        _chunks = int(self.depth / 4)
138        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
139
140        list_from_encoder = []
141        for i, blk in enumerate(self.blocks):
142            x = blk(x)
143            if i in chunks_for_projection:
144                list_from_encoder.append(self.convert_to_expected_dim(x))
145
146        x = self.convert_to_expected_dim(x)
147        return x, list_from_encoder[:3]
148
149    def forward(self, x: torch.Tensor) -> torch.Tensor:
150        """Apply the vision transformer to input data.
151
152        Args:
153            x: The input data.
154
155        Returns:
156            The vision transformer output.
157        """
158        x, list_from_encoder = self.forward_features(x)
159        return x, list_from_encoder
160
161
162def get_vision_transformer(backbone: str, model: str, img_size: int = 1024) -> nn.Module:
163    """Get vision transformer encoder.
164
165    Args:
166        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
167        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
168        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
169
170    Returns:
171        The vision transformer.
172    """
173    if backbone == "sam":
174        if model == "vit_b":
175            encoder = ViT_Sam(
176                depth=12, embed_dim=768, img_size=1024, mlp_ratio=4,
177                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
178                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
179                global_attn_indexes=[2, 5, 8, 11],
180                window_size=14, out_chans=256,
181            )
182        elif model == "vit_l":
183            encoder = ViT_Sam(
184                depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4,
185                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
186                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
187                global_attn_indexes=[5, 11, 17, 23],
188                window_size=14,  out_chans=256
189            )
190        elif model == "vit_h":
191            encoder = ViT_Sam(
192                depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4,
193                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
194                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
195                global_attn_indexes=[7, 15, 23, 31],
196                window_size=14, out_chans=256
197            )
198        else:
199            raise ValueError(f"{model} is not supported by SAM. Currently vit_b, vit_l, vit_h are supported.")
200
201    elif backbone == "mae":
202        if model == "vit_b":
203            encoder = ViT_MAE(
204                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
205                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
206            )
207        elif model == "vit_l":
208            encoder = ViT_MAE(
209                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
210                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
211            )
212        elif model == "vit_h":
213            encoder = ViT_MAE(
214                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
215                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
216            )
217        else:
218            raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.")
219
220    else:
221        raise ValueError("The UNETR supported backbones are `sam` or `mae`. Please choose either of the two.")
222
223    return encoder
class ViT_Sam:
25class ViT_Sam(ImageEncoderViT):
26    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
27
28    Based on:
29    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
30
31    Args:
32        in_chans: The number of input channels.
33        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
34        global_attn_indexes: The global attention indices.
35        kwargs: Keyword arguments for the image encoder base class.
36    """
37    def __init__(
38        self,
39        in_chans: int = 3,
40        embed_dim: int = 768,
41        global_attn_indexes: Tuple[int, ...] = ...,
42        **kwargs,
43    ) -> None:
44        if not _sam_import_success:
45            raise RuntimeError(
46                "The vision transformer backend can only be initialized if segment anything is installed."
47                "Please install segment anything from https://github.com/facebookresearch/segment-anything."
48                "and then rerun your code."
49            )
50
51        super().__init__(
52            embed_dim=embed_dim,
53            global_attn_indexes=global_attn_indexes,
54            **kwargs,
55        )
56        self.chunks_for_projection = global_attn_indexes
57        self.in_chans = in_chans
58        self.embed_dim = embed_dim
59
60    def forward(self, x: torch.Tensor) -> torch.Tensor:
61        """Apply the vision transformer to input data.
62
63        Args:
64            x: The input data.
65
66        Returns:
67            The vision transformer output.
68        """
69        x = self.patch_embed(x)
70        if self.pos_embed is not None:
71            x = x + self.pos_embed
72
73        list_from_encoder = []
74        for i, blk in enumerate(self.blocks):
75            x = blk(x)
76            if i in self.chunks_for_projection:
77                list_from_encoder.append(x)
78
79        x = x.permute(0, 3, 1, 2)
80        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
81        return x, list_from_encoder[:3]

Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).

Based on: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

Arguments:
  • in_chans: The number of input channels.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • global_attn_indexes: The global attention indices.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam( in_chans: int = 3, embed_dim: int = 768, global_attn_indexes: Tuple[int, ...] = Ellipsis, **kwargs)
37    def __init__(
38        self,
39        in_chans: int = 3,
40        embed_dim: int = 768,
41        global_attn_indexes: Tuple[int, ...] = ...,
42        **kwargs,
43    ) -> None:
44        if not _sam_import_success:
45            raise RuntimeError(
46                "The vision transformer backend can only be initialized if segment anything is installed."
47                "Please install segment anything from https://github.com/facebookresearch/segment-anything."
48                "and then rerun your code."
49            )
50
51        super().__init__(
52            embed_dim=embed_dim,
53            global_attn_indexes=global_attn_indexes,
54            **kwargs,
55        )
56        self.chunks_for_projection = global_attn_indexes
57        self.in_chans = in_chans
58        self.embed_dim = embed_dim
chunks_for_projection
in_chans
embed_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
60    def forward(self, x: torch.Tensor) -> torch.Tensor:
61        """Apply the vision transformer to input data.
62
63        Args:
64            x: The input data.
65
66        Returns:
67            The vision transformer output.
68        """
69        x = self.patch_embed(x)
70        if self.pos_embed is not None:
71            x = x + self.pos_embed
72
73        list_from_encoder = []
74        for i, blk in enumerate(self.blocks):
75            x = blk(x)
76            if i in self.chunks_for_projection:
77                list_from_encoder.append(x)
78
79        x = x.permute(0, 3, 1, 2)
80        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
81        return x, list_from_encoder[:3]

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_MAE:
 84class ViT_MAE(VisionTransformer):
 85    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
 86
 87    Based on:
 88    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
 89
 90    Args:
 91        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 92        in_chans: The number of input channels.
 93        depth: The depth of the vision transformer.
 94        kwargs: Additional keyword arguments for the vision transformer base class.
 95    """
 96    def __init__(
 97        self,
 98        img_size: int = 1024,  # chosen to match our experiments with segment anything
 99        in_chans: int = 3,
100        depth: int = 12,
101        **kwargs
102    ):
103        if not _timm_import_success:
104            raise RuntimeError(
105                "The vision transformer backend can only be initialized if timm is installed."
106                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/."
107                "and then rerun your code"
108            )
109        super().__init__(img_size=img_size, depth=depth, **kwargs)
110        self.img_size = img_size
111        self.in_chans = in_chans
112        self.depth = depth
113
114    def convert_to_expected_dim(self, inputs_):
115        """@private
116        """
117        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
118        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
119        rdim = inputs_.shape[1]
120        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
121        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
122        inputs_ = inputs_.permute(0, 3, 1, 2)
123        return inputs_
124
125    def forward_features(self, x):
126        """@private
127        """
128        B = x.shape[0]
129        x = self.patch_embed(x)
130
131        cls_tokens = self.cls_token.expand(B, -1, -1)
132        x = torch.cat((cls_tokens, x), dim=1)
133
134        x = x + self.pos_embed
135        x = self.pos_drop(x)
136
137        # chunks obtained for getting the projections for conjuctions with upsampling blocks
138        _chunks = int(self.depth / 4)
139        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
140
141        list_from_encoder = []
142        for i, blk in enumerate(self.blocks):
143            x = blk(x)
144            if i in chunks_for_projection:
145                list_from_encoder.append(self.convert_to_expected_dim(x))
146
147        x = self.convert_to_expected_dim(x)
148        return x, list_from_encoder[:3]
149
150    def forward(self, x: torch.Tensor) -> torch.Tensor:
151        """Apply the vision transformer to input data.
152
153        Args:
154            x: The input data.
155
156        Returns:
157            The vision transformer output.
158        """
159        x, list_from_encoder = self.forward_features(x)
160        return x, list_from_encoder

Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).

Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • in_chans: The number of input channels.
  • depth: The depth of the vision transformer.
  • kwargs: Additional keyword arguments for the vision transformer base class.
ViT_MAE(img_size: int = 1024, in_chans: int = 3, depth: int = 12, **kwargs)
 96    def __init__(
 97        self,
 98        img_size: int = 1024,  # chosen to match our experiments with segment anything
 99        in_chans: int = 3,
100        depth: int = 12,
101        **kwargs
102    ):
103        if not _timm_import_success:
104            raise RuntimeError(
105                "The vision transformer backend can only be initialized if timm is installed."
106                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/."
107                "and then rerun your code"
108            )
109        super().__init__(img_size=img_size, depth=depth, **kwargs)
110        self.img_size = img_size
111        self.in_chans = in_chans
112        self.depth = depth
img_size
in_chans
depth
def forward(self, x: torch.Tensor) -> torch.Tensor:
150    def forward(self, x: torch.Tensor) -> torch.Tensor:
151        """Apply the vision transformer to input data.
152
153        Args:
154            x: The input data.
155
156        Returns:
157            The vision transformer output.
158        """
159        x, list_from_encoder = self.forward_features(x)
160        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

def get_vision_transformer( backbone: str, model: str, img_size: int = 1024) -> torch.nn.modules.module.Module:
163def get_vision_transformer(backbone: str, model: str, img_size: int = 1024) -> nn.Module:
164    """Get vision transformer encoder.
165
166    Args:
167        backbone: The name of the vision transformer implementation. One of "sam" or "mae".
168        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
169        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
170
171    Returns:
172        The vision transformer.
173    """
174    if backbone == "sam":
175        if model == "vit_b":
176            encoder = ViT_Sam(
177                depth=12, embed_dim=768, img_size=1024, mlp_ratio=4,
178                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
179                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
180                global_attn_indexes=[2, 5, 8, 11],
181                window_size=14, out_chans=256,
182            )
183        elif model == "vit_l":
184            encoder = ViT_Sam(
185                depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4,
186                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
187                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
188                global_attn_indexes=[5, 11, 17, 23],
189                window_size=14,  out_chans=256
190            )
191        elif model == "vit_h":
192            encoder = ViT_Sam(
193                depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4,
194                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
195                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
196                global_attn_indexes=[7, 15, 23, 31],
197                window_size=14, out_chans=256
198            )
199        else:
200            raise ValueError(f"{model} is not supported by SAM. Currently vit_b, vit_l, vit_h are supported.")
201
202    elif backbone == "mae":
203        if model == "vit_b":
204            encoder = ViT_MAE(
205                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
206                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
207            )
208        elif model == "vit_l":
209            encoder = ViT_MAE(
210                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
211                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
212            )
213        elif model == "vit_h":
214            encoder = ViT_MAE(
215                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
216                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
217            )
218        else:
219            raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.")
220
221    else:
222        raise ValueError("The UNETR supported backbones are `sam` or `mae`. Please choose either of the two.")
223
224    return encoder

Get vision transformer encoder.

Arguments:
  • backbone: The name of the vision transformer implementation. One of "sam" or "mae".
  • model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
Returns:

The vision transformer.