torch_em.model.vit

  1from typing import Tuple
  2from functools import partial
  3
  4import torch
  5import torch.nn as nn
  6
  7# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should
  8# only be optional dependencies for torch_em
  9try:
 10    from segment_anything.modeling import ImageEncoderViT
 11    _sam_import_success = True
 12except ImportError:
 13    ImageEncoderViT = object
 14    _sam_import_success = False
 15
 16try:
 17    from timm.models.vision_transformer import VisionTransformer, PatchEmbed
 18    _timm_import_success = True
 19except ImportError:
 20    VisionTransformer = object
 21    PatchEmbed = object
 22    _timm_import_success = False
 23
 24
 25class ViT_Sam(ImageEncoderViT):
 26    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
 27
 28    Based on:
 29    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 30
 31    Args:
 32        in_chans: The number of input channels.
 33        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 34        global_attn_indexes: The global attention indices.
 35        kwargs: Keyword arguments for the image encoder base class.
 36    """
 37    def __init__(
 38        self,
 39        in_chans: int = 3,
 40        embed_dim: int = 768,
 41        global_attn_indexes: Tuple[int, ...] = ...,
 42        **kwargs,
 43    ) -> None:
 44        if not _sam_import_success:
 45            raise RuntimeError(
 46                "The vision transformer backend can only be initialized if segment anything is installed."
 47                "Please install segment anything from https://github.com/facebookresearch/segment-anything."
 48                "and then rerun your code."
 49            )
 50
 51        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
 52        self.chunks_for_projection = global_attn_indexes
 53        self.in_chans = in_chans
 54        self.embed_dim = embed_dim
 55
 56    def forward(self, x: torch.Tensor) -> torch.Tensor:
 57        """Apply the vision transformer to input data.
 58
 59        Args:
 60            x: The input data.
 61
 62        Returns:
 63            The vision transformer output.
 64        """
 65        x = self.patch_embed(x)
 66        if self.pos_embed is not None:
 67            x = x + self.pos_embed
 68
 69        list_from_encoder = []
 70        for i, blk in enumerate(self.blocks):
 71            x = blk(x)
 72            if i in self.chunks_for_projection:
 73                list_from_encoder.append(x)
 74
 75        x = x.permute(0, 3, 1, 2)
 76        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
 77        return x, list_from_encoder[:3]
 78
 79
 80class ViT_MAE(VisionTransformer):
 81    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
 82
 83    Based on:
 84    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
 85
 86    Args:
 87        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 88        in_chans: The number of input channels.
 89        depth: The depth of the vision transformer.
 90        kwargs: Additional keyword arguments for the vision transformer base class.
 91    """
 92    def __init__(
 93        self,
 94        img_size: int = 1024,  # chosen to match our experiments with segment anything
 95        in_chans: int = 3,
 96        depth: int = 12,
 97        **kwargs
 98    ):
 99        if not _timm_import_success:
100            raise RuntimeError(
101                "The vision transformer backend can only be initialized if timm is installed."
102                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/."
103                "and then rerun your code"
104            )
105        super().__init__(img_size=img_size, depth=depth, **kwargs)
106        self.img_size = img_size
107        self.in_chans = in_chans
108        self.depth = depth
109
110    def convert_to_expected_dim(self, inputs_):
111        """@private
112        """
113        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
114        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
115        rdim = inputs_.shape[1]
116        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
117        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
118        inputs_ = inputs_.permute(0, 3, 1, 2)
119        return inputs_
120
121    def forward_features(self, x):
122        """@private
123        """
124        B = x.shape[0]
125        x = self.patch_embed(x)
126
127        cls_tokens = self.cls_token.expand(B, -1, -1)
128        x = torch.cat((cls_tokens, x), dim=1)
129
130        x = x + self.pos_embed
131        x = self.pos_drop(x)
132
133        # chunks obtained for getting the projections for conjuctions with upsampling blocks
134        _chunks = int(self.depth / 4)
135        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
136
137        list_from_encoder = []
138        for i, blk in enumerate(self.blocks):
139            x = blk(x)
140            if i in chunks_for_projection:
141                list_from_encoder.append(self.convert_to_expected_dim(x))
142
143        x = self.convert_to_expected_dim(x)
144        return x, list_from_encoder[:3]
145
146    def forward(self, x: torch.Tensor) -> torch.Tensor:
147        """Apply the vision transformer to input data.
148
149        Args:
150            x: The input data.
151
152        Returns:
153            The vision transformer output.
154        """
155        x, list_from_encoder = self.forward_features(x)
156        return x, list_from_encoder
157
158
159#
160# Utilities for ScaleMAE's ViT
161#
162
163
164class CustomCompose:
165    def __init__(self, rescale_transform, other_transforms, src_transform):
166        self.rescale_transform = rescale_transform
167        self.other_transforms = other_transforms
168        self.src_transform = src_transform
169
170    def __call__(self, x, valid_masks=None):
171        if valid_masks is not None:
172            nodata = (x * (1 - valid_masks.float())).max()
173        x_aug = self.rescale_transform(x)
174        parms = self.rescale_transform._params
175
176        # sanity check, comment if this is working
177        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
178        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
179
180        if valid_masks is not None:
181            valid_masks = x_aug != nodata
182            _, c, h, w = x_aug.shape
183            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
184        else:
185            zero_ratio = -1
186
187        if self.other_transforms:
188            x_aug = self.other_transforms(x_aug)
189        x_src = self.src_transform(x_aug)
190        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
191
192        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
193        # assert (dx == dy).all()
194
195        h, w = x_aug.shape[-2:]
196        # assert h == w
197
198        return x_aug, x_src, dx / h, zero_ratio, valid_masks
199
200
201def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
202    """
203    grid_size: int of the grid height and width
204    res: array of size n, representing the resolution of a pixel (say, in meters),
205    return:
206    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
207    """
208    # res = torch.FloatTensor(res).to(device)
209    res = res.to(device)
210    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
211    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
212    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
213    grid = torch.stack(grid, dim=0)  # 2 x h x w
214
215    # grid = grid.reshape([2, 1, grid_size, grid_size])
216    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
217    _, n, h, w = grid.shape
218    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
219    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
220    if cls_token:
221        pos_embed = torch.cat(
222            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
223        )
224
225    return pos_embed
226
227
228def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
229    assert embed_dim % 2 == 0
230
231    # use half of dimensions to encode grid_h
232    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
233    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
234
235    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
236    return emb
237
238
239def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
240    """
241    embed_dim: output dimension for each position
242    pos: a list of positions to be encoded: size (M,)
243    out: (M, D)
244    """
245    assert embed_dim % 2 == 0
246    # old_shape = pos
247    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
248    omega /= embed_dim / 2.0
249    omega = 1.0 / 10000**omega  # (D/2,)
250
251    pos = pos.reshape(-1)  # (M,)
252    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
253
254    emb_sin = torch.sin(out)  # (M, D/2)
255    emb_cos = torch.cos(out)  # (M, D/2)
256
257    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
258    return emb
259
260
261class PatchEmbedUnSafe(PatchEmbed):
262    """Image to Patch Embedding"""
263
264    def forward(self, x):
265        B, C, H, W = x.shape
266
267        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
268        # assert H == self.img_size[0] and W == self.img_size[1], \
269        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
270
271        x = self.proj(x).flatten(2).transpose(1, 2)
272        return x
273
274
275class ViT_ScaleMAE(VisionTransformer):
276    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
277
278    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
279    the model on a different zoom factor dataset.
280    """
281
282    def __init__(
283        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
284    ):
285        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
286        self.img_size = img_size
287        self.in_chans = in_chans
288        self.depth = depth
289        self.base_resolution = base_resolution
290
291        self.patch_embed = PatchEmbedUnSafe(
292            img_size=img_size,
293            patch_size=patch_size,
294            in_chans=in_chans,
295            embed_dim=embed_dim,
296        )
297
298    def transform_inputs(self, x):
299        import kornia.augmentation as K
300        from kornia.constants import Resample
301
302        self._transforms = CustomCompose(
303            rescale_transform=K.RandomResizedCrop(
304                (448, 448),
305                ratio=(1.0, 1.0),
306                scale=(1.0, 1.0),
307                resample=Resample.BICUBIC.name,
308            ),
309            other_transforms=None,
310            src_transform=K.Resize((224, 224)),
311        )
312        x, _, ratios, _, _ = self._transforms(x)
313        input_res = ratios * self.base_resolution
314        return x, input_res
315
316    def convert_to_expected_dim(self, x):
317        inputs_ = x[:, 1:, :]  # removing the class tokens
318        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
319        rdim = inputs_.shape[1]
320        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
321        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
322        inputs_ = inputs_.permute(0, 3, 1, 2)
323        return inputs_
324
325    def forward_features(self, x):
326        x, input_res = self.transform_inputs(x)
327
328        B, _, h, w = x.shape
329        x = self.patch_embed(x)
330
331        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
332        pos_embed = get_2d_sincos_pos_embed_with_resolution(
333            x.shape[-1],
334            int(num_patches ** 0.5),
335            input_res,
336            cls_token=True,
337            device=x.device,
338        )
339
340        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
341        x = torch.cat((cls_tokens, x), dim=1)
342        x = x + pos_embed
343        x = self.pos_drop(x)
344
345        # chunks obtained for getting the projections for conjuctions with upsampling blocks
346        _chunks = int(self.depth / 4)
347        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
348
349        list_from_encoder = []
350        for i, blk in enumerate(self.blocks):
351            x = blk(x)
352            if i in chunks_for_projection:
353                list_from_encoder.append(self.convert_to_expected_dim(x))
354
355        x = self.convert_to_expected_dim(x)
356
357        return x, list_from_encoder
358
359    def forward(self, x):
360        x, list_from_encoder = self.forward_features(x)
361        return x, list_from_encoder
362
363
364def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
365    """Get vision transformer encoder.
366
367    Args:
368        backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
369        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
370        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
371        kwargs: Additional kwargs which can be expected by the vision transformer,
372            e.g. 'base_resolution' for `ViT_ScaleMAE`.
373
374    Returns:
375        The vision transformer.
376    """
377    if backbone == "sam":
378        if model == "vit_b":
379            encoder = ViT_Sam(
380                depth=12, embed_dim=768, img_size=1024, mlp_ratio=4,
381                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
382                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
383                global_attn_indexes=[2, 5, 8, 11],
384                window_size=14, out_chans=256,
385            )
386        elif model == "vit_l":
387            encoder = ViT_Sam(
388                depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4,
389                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
390                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
391                global_attn_indexes=[5, 11, 17, 23],
392                window_size=14,  out_chans=256
393            )
394        elif model == "vit_h":
395            encoder = ViT_Sam(
396                depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4,
397                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
398                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
399                global_attn_indexes=[7, 15, 23, 31],
400                window_size=14, out_chans=256
401            )
402        else:
403            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
404
405    elif backbone == "mae":
406        if model == "vit_b":
407            encoder = ViT_MAE(
408                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
409                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
410            )
411        elif model == "vit_l":
412            encoder = ViT_MAE(
413                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
414                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
415            )
416        elif model == "vit_h":
417            encoder = ViT_MAE(
418                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
419                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
420            )
421        else:
422            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
423
424    elif backbone == "scalemae":
425        base_resolution = kwargs.get("base_resolution", 2.5)
426
427        if model == "vit_b":
428            encoder = ViT_ScaleMAE(
429                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
430                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
431                base_resolution=base_resolution,
432            )
433        elif model == "vit_l":
434            encoder = ViT_ScaleMAE(
435                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
436                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
437                base_resolution=base_resolution,
438            )
439        elif model == "vit_h":
440            encoder = ViT_ScaleMAE(
441                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
442                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
443                base_resolution=base_resolution,
444            )
445        else:
446            raise ValueError(
447                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
448            )
449
450    else:
451        raise ValueError("The 'UNETR' supported backbones are 'sam', 'mae' or 'scalemae'. Please choose one of them.")
452
453    return encoder
class ViT_Sam:
26class ViT_Sam(ImageEncoderViT):
27    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
28
29    Based on:
30    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
31
32    Args:
33        in_chans: The number of input channels.
34        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
35        global_attn_indexes: The global attention indices.
36        kwargs: Keyword arguments for the image encoder base class.
37    """
38    def __init__(
39        self,
40        in_chans: int = 3,
41        embed_dim: int = 768,
42        global_attn_indexes: Tuple[int, ...] = ...,
43        **kwargs,
44    ) -> None:
45        if not _sam_import_success:
46            raise RuntimeError(
47                "The vision transformer backend can only be initialized if segment anything is installed."
48                "Please install segment anything from https://github.com/facebookresearch/segment-anything."
49                "and then rerun your code."
50            )
51
52        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
53        self.chunks_for_projection = global_attn_indexes
54        self.in_chans = in_chans
55        self.embed_dim = embed_dim
56
57    def forward(self, x: torch.Tensor) -> torch.Tensor:
58        """Apply the vision transformer to input data.
59
60        Args:
61            x: The input data.
62
63        Returns:
64            The vision transformer output.
65        """
66        x = self.patch_embed(x)
67        if self.pos_embed is not None:
68            x = x + self.pos_embed
69
70        list_from_encoder = []
71        for i, blk in enumerate(self.blocks):
72            x = blk(x)
73            if i in self.chunks_for_projection:
74                list_from_encoder.append(x)
75
76        x = x.permute(0, 3, 1, 2)
77        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
78        return x, list_from_encoder[:3]

Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).

Based on: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

Arguments:
  • in_chans: The number of input channels.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • global_attn_indexes: The global attention indices.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam( in_chans: int = 3, embed_dim: int = 768, global_attn_indexes: Tuple[int, ...] = Ellipsis, **kwargs)
38    def __init__(
39        self,
40        in_chans: int = 3,
41        embed_dim: int = 768,
42        global_attn_indexes: Tuple[int, ...] = ...,
43        **kwargs,
44    ) -> None:
45        if not _sam_import_success:
46            raise RuntimeError(
47                "The vision transformer backend can only be initialized if segment anything is installed."
48                "Please install segment anything from https://github.com/facebookresearch/segment-anything."
49                "and then rerun your code."
50            )
51
52        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
53        self.chunks_for_projection = global_attn_indexes
54        self.in_chans = in_chans
55        self.embed_dim = embed_dim
chunks_for_projection
in_chans
embed_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
57    def forward(self, x: torch.Tensor) -> torch.Tensor:
58        """Apply the vision transformer to input data.
59
60        Args:
61            x: The input data.
62
63        Returns:
64            The vision transformer output.
65        """
66        x = self.patch_embed(x)
67        if self.pos_embed is not None:
68            x = x + self.pos_embed
69
70        list_from_encoder = []
71        for i, blk in enumerate(self.blocks):
72            x = blk(x)
73            if i in self.chunks_for_projection:
74                list_from_encoder.append(x)
75
76        x = x.permute(0, 3, 1, 2)
77        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
78        return x, list_from_encoder[:3]

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_MAE:
 81class ViT_MAE(VisionTransformer):
 82    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
 83
 84    Based on:
 85    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
 86
 87    Args:
 88        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 89        in_chans: The number of input channels.
 90        depth: The depth of the vision transformer.
 91        kwargs: Additional keyword arguments for the vision transformer base class.
 92    """
 93    def __init__(
 94        self,
 95        img_size: int = 1024,  # chosen to match our experiments with segment anything
 96        in_chans: int = 3,
 97        depth: int = 12,
 98        **kwargs
 99    ):
100        if not _timm_import_success:
101            raise RuntimeError(
102                "The vision transformer backend can only be initialized if timm is installed."
103                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/."
104                "and then rerun your code"
105            )
106        super().__init__(img_size=img_size, depth=depth, **kwargs)
107        self.img_size = img_size
108        self.in_chans = in_chans
109        self.depth = depth
110
111    def convert_to_expected_dim(self, inputs_):
112        """@private
113        """
114        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
115        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
116        rdim = inputs_.shape[1]
117        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
118        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
119        inputs_ = inputs_.permute(0, 3, 1, 2)
120        return inputs_
121
122    def forward_features(self, x):
123        """@private
124        """
125        B = x.shape[0]
126        x = self.patch_embed(x)
127
128        cls_tokens = self.cls_token.expand(B, -1, -1)
129        x = torch.cat((cls_tokens, x), dim=1)
130
131        x = x + self.pos_embed
132        x = self.pos_drop(x)
133
134        # chunks obtained for getting the projections for conjuctions with upsampling blocks
135        _chunks = int(self.depth / 4)
136        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
137
138        list_from_encoder = []
139        for i, blk in enumerate(self.blocks):
140            x = blk(x)
141            if i in chunks_for_projection:
142                list_from_encoder.append(self.convert_to_expected_dim(x))
143
144        x = self.convert_to_expected_dim(x)
145        return x, list_from_encoder[:3]
146
147    def forward(self, x: torch.Tensor) -> torch.Tensor:
148        """Apply the vision transformer to input data.
149
150        Args:
151            x: The input data.
152
153        Returns:
154            The vision transformer output.
155        """
156        x, list_from_encoder = self.forward_features(x)
157        return x, list_from_encoder

Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).

Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • in_chans: The number of input channels.
  • depth: The depth of the vision transformer.
  • kwargs: Additional keyword arguments for the vision transformer base class.
ViT_MAE(img_size: int = 1024, in_chans: int = 3, depth: int = 12, **kwargs)
 93    def __init__(
 94        self,
 95        img_size: int = 1024,  # chosen to match our experiments with segment anything
 96        in_chans: int = 3,
 97        depth: int = 12,
 98        **kwargs
 99    ):
100        if not _timm_import_success:
101            raise RuntimeError(
102                "The vision transformer backend can only be initialized if timm is installed."
103                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/."
104                "and then rerun your code"
105            )
106        super().__init__(img_size=img_size, depth=depth, **kwargs)
107        self.img_size = img_size
108        self.in_chans = in_chans
109        self.depth = depth
img_size
in_chans
depth
def forward(self, x: torch.Tensor) -> torch.Tensor:
147    def forward(self, x: torch.Tensor) -> torch.Tensor:
148        """Apply the vision transformer to input data.
149
150        Args:
151            x: The input data.
152
153        Returns:
154            The vision transformer output.
155        """
156        x, list_from_encoder = self.forward_features(x)
157        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class CustomCompose:
165class CustomCompose:
166    def __init__(self, rescale_transform, other_transforms, src_transform):
167        self.rescale_transform = rescale_transform
168        self.other_transforms = other_transforms
169        self.src_transform = src_transform
170
171    def __call__(self, x, valid_masks=None):
172        if valid_masks is not None:
173            nodata = (x * (1 - valid_masks.float())).max()
174        x_aug = self.rescale_transform(x)
175        parms = self.rescale_transform._params
176
177        # sanity check, comment if this is working
178        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
179        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
180
181        if valid_masks is not None:
182            valid_masks = x_aug != nodata
183            _, c, h, w = x_aug.shape
184            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
185        else:
186            zero_ratio = -1
187
188        if self.other_transforms:
189            x_aug = self.other_transforms(x_aug)
190        x_src = self.src_transform(x_aug)
191        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
192
193        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
194        # assert (dx == dy).all()
195
196        h, w = x_aug.shape[-2:]
197        # assert h == w
198
199        return x_aug, x_src, dx / h, zero_ratio, valid_masks
CustomCompose(rescale_transform, other_transforms, src_transform)
166    def __init__(self, rescale_transform, other_transforms, src_transform):
167        self.rescale_transform = rescale_transform
168        self.other_transforms = other_transforms
169        self.src_transform = src_transform
rescale_transform
other_transforms
src_transform
def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device='cpu'):
202def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
203    """
204    grid_size: int of the grid height and width
205    res: array of size n, representing the resolution of a pixel (say, in meters),
206    return:
207    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
208    """
209    # res = torch.FloatTensor(res).to(device)
210    res = res.to(device)
211    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
212    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
213    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
214    grid = torch.stack(grid, dim=0)  # 2 x h x w
215
216    # grid = grid.reshape([2, 1, grid_size, grid_size])
217    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
218    _, n, h, w = grid.shape
219    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
220    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
221    if cls_token:
222        pos_embed = torch.cat(
223            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
224        )
225
226    return pos_embed

grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)

def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
229def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
230    assert embed_dim % 2 == 0
231
232    # use half of dimensions to encode grid_h
233    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
234    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
235
236    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
237    return emb
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
240def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
241    """
242    embed_dim: output dimension for each position
243    pos: a list of positions to be encoded: size (M,)
244    out: (M, D)
245    """
246    assert embed_dim % 2 == 0
247    # old_shape = pos
248    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
249    omega /= embed_dim / 2.0
250    omega = 1.0 / 10000**omega  # (D/2,)
251
252    pos = pos.reshape(-1)  # (M,)
253    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
254
255    emb_sin = torch.sin(out)  # (M, D/2)
256    emb_cos = torch.cos(out)  # (M, D/2)
257
258    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
259    return emb

embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)

class PatchEmbedUnSafe:
262class PatchEmbedUnSafe(PatchEmbed):
263    """Image to Patch Embedding"""
264
265    def forward(self, x):
266        B, C, H, W = x.shape
267
268        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
269        # assert H == self.img_size[0] and W == self.img_size[1], \
270        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
271
272        x = self.proj(x).flatten(2).transpose(1, 2)
273        return x

Image to Patch Embedding

def forward(self, x):
265    def forward(self, x):
266        B, C, H, W = x.shape
267
268        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
269        # assert H == self.img_size[0] and W == self.img_size[1], \
270        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
271
272        x = self.proj(x).flatten(2).transpose(1, 2)
273        return x
class ViT_ScaleMAE:
276class ViT_ScaleMAE(VisionTransformer):
277    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
278
279    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
280    the model on a different zoom factor dataset.
281    """
282
283    def __init__(
284        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
285    ):
286        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
287        self.img_size = img_size
288        self.in_chans = in_chans
289        self.depth = depth
290        self.base_resolution = base_resolution
291
292        self.patch_embed = PatchEmbedUnSafe(
293            img_size=img_size,
294            patch_size=patch_size,
295            in_chans=in_chans,
296            embed_dim=embed_dim,
297        )
298
299    def transform_inputs(self, x):
300        import kornia.augmentation as K
301        from kornia.constants import Resample
302
303        self._transforms = CustomCompose(
304            rescale_transform=K.RandomResizedCrop(
305                (448, 448),
306                ratio=(1.0, 1.0),
307                scale=(1.0, 1.0),
308                resample=Resample.BICUBIC.name,
309            ),
310            other_transforms=None,
311            src_transform=K.Resize((224, 224)),
312        )
313        x, _, ratios, _, _ = self._transforms(x)
314        input_res = ratios * self.base_resolution
315        return x, input_res
316
317    def convert_to_expected_dim(self, x):
318        inputs_ = x[:, 1:, :]  # removing the class tokens
319        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
320        rdim = inputs_.shape[1]
321        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
322        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
323        inputs_ = inputs_.permute(0, 3, 1, 2)
324        return inputs_
325
326    def forward_features(self, x):
327        x, input_res = self.transform_inputs(x)
328
329        B, _, h, w = x.shape
330        x = self.patch_embed(x)
331
332        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
333        pos_embed = get_2d_sincos_pos_embed_with_resolution(
334            x.shape[-1],
335            int(num_patches ** 0.5),
336            input_res,
337            cls_token=True,
338            device=x.device,
339        )
340
341        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
342        x = torch.cat((cls_tokens, x), dim=1)
343        x = x + pos_embed
344        x = self.pos_drop(x)
345
346        # chunks obtained for getting the projections for conjuctions with upsampling blocks
347        _chunks = int(self.depth / 4)
348        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
349
350        list_from_encoder = []
351        for i, blk in enumerate(self.blocks):
352            x = blk(x)
353            if i in chunks_for_projection:
354                list_from_encoder.append(self.convert_to_expected_dim(x))
355
356        x = self.convert_to_expected_dim(x)
357
358        return x, list_from_encoder
359
360    def forward(self, x):
361        x, list_from_encoder = self.forward_features(x)
362        return x, list_from_encoder

Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).

NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.

ViT_ScaleMAE( img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs)
283    def __init__(
284        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
285    ):
286        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
287        self.img_size = img_size
288        self.in_chans = in_chans
289        self.depth = depth
290        self.base_resolution = base_resolution
291
292        self.patch_embed = PatchEmbedUnSafe(
293            img_size=img_size,
294            patch_size=patch_size,
295            in_chans=in_chans,
296            embed_dim=embed_dim,
297        )
img_size
in_chans
depth
base_resolution
patch_embed
def transform_inputs(self, x):
299    def transform_inputs(self, x):
300        import kornia.augmentation as K
301        from kornia.constants import Resample
302
303        self._transforms = CustomCompose(
304            rescale_transform=K.RandomResizedCrop(
305                (448, 448),
306                ratio=(1.0, 1.0),
307                scale=(1.0, 1.0),
308                resample=Resample.BICUBIC.name,
309            ),
310            other_transforms=None,
311            src_transform=K.Resize((224, 224)),
312        )
313        x, _, ratios, _, _ = self._transforms(x)
314        input_res = ratios * self.base_resolution
315        return x, input_res
def convert_to_expected_dim(self, x):
317    def convert_to_expected_dim(self, x):
318        inputs_ = x[:, 1:, :]  # removing the class tokens
319        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
320        rdim = inputs_.shape[1]
321        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
322        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
323        inputs_ = inputs_.permute(0, 3, 1, 2)
324        return inputs_
def forward_features(self, x):
326    def forward_features(self, x):
327        x, input_res = self.transform_inputs(x)
328
329        B, _, h, w = x.shape
330        x = self.patch_embed(x)
331
332        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
333        pos_embed = get_2d_sincos_pos_embed_with_resolution(
334            x.shape[-1],
335            int(num_patches ** 0.5),
336            input_res,
337            cls_token=True,
338            device=x.device,
339        )
340
341        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
342        x = torch.cat((cls_tokens, x), dim=1)
343        x = x + pos_embed
344        x = self.pos_drop(x)
345
346        # chunks obtained for getting the projections for conjuctions with upsampling blocks
347        _chunks = int(self.depth / 4)
348        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
349
350        list_from_encoder = []
351        for i, blk in enumerate(self.blocks):
352            x = blk(x)
353            if i in chunks_for_projection:
354                list_from_encoder.append(self.convert_to_expected_dim(x))
355
356        x = self.convert_to_expected_dim(x)
357
358        return x, list_from_encoder
def forward(self, x):
360    def forward(self, x):
361        x, list_from_encoder = self.forward_features(x)
362        return x, list_from_encoder
def get_vision_transformer( backbone: str, model: str, img_size: int = 1024, **kwargs) -> torch.nn.modules.module.Module:
365def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
366    """Get vision transformer encoder.
367
368    Args:
369        backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
370        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
371        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
372        kwargs: Additional kwargs which can be expected by the vision transformer,
373            e.g. 'base_resolution' for `ViT_ScaleMAE`.
374
375    Returns:
376        The vision transformer.
377    """
378    if backbone == "sam":
379        if model == "vit_b":
380            encoder = ViT_Sam(
381                depth=12, embed_dim=768, img_size=1024, mlp_ratio=4,
382                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
383                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
384                global_attn_indexes=[2, 5, 8, 11],
385                window_size=14, out_chans=256,
386            )
387        elif model == "vit_l":
388            encoder = ViT_Sam(
389                depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4,
390                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
391                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
392                global_attn_indexes=[5, 11, 17, 23],
393                window_size=14,  out_chans=256
394            )
395        elif model == "vit_h":
396            encoder = ViT_Sam(
397                depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4,
398                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
399                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
400                global_attn_indexes=[7, 15, 23, 31],
401                window_size=14, out_chans=256
402            )
403        else:
404            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
405
406    elif backbone == "mae":
407        if model == "vit_b":
408            encoder = ViT_MAE(
409                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
410                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
411            )
412        elif model == "vit_l":
413            encoder = ViT_MAE(
414                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
415                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
416            )
417        elif model == "vit_h":
418            encoder = ViT_MAE(
419                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
420                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
421            )
422        else:
423            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
424
425    elif backbone == "scalemae":
426        base_resolution = kwargs.get("base_resolution", 2.5)
427
428        if model == "vit_b":
429            encoder = ViT_ScaleMAE(
430                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
431                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
432                base_resolution=base_resolution,
433            )
434        elif model == "vit_l":
435            encoder = ViT_ScaleMAE(
436                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
437                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
438                base_resolution=base_resolution,
439            )
440        elif model == "vit_h":
441            encoder = ViT_ScaleMAE(
442                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
443                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
444                base_resolution=base_resolution,
445            )
446        else:
447            raise ValueError(
448                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
449            )
450
451    else:
452        raise ValueError("The 'UNETR' supported backbones are 'sam', 'mae' or 'scalemae'. Please choose one of them.")
453
454    return encoder

Get vision transformer encoder.

Arguments:
  • backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
  • model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • kwargs: Additional kwargs which can be expected by the vision transformer, e.g. 'base_resolution' for ViT_ScaleMAE.
Returns:

The vision transformer.