torch_em.model.vit

  1from functools import partial
  2from typing import Tuple, List
  3
  4import torch
  5import torch.nn as nn
  6
  7# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should
  8# only be optional dependencies for torch_em
  9try:
 10    from segment_anything.modeling import ImageEncoderViT
 11    _sam_import_success = True
 12except ImportError:
 13    ImageEncoderViT = object
 14    _sam_import_success = False
 15
 16try:
 17    from timm.models.vision_transformer import VisionTransformer, PatchEmbed
 18    _timm_import_success = True
 19except ImportError:
 20    VisionTransformer = object
 21    PatchEmbed = object
 22    _timm_import_success = False
 23
 24try:
 25    from sam2.modeling.backbones.hieradet import Hiera
 26    from sam2.modeling.position_encoding import PositionEmbeddingSine
 27    from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck
 28    _sam2_import_success = True
 29except ImportError:
 30    ImageEncoder = object
 31    _sam2_import_success = False
 32
 33
 34class ViT_Sam(ImageEncoderViT):
 35    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
 36
 37    Based on:
 38    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 39
 40    Args:
 41        in_chans: The number of input channels.
 42        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 43        global_attn_indexes: The global attention indices.
 44        kwargs: Keyword arguments for the image encoder base class.
 45    """
 46    def __init__(
 47        self,
 48        in_chans: int = 3,
 49        embed_dim: int = 768,
 50        global_attn_indexes: Tuple[int, ...] = ...,
 51        **kwargs,
 52    ) -> None:
 53        if not _sam_import_success:
 54            raise RuntimeError(
 55                "The vision transformer backend can only be initialized if segment anything is installed. "
 56                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
 57                "and then rerun your code."
 58            )
 59
 60        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
 61        self.chunks_for_projection = global_attn_indexes
 62        self.in_chans = in_chans
 63        self.embed_dim = embed_dim
 64
 65    def forward(self, x: torch.Tensor) -> torch.Tensor:
 66        """Apply the vision transformer to input data.
 67
 68        Args:
 69            x: The input data.
 70
 71        Returns:
 72            The vision transformer output.
 73        """
 74        x = self.patch_embed(x)
 75        if self.pos_embed is not None:
 76            x = x + self.pos_embed
 77
 78        list_from_encoder = []
 79        for i, blk in enumerate(self.blocks):
 80            x = blk(x)
 81            if i in self.chunks_for_projection:
 82                list_from_encoder.append(x)
 83
 84        x = x.permute(0, 3, 1, 2)
 85        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
 86        return x, list_from_encoder[:3]
 87
 88
 89class ViT_MAE(VisionTransformer):
 90    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
 91
 92    Based on:
 93    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
 94
 95    Args:
 96        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 97        in_chans: The number of input channels.
 98        depth: The depth of the vision transformer.
 99        kwargs: Additional keyword arguments for the vision transformer base class.
100    """
101    def __init__(
102        self,
103        img_size: int = 1024,  # chosen to match our experiments with segment anything
104        in_chans: int = 3,
105        depth: int = 12,
106        **kwargs
107    ):
108        if not _timm_import_success:
109            raise RuntimeError(
110                "The vision transformer backend can only be initialized if timm is installed. "
111                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
112                "and then rerun your code"
113            )
114        super().__init__(img_size=img_size, depth=depth, **kwargs)
115        self.img_size = img_size
116        self.in_chans = in_chans
117        self.depth = depth
118
119    def convert_to_expected_dim(self, inputs_):
120        """@private
121        """
122        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
123        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
124        rdim = inputs_.shape[1]
125        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
126        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
127        inputs_ = inputs_.permute(0, 3, 1, 2)
128        return inputs_
129
130    def forward_features(self, x):
131        """@private
132        """
133        B = x.shape[0]
134        x = self.patch_embed(x)
135
136        cls_tokens = self.cls_token.expand(B, -1, -1)
137        x = torch.cat((cls_tokens, x), dim=1)
138
139        x = x + self.pos_embed
140        x = self.pos_drop(x)
141
142        # chunks obtained for getting the projections for conjuctions with upsampling blocks
143        _chunks = int(self.depth / 4)
144        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
145
146        list_from_encoder = []
147        for i, blk in enumerate(self.blocks):
148            x = blk(x)
149            if i in chunks_for_projection:
150                list_from_encoder.append(self.convert_to_expected_dim(x))
151
152        x = self.convert_to_expected_dim(x)
153        return x, list_from_encoder[:3]
154
155    def forward(self, x: torch.Tensor) -> torch.Tensor:
156        """Apply the vision transformer to input data.
157
158        Args:
159            x: The input data.
160
161        Returns:
162            The vision transformer output.
163        """
164        x, list_from_encoder = self.forward_features(x)
165        return x, list_from_encoder
166
167
168class ViT_Sam2(ImageEncoder):
169    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
170
171    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
172
173    Args:
174        backbone_channel_list: The channels throughout the entire backbone.
175        embed_dim: The initial embedding dimension.
176        num_heads: The initial number of heads.
177        stages: The number of blocks per stage.
178        global_att_blocks: The parameter to decide which blocks have global attention.
179        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
180        window_spec: The window size per stage, when not using global attention.
181        scalp: The count of lowest resolution features to discard.
182    """
183    def __init__(
184        self,
185        backbone_channel_list: List[int],
186        embed_dim: int = 96,
187        num_heads: int = 1,
188        stages: Tuple[int, ...] = (2, 3, 16, 3),
189        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
190        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
191        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
192        scalp: int = 1,
193    ):
194        if not _sam2_import_success:
195            raise RuntimeError(
196                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
197                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
198                "and then rerun your code"
199            )
200
201        trunk = Hiera(
202            embed_dim=embed_dim,
203            num_heads=num_heads,
204            stages=stages,
205            global_att_blocks=global_att_blocks,
206            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
207            window_spec=window_spec,
208        )
209        neck = FpnNeck(
210            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
211            d_model=256,
212            backbone_channel_list=backbone_channel_list,
213            fpn_top_down_levels=[2, 3],
214            fpn_interp_model="nearest",
215        )
216
217        super().__init__(trunk=trunk, neck=neck, scalp=scalp)
218        self.scalp = scalp
219        self.embed_dim = embed_dim
220        self.img_size = 1024  # NOTE: Hard-coded atm, declared in the configuration file.
221
222    def forward(self, x: torch.Tensor):
223        # The forward pass throught the backbone.
224        features, pos = self.neck(self.trunk(x))
225        if self.scalp > 0:  # This discard the "n" lowest resolution features.
226            features, pos = features[:-self.scalp], pos[:-self.scalp]
227
228        return features[-1], features
229
230
231#
232# Utilities for ScaleMAE's ViT
233#
234
235
236class CustomCompose:
237    def __init__(self, rescale_transform, other_transforms, src_transform):
238        self.rescale_transform = rescale_transform
239        self.other_transforms = other_transforms
240        self.src_transform = src_transform
241
242    def __call__(self, x, valid_masks=None):
243        if valid_masks is not None:
244            nodata = (x * (1 - valid_masks.float())).max()
245        x_aug = self.rescale_transform(x)
246        parms = self.rescale_transform._params
247
248        # sanity check, comment if this is working
249        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
250        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
251
252        if valid_masks is not None:
253            valid_masks = x_aug != nodata
254            _, c, h, w = x_aug.shape
255            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
256        else:
257            zero_ratio = -1
258
259        if self.other_transforms:
260            x_aug = self.other_transforms(x_aug)
261        x_src = self.src_transform(x_aug)
262        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
263
264        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
265        # assert (dx == dy).all()
266
267        h, w = x_aug.shape[-2:]
268        # assert h == w
269
270        return x_aug, x_src, dx / h, zero_ratio, valid_masks
271
272
273def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
274    """
275    grid_size: int of the grid height and width
276    res: array of size n, representing the resolution of a pixel (say, in meters),
277    return:
278    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
279    """
280    # res = torch.FloatTensor(res).to(device)
281    res = res.to(device)
282    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
283    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
284    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
285    grid = torch.stack(grid, dim=0)  # 2 x h x w
286
287    # grid = grid.reshape([2, 1, grid_size, grid_size])
288    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
289    _, n, h, w = grid.shape
290    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
291    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
292    if cls_token:
293        pos_embed = torch.cat(
294            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
295        )
296
297    return pos_embed
298
299
300def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
301    assert embed_dim % 2 == 0
302
303    # use half of dimensions to encode grid_h
304    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
305    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
306
307    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
308    return emb
309
310
311def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
312    """
313    embed_dim: output dimension for each position
314    pos: a list of positions to be encoded: size (M,)
315    out: (M, D)
316    """
317    assert embed_dim % 2 == 0
318    # old_shape = pos
319    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
320    omega /= embed_dim / 2.0
321    omega = 1.0 / 10000**omega  # (D/2,)
322
323    pos = pos.reshape(-1)  # (M,)
324    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
325
326    emb_sin = torch.sin(out)  # (M, D/2)
327    emb_cos = torch.cos(out)  # (M, D/2)
328
329    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
330    return emb
331
332
333class PatchEmbedUnSafe(PatchEmbed):
334    """Image to Patch Embedding"""
335
336    def forward(self, x):
337        B, C, H, W = x.shape
338
339        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
340        # assert H == self.img_size[0] and W == self.img_size[1], \
341        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
342
343        x = self.proj(x).flatten(2).transpose(1, 2)
344        return x
345
346
347class ViT_ScaleMAE(VisionTransformer):
348    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
349
350    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
351    the model on a different zoom factor dataset.
352    """
353
354    def __init__(
355        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
356    ):
357        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
358        self.img_size = img_size
359        self.in_chans = in_chans
360        self.depth = depth
361        self.base_resolution = base_resolution
362
363        self.patch_embed = PatchEmbedUnSafe(
364            img_size=img_size,
365            patch_size=patch_size,
366            in_chans=in_chans,
367            embed_dim=embed_dim,
368        )
369
370    def transform_inputs(self, x):
371        import kornia.augmentation as K
372        from kornia.constants import Resample
373
374        self._transforms = CustomCompose(
375            rescale_transform=K.RandomResizedCrop(
376                (448, 448),
377                ratio=(1.0, 1.0),
378                scale=(1.0, 1.0),
379                resample=Resample.BICUBIC.name,
380            ),
381            other_transforms=None,
382            src_transform=K.Resize((224, 224)),
383        )
384        x, _, ratios, _, _ = self._transforms(x)
385        input_res = ratios * self.base_resolution
386        return x, input_res
387
388    def convert_to_expected_dim(self, x):
389        inputs_ = x[:, 1:, :]  # removing the class tokens
390        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
391        rdim = inputs_.shape[1]
392        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
393        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
394        inputs_ = inputs_.permute(0, 3, 1, 2)
395        return inputs_
396
397    def forward_features(self, x):
398        x, input_res = self.transform_inputs(x)
399
400        B, _, h, w = x.shape
401        x = self.patch_embed(x)
402
403        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
404        pos_embed = get_2d_sincos_pos_embed_with_resolution(
405            x.shape[-1],
406            int(num_patches ** 0.5),
407            input_res,
408            cls_token=True,
409            device=x.device,
410        )
411
412        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
413        x = torch.cat((cls_tokens, x), dim=1)
414        x = x + pos_embed
415        x = self.pos_drop(x)
416
417        # chunks obtained for getting the projections for conjuctions with upsampling blocks
418        _chunks = int(self.depth / 4)
419        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
420
421        list_from_encoder = []
422        for i, blk in enumerate(self.blocks):
423            x = blk(x)
424            if i in chunks_for_projection:
425                list_from_encoder.append(self.convert_to_expected_dim(x))
426
427        x = self.convert_to_expected_dim(x)
428
429        return x, list_from_encoder
430
431    def forward(self, x):
432        x, list_from_encoder = self.forward_features(x)
433        return x, list_from_encoder
434
435
436def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
437    """Get vision transformer encoder.
438
439    Args:
440        backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
441        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
442        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
443        kwargs: Additional kwargs which can be expected by the vision transformer,
444            e.g. 'base_resolution' for `ViT_ScaleMAE`.
445
446    Returns:
447        The vision transformer.
448    """
449    if backbone == "sam":
450        if model == "vit_b":
451            encoder = ViT_Sam(
452                depth=12, embed_dim=768, img_size=1024, mlp_ratio=4,
453                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
454                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
455                global_attn_indexes=[2, 5, 8, 11],
456                window_size=14, out_chans=256,
457            )
458        elif model == "vit_l":
459            encoder = ViT_Sam(
460                depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4,
461                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
462                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
463                global_attn_indexes=[5, 11, 17, 23],
464                window_size=14, out_chans=256,
465            )
466        elif model == "vit_h":
467            encoder = ViT_Sam(
468                depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4,
469                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
470                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
471                global_attn_indexes=[7, 15, 23, 31],
472                window_size=14, out_chans=256,
473            )
474        else:
475            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
476
477    elif backbone == "sam2":
478        if model == "hvit_t":
479            encoder = ViT_Sam2(
480                embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
481                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
482            )
483        elif model == "hvit_s":
484            encoder = ViT_Sam2(
485                embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
486                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
487            )
488        elif model == "hvit_b":
489            encoder = ViT_Sam2(
490                embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
491            )
492        elif model == "hvit_l":
493            encoder = ViT_Sam2(
494                embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
495                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
496            )
497        else:
498            raise ValueError(
499                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
500            )
501
502    elif backbone == "mae":
503        if model == "vit_b":
504            encoder = ViT_MAE(
505                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
506                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
507            )
508        elif model == "vit_l":
509            encoder = ViT_MAE(
510                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
511                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
512            )
513        elif model == "vit_h":
514            encoder = ViT_MAE(
515                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
516                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
517            )
518        else:
519            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
520
521    elif backbone == "scalemae":
522        base_resolution = kwargs.get("base_resolution", 2.5)
523
524        if model == "vit_b":
525            encoder = ViT_ScaleMAE(
526                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
527                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
528                base_resolution=base_resolution,
529            )
530        elif model == "vit_l":
531            encoder = ViT_ScaleMAE(
532                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
533                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
534                base_resolution=base_resolution,
535            )
536        elif model == "vit_h":
537            encoder = ViT_ScaleMAE(
538                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
539                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
540                base_resolution=base_resolution,
541            )
542        else:
543            raise ValueError(
544                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
545            )
546
547    else:
548        raise ValueError(
549            "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae' or 'scalemae'. Please choose one of them."
550        )
551
552    return encoder
class ViT_Sam:
35class ViT_Sam(ImageEncoderViT):
36    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
37
38    Based on:
39    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
40
41    Args:
42        in_chans: The number of input channels.
43        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
44        global_attn_indexes: The global attention indices.
45        kwargs: Keyword arguments for the image encoder base class.
46    """
47    def __init__(
48        self,
49        in_chans: int = 3,
50        embed_dim: int = 768,
51        global_attn_indexes: Tuple[int, ...] = ...,
52        **kwargs,
53    ) -> None:
54        if not _sam_import_success:
55            raise RuntimeError(
56                "The vision transformer backend can only be initialized if segment anything is installed. "
57                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
58                "and then rerun your code."
59            )
60
61        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
62        self.chunks_for_projection = global_attn_indexes
63        self.in_chans = in_chans
64        self.embed_dim = embed_dim
65
66    def forward(self, x: torch.Tensor) -> torch.Tensor:
67        """Apply the vision transformer to input data.
68
69        Args:
70            x: The input data.
71
72        Returns:
73            The vision transformer output.
74        """
75        x = self.patch_embed(x)
76        if self.pos_embed is not None:
77            x = x + self.pos_embed
78
79        list_from_encoder = []
80        for i, blk in enumerate(self.blocks):
81            x = blk(x)
82            if i in self.chunks_for_projection:
83                list_from_encoder.append(x)
84
85        x = x.permute(0, 3, 1, 2)
86        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
87        return x, list_from_encoder[:3]

Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).

Based on: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

Arguments:
  • in_chans: The number of input channels.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • global_attn_indexes: The global attention indices.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam( in_chans: int = 3, embed_dim: int = 768, global_attn_indexes: Tuple[int, ...] = Ellipsis, **kwargs)
47    def __init__(
48        self,
49        in_chans: int = 3,
50        embed_dim: int = 768,
51        global_attn_indexes: Tuple[int, ...] = ...,
52        **kwargs,
53    ) -> None:
54        if not _sam_import_success:
55            raise RuntimeError(
56                "The vision transformer backend can only be initialized if segment anything is installed. "
57                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
58                "and then rerun your code."
59            )
60
61        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
62        self.chunks_for_projection = global_attn_indexes
63        self.in_chans = in_chans
64        self.embed_dim = embed_dim
chunks_for_projection
in_chans
embed_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
66    def forward(self, x: torch.Tensor) -> torch.Tensor:
67        """Apply the vision transformer to input data.
68
69        Args:
70            x: The input data.
71
72        Returns:
73            The vision transformer output.
74        """
75        x = self.patch_embed(x)
76        if self.pos_embed is not None:
77            x = x + self.pos_embed
78
79        list_from_encoder = []
80        for i, blk in enumerate(self.blocks):
81            x = blk(x)
82            if i in self.chunks_for_projection:
83                list_from_encoder.append(x)
84
85        x = x.permute(0, 3, 1, 2)
86        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
87        return x, list_from_encoder[:3]

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_MAE:
 90class ViT_MAE(VisionTransformer):
 91    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
 92
 93    Based on:
 94    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
 95
 96    Args:
 97        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 98        in_chans: The number of input channels.
 99        depth: The depth of the vision transformer.
100        kwargs: Additional keyword arguments for the vision transformer base class.
101    """
102    def __init__(
103        self,
104        img_size: int = 1024,  # chosen to match our experiments with segment anything
105        in_chans: int = 3,
106        depth: int = 12,
107        **kwargs
108    ):
109        if not _timm_import_success:
110            raise RuntimeError(
111                "The vision transformer backend can only be initialized if timm is installed. "
112                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
113                "and then rerun your code"
114            )
115        super().__init__(img_size=img_size, depth=depth, **kwargs)
116        self.img_size = img_size
117        self.in_chans = in_chans
118        self.depth = depth
119
120    def convert_to_expected_dim(self, inputs_):
121        """@private
122        """
123        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
124        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
125        rdim = inputs_.shape[1]
126        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
127        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
128        inputs_ = inputs_.permute(0, 3, 1, 2)
129        return inputs_
130
131    def forward_features(self, x):
132        """@private
133        """
134        B = x.shape[0]
135        x = self.patch_embed(x)
136
137        cls_tokens = self.cls_token.expand(B, -1, -1)
138        x = torch.cat((cls_tokens, x), dim=1)
139
140        x = x + self.pos_embed
141        x = self.pos_drop(x)
142
143        # chunks obtained for getting the projections for conjuctions with upsampling blocks
144        _chunks = int(self.depth / 4)
145        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
146
147        list_from_encoder = []
148        for i, blk in enumerate(self.blocks):
149            x = blk(x)
150            if i in chunks_for_projection:
151                list_from_encoder.append(self.convert_to_expected_dim(x))
152
153        x = self.convert_to_expected_dim(x)
154        return x, list_from_encoder[:3]
155
156    def forward(self, x: torch.Tensor) -> torch.Tensor:
157        """Apply the vision transformer to input data.
158
159        Args:
160            x: The input data.
161
162        Returns:
163            The vision transformer output.
164        """
165        x, list_from_encoder = self.forward_features(x)
166        return x, list_from_encoder

Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).

Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • in_chans: The number of input channels.
  • depth: The depth of the vision transformer.
  • kwargs: Additional keyword arguments for the vision transformer base class.
ViT_MAE(img_size: int = 1024, in_chans: int = 3, depth: int = 12, **kwargs)
102    def __init__(
103        self,
104        img_size: int = 1024,  # chosen to match our experiments with segment anything
105        in_chans: int = 3,
106        depth: int = 12,
107        **kwargs
108    ):
109        if not _timm_import_success:
110            raise RuntimeError(
111                "The vision transformer backend can only be initialized if timm is installed. "
112                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
113                "and then rerun your code"
114            )
115        super().__init__(img_size=img_size, depth=depth, **kwargs)
116        self.img_size = img_size
117        self.in_chans = in_chans
118        self.depth = depth
img_size
in_chans
depth
def forward(self, x: torch.Tensor) -> torch.Tensor:
156    def forward(self, x: torch.Tensor) -> torch.Tensor:
157        """Apply the vision transformer to input data.
158
159        Args:
160            x: The input data.
161
162        Returns:
163            The vision transformer output.
164        """
165        x, list_from_encoder = self.forward_features(x)
166        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_Sam2:
169class ViT_Sam2(ImageEncoder):
170    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
171
172    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
173
174    Args:
175        backbone_channel_list: The channels throughout the entire backbone.
176        embed_dim: The initial embedding dimension.
177        num_heads: The initial number of heads.
178        stages: The number of blocks per stage.
179        global_att_blocks: The parameter to decide which blocks have global attention.
180        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
181        window_spec: The window size per stage, when not using global attention.
182        scalp: The count of lowest resolution features to discard.
183    """
184    def __init__(
185        self,
186        backbone_channel_list: List[int],
187        embed_dim: int = 96,
188        num_heads: int = 1,
189        stages: Tuple[int, ...] = (2, 3, 16, 3),
190        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
191        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
192        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
193        scalp: int = 1,
194    ):
195        if not _sam2_import_success:
196            raise RuntimeError(
197                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
198                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
199                "and then rerun your code"
200            )
201
202        trunk = Hiera(
203            embed_dim=embed_dim,
204            num_heads=num_heads,
205            stages=stages,
206            global_att_blocks=global_att_blocks,
207            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
208            window_spec=window_spec,
209        )
210        neck = FpnNeck(
211            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
212            d_model=256,
213            backbone_channel_list=backbone_channel_list,
214            fpn_top_down_levels=[2, 3],
215            fpn_interp_model="nearest",
216        )
217
218        super().__init__(trunk=trunk, neck=neck, scalp=scalp)
219        self.scalp = scalp
220        self.embed_dim = embed_dim
221        self.img_size = 1024  # NOTE: Hard-coded atm, declared in the configuration file.
222
223    def forward(self, x: torch.Tensor):
224        # The forward pass throught the backbone.
225        features, pos = self.neck(self.trunk(x))
226        if self.scalp > 0:  # This discard the "n" lowest resolution features.
227            features, pos = features[:-self.scalp], pos[:-self.scalp]
228
229        return features[-1], features

Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).

Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.

Arguments:
  • backbone_channel_list: The channels throughout the entire backbone.
  • embed_dim: The initial embedding dimension.
  • num_heads: The initial number of heads.
  • stages: The number of blocks per stage.
  • global_att_blocks: The parameter to decide which blocks have global attention.
  • window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
  • window_spec: The window size per stage, when not using global attention.
  • scalp: The count of lowest resolution features to discard.
ViT_Sam2( backbone_channel_list: List[int], embed_dim: int = 96, num_heads: int = 1, stages: Tuple[int, ...] = (2, 3, 16, 3), global_att_blocks: Tuple[int, ...] = (12, 16, 20), window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), window_spec: Tuple[int, ...] = (8, 4, 14, 7), scalp: int = 1)
184    def __init__(
185        self,
186        backbone_channel_list: List[int],
187        embed_dim: int = 96,
188        num_heads: int = 1,
189        stages: Tuple[int, ...] = (2, 3, 16, 3),
190        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
191        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
192        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
193        scalp: int = 1,
194    ):
195        if not _sam2_import_success:
196            raise RuntimeError(
197                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
198                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
199                "and then rerun your code"
200            )
201
202        trunk = Hiera(
203            embed_dim=embed_dim,
204            num_heads=num_heads,
205            stages=stages,
206            global_att_blocks=global_att_blocks,
207            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
208            window_spec=window_spec,
209        )
210        neck = FpnNeck(
211            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
212            d_model=256,
213            backbone_channel_list=backbone_channel_list,
214            fpn_top_down_levels=[2, 3],
215            fpn_interp_model="nearest",
216        )
217
218        super().__init__(trunk=trunk, neck=neck, scalp=scalp)
219        self.scalp = scalp
220        self.embed_dim = embed_dim
221        self.img_size = 1024  # NOTE: Hard-coded atm, declared in the configuration file.
scalp
embed_dim
img_size
def forward(self, x: torch.Tensor):
223    def forward(self, x: torch.Tensor):
224        # The forward pass throught the backbone.
225        features, pos = self.neck(self.trunk(x))
226        if self.scalp > 0:  # This discard the "n" lowest resolution features.
227            features, pos = features[:-self.scalp], pos[:-self.scalp]
228
229        return features[-1], features
class CustomCompose:
237class CustomCompose:
238    def __init__(self, rescale_transform, other_transforms, src_transform):
239        self.rescale_transform = rescale_transform
240        self.other_transforms = other_transforms
241        self.src_transform = src_transform
242
243    def __call__(self, x, valid_masks=None):
244        if valid_masks is not None:
245            nodata = (x * (1 - valid_masks.float())).max()
246        x_aug = self.rescale_transform(x)
247        parms = self.rescale_transform._params
248
249        # sanity check, comment if this is working
250        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
251        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
252
253        if valid_masks is not None:
254            valid_masks = x_aug != nodata
255            _, c, h, w = x_aug.shape
256            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
257        else:
258            zero_ratio = -1
259
260        if self.other_transforms:
261            x_aug = self.other_transforms(x_aug)
262        x_src = self.src_transform(x_aug)
263        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
264
265        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
266        # assert (dx == dy).all()
267
268        h, w = x_aug.shape[-2:]
269        # assert h == w
270
271        return x_aug, x_src, dx / h, zero_ratio, valid_masks
CustomCompose(rescale_transform, other_transforms, src_transform)
238    def __init__(self, rescale_transform, other_transforms, src_transform):
239        self.rescale_transform = rescale_transform
240        self.other_transforms = other_transforms
241        self.src_transform = src_transform
rescale_transform
other_transforms
src_transform
def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device='cpu'):
274def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
275    """
276    grid_size: int of the grid height and width
277    res: array of size n, representing the resolution of a pixel (say, in meters),
278    return:
279    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
280    """
281    # res = torch.FloatTensor(res).to(device)
282    res = res.to(device)
283    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
284    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
285    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
286    grid = torch.stack(grid, dim=0)  # 2 x h x w
287
288    # grid = grid.reshape([2, 1, grid_size, grid_size])
289    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
290    _, n, h, w = grid.shape
291    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
292    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
293    if cls_token:
294        pos_embed = torch.cat(
295            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
296        )
297
298    return pos_embed

grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)

def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
301def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
302    assert embed_dim % 2 == 0
303
304    # use half of dimensions to encode grid_h
305    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
306    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
307
308    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
309    return emb
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
312def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
313    """
314    embed_dim: output dimension for each position
315    pos: a list of positions to be encoded: size (M,)
316    out: (M, D)
317    """
318    assert embed_dim % 2 == 0
319    # old_shape = pos
320    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
321    omega /= embed_dim / 2.0
322    omega = 1.0 / 10000**omega  # (D/2,)
323
324    pos = pos.reshape(-1)  # (M,)
325    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
326
327    emb_sin = torch.sin(out)  # (M, D/2)
328    emb_cos = torch.cos(out)  # (M, D/2)
329
330    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
331    return emb

embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)

class PatchEmbedUnSafe:
334class PatchEmbedUnSafe(PatchEmbed):
335    """Image to Patch Embedding"""
336
337    def forward(self, x):
338        B, C, H, W = x.shape
339
340        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
341        # assert H == self.img_size[0] and W == self.img_size[1], \
342        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
343
344        x = self.proj(x).flatten(2).transpose(1, 2)
345        return x

Image to Patch Embedding

def forward(self, x):
337    def forward(self, x):
338        B, C, H, W = x.shape
339
340        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
341        # assert H == self.img_size[0] and W == self.img_size[1], \
342        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
343
344        x = self.proj(x).flatten(2).transpose(1, 2)
345        return x
class ViT_ScaleMAE:
348class ViT_ScaleMAE(VisionTransformer):
349    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
350
351    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
352    the model on a different zoom factor dataset.
353    """
354
355    def __init__(
356        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
357    ):
358        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
359        self.img_size = img_size
360        self.in_chans = in_chans
361        self.depth = depth
362        self.base_resolution = base_resolution
363
364        self.patch_embed = PatchEmbedUnSafe(
365            img_size=img_size,
366            patch_size=patch_size,
367            in_chans=in_chans,
368            embed_dim=embed_dim,
369        )
370
371    def transform_inputs(self, x):
372        import kornia.augmentation as K
373        from kornia.constants import Resample
374
375        self._transforms = CustomCompose(
376            rescale_transform=K.RandomResizedCrop(
377                (448, 448),
378                ratio=(1.0, 1.0),
379                scale=(1.0, 1.0),
380                resample=Resample.BICUBIC.name,
381            ),
382            other_transforms=None,
383            src_transform=K.Resize((224, 224)),
384        )
385        x, _, ratios, _, _ = self._transforms(x)
386        input_res = ratios * self.base_resolution
387        return x, input_res
388
389    def convert_to_expected_dim(self, x):
390        inputs_ = x[:, 1:, :]  # removing the class tokens
391        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
392        rdim = inputs_.shape[1]
393        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
394        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
395        inputs_ = inputs_.permute(0, 3, 1, 2)
396        return inputs_
397
398    def forward_features(self, x):
399        x, input_res = self.transform_inputs(x)
400
401        B, _, h, w = x.shape
402        x = self.patch_embed(x)
403
404        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
405        pos_embed = get_2d_sincos_pos_embed_with_resolution(
406            x.shape[-1],
407            int(num_patches ** 0.5),
408            input_res,
409            cls_token=True,
410            device=x.device,
411        )
412
413        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
414        x = torch.cat((cls_tokens, x), dim=1)
415        x = x + pos_embed
416        x = self.pos_drop(x)
417
418        # chunks obtained for getting the projections for conjuctions with upsampling blocks
419        _chunks = int(self.depth / 4)
420        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
421
422        list_from_encoder = []
423        for i, blk in enumerate(self.blocks):
424            x = blk(x)
425            if i in chunks_for_projection:
426                list_from_encoder.append(self.convert_to_expected_dim(x))
427
428        x = self.convert_to_expected_dim(x)
429
430        return x, list_from_encoder
431
432    def forward(self, x):
433        x, list_from_encoder = self.forward_features(x)
434        return x, list_from_encoder

Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).

NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.

ViT_ScaleMAE( img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs)
355    def __init__(
356        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
357    ):
358        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
359        self.img_size = img_size
360        self.in_chans = in_chans
361        self.depth = depth
362        self.base_resolution = base_resolution
363
364        self.patch_embed = PatchEmbedUnSafe(
365            img_size=img_size,
366            patch_size=patch_size,
367            in_chans=in_chans,
368            embed_dim=embed_dim,
369        )
img_size
in_chans
depth
base_resolution
patch_embed
def transform_inputs(self, x):
371    def transform_inputs(self, x):
372        import kornia.augmentation as K
373        from kornia.constants import Resample
374
375        self._transforms = CustomCompose(
376            rescale_transform=K.RandomResizedCrop(
377                (448, 448),
378                ratio=(1.0, 1.0),
379                scale=(1.0, 1.0),
380                resample=Resample.BICUBIC.name,
381            ),
382            other_transforms=None,
383            src_transform=K.Resize((224, 224)),
384        )
385        x, _, ratios, _, _ = self._transforms(x)
386        input_res = ratios * self.base_resolution
387        return x, input_res
def convert_to_expected_dim(self, x):
389    def convert_to_expected_dim(self, x):
390        inputs_ = x[:, 1:, :]  # removing the class tokens
391        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
392        rdim = inputs_.shape[1]
393        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
394        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
395        inputs_ = inputs_.permute(0, 3, 1, 2)
396        return inputs_
def forward_features(self, x):
398    def forward_features(self, x):
399        x, input_res = self.transform_inputs(x)
400
401        B, _, h, w = x.shape
402        x = self.patch_embed(x)
403
404        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
405        pos_embed = get_2d_sincos_pos_embed_with_resolution(
406            x.shape[-1],
407            int(num_patches ** 0.5),
408            input_res,
409            cls_token=True,
410            device=x.device,
411        )
412
413        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
414        x = torch.cat((cls_tokens, x), dim=1)
415        x = x + pos_embed
416        x = self.pos_drop(x)
417
418        # chunks obtained for getting the projections for conjuctions with upsampling blocks
419        _chunks = int(self.depth / 4)
420        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
421
422        list_from_encoder = []
423        for i, blk in enumerate(self.blocks):
424            x = blk(x)
425            if i in chunks_for_projection:
426                list_from_encoder.append(self.convert_to_expected_dim(x))
427
428        x = self.convert_to_expected_dim(x)
429
430        return x, list_from_encoder
def forward(self, x):
432    def forward(self, x):
433        x, list_from_encoder = self.forward_features(x)
434        return x, list_from_encoder
def get_vision_transformer( backbone: str, model: str, img_size: int = 1024, **kwargs) -> torch.nn.modules.module.Module:
437def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
438    """Get vision transformer encoder.
439
440    Args:
441        backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
442        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
443        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
444        kwargs: Additional kwargs which can be expected by the vision transformer,
445            e.g. 'base_resolution' for `ViT_ScaleMAE`.
446
447    Returns:
448        The vision transformer.
449    """
450    if backbone == "sam":
451        if model == "vit_b":
452            encoder = ViT_Sam(
453                depth=12, embed_dim=768, img_size=1024, mlp_ratio=4,
454                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
455                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
456                global_attn_indexes=[2, 5, 8, 11],
457                window_size=14, out_chans=256,
458            )
459        elif model == "vit_l":
460            encoder = ViT_Sam(
461                depth=24, embed_dim=1024, img_size=1024, mlp_ratio=4,
462                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
463                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
464                global_attn_indexes=[5, 11, 17, 23],
465                window_size=14, out_chans=256,
466            )
467        elif model == "vit_h":
468            encoder = ViT_Sam(
469                depth=32, embed_dim=1280, img_size=1024, mlp_ratio=4,
470                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
471                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
472                global_attn_indexes=[7, 15, 23, 31],
473                window_size=14, out_chans=256,
474            )
475        else:
476            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
477
478    elif backbone == "sam2":
479        if model == "hvit_t":
480            encoder = ViT_Sam2(
481                embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
482                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
483            )
484        elif model == "hvit_s":
485            encoder = ViT_Sam2(
486                embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
487                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
488            )
489        elif model == "hvit_b":
490            encoder = ViT_Sam2(
491                embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
492            )
493        elif model == "hvit_l":
494            encoder = ViT_Sam2(
495                embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
496                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
497            )
498        else:
499            raise ValueError(
500                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
501            )
502
503    elif backbone == "mae":
504        if model == "vit_b":
505            encoder = ViT_MAE(
506                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
507                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
508            )
509        elif model == "vit_l":
510            encoder = ViT_MAE(
511                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
512                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
513            )
514        elif model == "vit_h":
515            encoder = ViT_MAE(
516                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
517                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
518            )
519        else:
520            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
521
522    elif backbone == "scalemae":
523        base_resolution = kwargs.get("base_resolution", 2.5)
524
525        if model == "vit_b":
526            encoder = ViT_ScaleMAE(
527                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
528                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
529                base_resolution=base_resolution,
530            )
531        elif model == "vit_l":
532            encoder = ViT_ScaleMAE(
533                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
534                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
535                base_resolution=base_resolution,
536            )
537        elif model == "vit_h":
538            encoder = ViT_ScaleMAE(
539                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
540                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
541                base_resolution=base_resolution,
542            )
543        else:
544            raise ValueError(
545                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
546            )
547
548    else:
549        raise ValueError(
550            "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae' or 'scalemae'. Please choose one of them."
551        )
552
553    return encoder

Get vision transformer encoder.

Arguments:
  • backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
  • model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • kwargs: Additional kwargs which can be expected by the vision transformer, e.g. 'base_resolution' for ViT_ScaleMAE.
Returns:

The vision transformer.