torch_em.model.vit

  1import math
  2from functools import partial
  3from typing import Tuple, List
  4
  5import torch
  6import torch.nn as nn
  7
  8# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should
  9# only be optional dependencies for torch_em
 10try:
 11    from segment_anything.modeling import ImageEncoderViT
 12    _sam_import_success = True
 13except ImportError:
 14    ImageEncoderViT = object
 15    _sam_import_success = False
 16
 17try:
 18    from timm.models.vision_transformer import VisionTransformer, PatchEmbed
 19    _timm_import_success = True
 20except ImportError:
 21    VisionTransformer = object
 22    PatchEmbed = object
 23    _timm_import_success = False
 24
 25try:
 26    from sam2.modeling.backbones.hieradet import Hiera
 27    from sam2.modeling.position_encoding import PositionEmbeddingSine
 28    from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck
 29    _sam2_import_success = True
 30except ImportError:
 31    ImageEncoder = object
 32    _sam2_import_success = False
 33
 34try:
 35    from dinov2.models.vision_transformer import DinoVisionTransformer as DinoV2VisionTransformer
 36    from dinov2.layers import MemEffAttention, NestedTensorBlock as Block
 37    _dinov2_import_success = True
 38except ImportError:
 39    DinoV2VisionTransformer = object
 40    _dinov2_import_success = False
 41
 42try:
 43    from dinov3.models.vision_transformer import DinoVisionTransformer as DinoV3VisionTransformer
 44    _dinov3_import_success = True
 45except ImportError:
 46    DinoV3VisionTransformer = object
 47    _dinov3_import_success = False
 48
 49
 50try:
 51    from sam3.model.vitdet import ViT as SAM3ViT, get_abs_pos
 52    _sam3_import_success = True
 53except ImportError:
 54    SAM3ViT = object
 55    _sam3_import_success = False
 56
 57
 58class ViT_Sam(ImageEncoderViT):
 59    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
 60
 61    Based on:
 62    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 63
 64    Args:
 65        in_chans: The number of input channels.
 66        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 67        global_attn_indexes: The global attention indices.
 68        kwargs: Keyword arguments for the image encoder base class.
 69    """
 70    def __init__(
 71        self,
 72        in_chans: int = 3,
 73        embed_dim: int = 768,
 74        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
 75        **kwargs,
 76    ) -> None:
 77        if not _sam_import_success:
 78            raise RuntimeError(
 79                "The vision transformer backend can only be initialized if segment anything is installed. "
 80                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
 81                "and then rerun your code."
 82            )
 83
 84        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
 85        self.chunks_for_projection = global_attn_indexes
 86        self.in_chans = in_chans
 87        self.embed_dim = embed_dim
 88
 89    def forward(self, x: torch.Tensor) -> torch.Tensor:
 90        """Apply the vision transformer to input data.
 91
 92        Args:
 93            x: The input data.
 94
 95        Returns:
 96            The vision transformer output.
 97        """
 98        x = self.patch_embed(x)
 99        if self.pos_embed is not None:
100            x = x + self.pos_embed
101
102        list_from_encoder = []
103        for i, blk in enumerate(self.blocks):
104            x = blk(x)
105            if i in self.chunks_for_projection:
106                list_from_encoder.append(x)
107
108        x = x.permute(0, 3, 1, 2)
109        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
110        return x, list_from_encoder[:3]
111
112
113class ViT_MAE(VisionTransformer):
114    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
115
116    Based on:
117    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
118
119    Args:
120        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
121        in_chans: The number of input channels.
122        depth: The depth of the vision transformer.
123        kwargs: Additional keyword arguments for the vision transformer base class.
124    """
125    def __init__(
126        self,
127        img_size: int = 1024,  # chosen to match our experiments with segment anything
128        in_chans: int = 3,
129        depth: int = 12,
130        **kwargs
131    ):
132        if not _timm_import_success:
133            raise RuntimeError(
134                "The vision transformer backend can only be initialized if timm is installed. "
135                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
136                "and then rerun your code"
137            )
138        super().__init__(img_size=img_size, depth=depth, **kwargs)
139        self.img_size = img_size
140        self.in_chans = in_chans
141        self.depth = depth
142
143    def convert_to_expected_dim(self, inputs_):
144        """@private
145        """
146        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
147        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
148        rdim = inputs_.shape[1]
149        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
150        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
151        inputs_ = inputs_.permute(0, 3, 1, 2)
152        return inputs_
153
154    def forward_features(self, x):
155        """@private
156        """
157        B = x.shape[0]
158        x = self.patch_embed(x)
159
160        cls_tokens = self.cls_token.expand(B, -1, -1)
161        x = torch.cat((cls_tokens, x), dim=1)
162
163        x = x + self.pos_embed
164        x = self.pos_drop(x)
165
166        # chunks obtained for getting the projections for conjuctions with upsampling blocks
167        _chunks = int(self.depth / 4)
168        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
169
170        list_from_encoder = []
171        for i, blk in enumerate(self.blocks):
172            x = blk(x)
173            if i in chunks_for_projection:
174                list_from_encoder.append(self.convert_to_expected_dim(x))
175
176        x = self.convert_to_expected_dim(x)
177        return x, list_from_encoder[:3]
178
179    def forward(self, x: torch.Tensor) -> torch.Tensor:
180        """Apply the vision transformer to input data.
181
182        Args:
183            x: The input data.
184
185        Returns:
186            The vision transformer output.
187        """
188        x, list_from_encoder = self.forward_features(x)
189        return x, list_from_encoder
190
191
192class ViT_Sam2(ImageEncoder):
193    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
194
195    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
196
197    Args:
198        backbone_channel_list: The channels throughout the entire backbone.
199        embed_dim: The initial embedding dimension.
200        num_heads: The initial number of heads.
201        stages: The number of blocks per stage.
202        global_att_blocks: The parameter to decide which blocks have global attention.
203        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
204        window_spec: The window size per stage, when not using global attention.
205        scalp: The count of lowest resolution features to discard.
206    """
207    def __init__(
208        self,
209        backbone_channel_list: List[int],
210        img_size: int = 1024,
211        embed_dim: int = 96,
212        num_heads: int = 1,
213        stages: Tuple[int, ...] = (2, 3, 16, 3),
214        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
215        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
216        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
217        scalp: int = 1,
218        **kwargs
219    ):
220        if not _sam2_import_success:
221            raise RuntimeError(
222                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
223                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
224                "and then rerun your code"
225            )
226
227        trunk = Hiera(
228            embed_dim=embed_dim,
229            num_heads=num_heads,
230            stages=stages,
231            global_att_blocks=global_att_blocks,
232            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
233            window_spec=window_spec,
234        )
235        neck = FpnNeck(
236            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
237            d_model=256,
238            backbone_channel_list=backbone_channel_list,
239            fpn_top_down_levels=[2, 3],
240            fpn_interp_model="nearest",
241        )
242
243        super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs)
244        self.scalp = scalp
245        self.embed_dim = embed_dim
246        self.img_size = img_size
247
248    def forward(self, x: torch.Tensor):
249        # The forward pass throught the backbone.
250        features, pos = self.neck(self.trunk(x))
251        if self.scalp > 0:  # This discard the "n" lowest resolution features.
252            features, pos = features[:-self.scalp], pos[:-self.scalp]
253
254        return features[-1], features
255
256
257class ViT_Sam3(SAM3ViT):
258    """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).
259
260    Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.
261
262    Args:
263        img_size: The input image size.
264        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
265        kwargs: Keyword arguments for the image encoder base class.
266    """
267    def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs):
268        if not _sam3_import_success:
269            raise RuntimeError(
270                "The vision transformer backend can only be initialized if segment anything 3 is installed. "
271                "Please install segment anything 3 from https://github.com/facebookresearch/sam3 "
272                "and then rerun your code"
273            )
274
275        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
276        self.img_size = img_size
277        self.embed_dim = embed_dim
278
279    def forward_features(self, x):
280        """@private
281        """
282        x = self.patch_embed(x)
283        h, w = x.shape[1], x.shape[2]
284
285        s = 0
286        if self.retain_cls_token:
287            # If the 'cls_token' is retained, we don't maintain the spatial shape.
288            x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
289            s = 1
290
291        if self.pos_embed is not None:
292            x = x + get_abs_pos(
293                self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos,
294            )
295
296        x = self.ln_pre(x)
297
298        list_from_encoder = []
299        for i, blk in enumerate(self.blocks):
300            if self.use_act_checkpoint and self.training:
301                x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False)
302            else:
303                x = blk(x)
304
305            x = self._convert_to_expected_dim(x, i, s)
306
307            if i in self.full_attn_ids:
308                list_from_encoder.append(x)
309
310        return x, list_from_encoder
311
312    def _convert_to_expected_dim(self, x, i, s):
313        if (i == self.full_attn_ids[-1]) or (
314            self.return_interm_layers and i in self.full_attn_ids
315        ):
316            if i == self.full_attn_ids[-1]:
317                x = self.ln_post(x)
318
319            feats = x[:, s:]
320            if feats.ndim == 4:
321                feats = feats.permute(0, 3, 1, 2)
322            else:
323                assert feats.ndim == 3
324                h = w = math.sqrt(feats.shape[1])
325                feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
326            return feats
327
328        else:
329            return x
330
331    def forward(self, x: torch.Tensor):
332        """Apply the vision transformer to input data.
333
334        Args:
335            x: The input data.
336
337        Returns:
338            The vision transformer output.
339        """
340        x, list_from_encoder = self.forward_features(x)
341        return x, list_from_encoder
342
343#
344# Utilities for ScaleMAE's ViT
345#
346
347
348class CustomCompose:
349    def __init__(self, rescale_transform, other_transforms, src_transform):
350        self.rescale_transform = rescale_transform
351        self.other_transforms = other_transforms
352        self.src_transform = src_transform
353
354    def __call__(self, x, valid_masks=None):
355        if valid_masks is not None:
356            nodata = (x * (1 - valid_masks.float())).max()
357        x_aug = self.rescale_transform(x)
358        parms = self.rescale_transform._params
359
360        # sanity check, comment if this is working
361        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
362        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
363
364        if valid_masks is not None:
365            valid_masks = x_aug != nodata
366            _, c, h, w = x_aug.shape
367            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
368        else:
369            zero_ratio = -1
370
371        if self.other_transforms:
372            x_aug = self.other_transforms(x_aug)
373        x_src = self.src_transform(x_aug)
374        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
375
376        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
377        # assert (dx == dy).all()
378
379        h, w = x_aug.shape[-2:]
380        # assert h == w
381
382        return x_aug, x_src, dx / h, zero_ratio, valid_masks
383
384
385def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
386    """
387    grid_size: int of the grid height and width
388    res: array of size n, representing the resolution of a pixel (say, in meters),
389    return:
390    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
391    """
392    # res = torch.FloatTensor(res).to(device)
393    res = res.to(device)
394    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
395    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
396    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
397    grid = torch.stack(grid, dim=0)  # 2 x h x w
398
399    # grid = grid.reshape([2, 1, grid_size, grid_size])
400    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
401    _, n, h, w = grid.shape
402    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
403    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
404    if cls_token:
405        pos_embed = torch.cat(
406            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
407        )
408
409    return pos_embed
410
411
412def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
413    assert embed_dim % 2 == 0
414
415    # use half of dimensions to encode grid_h
416    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
417    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
418
419    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
420    return emb
421
422
423def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
424    """
425    embed_dim: output dimension for each position
426    pos: a list of positions to be encoded: size (M,)
427    out: (M, D)
428    """
429    assert embed_dim % 2 == 0
430    # old_shape = pos
431    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
432    omega /= embed_dim / 2.0
433    omega = 1.0 / 10000**omega  # (D/2,)
434
435    pos = pos.reshape(-1)  # (M,)
436    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
437
438    emb_sin = torch.sin(out)  # (M, D/2)
439    emb_cos = torch.cos(out)  # (M, D/2)
440
441    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
442    return emb
443
444
445class PatchEmbedUnSafe(PatchEmbed):
446    """Image to Patch Embedding"""
447
448    def forward(self, x):
449        B, C, H, W = x.shape
450
451        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
452        # assert H == self.img_size[0] and W == self.img_size[1], \
453        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
454
455        x = self.proj(x).flatten(2).transpose(1, 2)
456        return x
457
458
459class ViT_ScaleMAE(VisionTransformer):
460    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
461
462    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
463    the model on a different zoom factor dataset.
464    """
465
466    def __init__(
467        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
468    ):
469        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
470        self.img_size = img_size
471        self.in_chans = in_chans
472        self.depth = depth
473        self.base_resolution = base_resolution
474
475        self.patch_embed = PatchEmbedUnSafe(
476            img_size=img_size,
477            patch_size=patch_size,
478            in_chans=in_chans,
479            embed_dim=embed_dim,
480        )
481
482    def transform_inputs(self, x):
483        import kornia.augmentation as K
484        from kornia.constants import Resample
485
486        self._transforms = CustomCompose(
487            rescale_transform=K.RandomResizedCrop(
488                (448, 448),
489                ratio=(1.0, 1.0),
490                scale=(1.0, 1.0),
491                resample=Resample.BICUBIC.name,
492            ),
493            other_transforms=None,
494            src_transform=K.Resize((224, 224)),
495        )
496        x, _, ratios, _, _ = self._transforms(x)
497        input_res = ratios * self.base_resolution
498        return x, input_res
499
500    def convert_to_expected_dim(self, x):
501        inputs_ = x[:, 1:, :]  # removing the class tokens
502        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
503        rdim = inputs_.shape[1]
504        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
505        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
506        inputs_ = inputs_.permute(0, 3, 1, 2)
507        return inputs_
508
509    def forward_features(self, x):
510        x, input_res = self.transform_inputs(x)
511
512        B, _, h, w = x.shape
513        x = self.patch_embed(x)
514
515        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
516        pos_embed = get_2d_sincos_pos_embed_with_resolution(
517            x.shape[-1],
518            int(num_patches ** 0.5),
519            input_res,
520            cls_token=True,
521            device=x.device,
522        )
523
524        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
525        x = torch.cat((cls_tokens, x), dim=1)
526        x = x + pos_embed
527        x = self.pos_drop(x)
528
529        # chunks obtained for getting the projections for conjuctions with upsampling blocks
530        _chunks = int(self.depth / 4)
531        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
532
533        list_from_encoder = []
534        for i, blk in enumerate(self.blocks):
535            x = blk(x)
536            if i in chunks_for_projection:
537                list_from_encoder.append(self.convert_to_expected_dim(x))
538
539        x = self.convert_to_expected_dim(x)
540
541        return x, list_from_encoder
542
543    def forward(self, x):
544        x, list_from_encoder = self.forward_features(x)
545        return x, list_from_encoder
546
547
548class ViT_DINOv2(DinoV2VisionTransformer):
549    """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
550
551    Based on:
552    https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
553
554    Args:
555        img_size: The input image size.
556        patch_size: The patch size.
557        depth: The depth of the network.
558        num_register_tokens: The number of registers added (in addition to the class tokens).
559            It's important to know for ViTs trained with registers, to remove them at the end.
560    """
561    def __init__(
562        self,
563        img_size: int = 224,
564        patch_size: int = 16,
565        depth: int = 12,
566        num_register_tokens: int = 0,
567        **kwargs
568    ):
569        if not _dinov2_import_success:
570            raise RuntimeError(
571                "The vision transformer backend can only be initialized if DINOv2 is installed. "
572                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
573                "and then rerun your code."
574            )
575
576        super().__init__(
577            img_size=img_size,
578            depth=depth,
579            patch_size=patch_size,
580            num_register_tokens=num_register_tokens,
581            **kwargs
582        )
583
584        self.img_size = img_size
585        self.num_register_tokens = num_register_tokens
586        self.patch_size = patch_size
587        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
588
589    def forward(self, x, masks=None) -> torch.Tensor:
590
591        B = x.shape[0]
592
593        x = self.prepare_tokens_with_masks(x)
594
595        list_of_encoder = []
596        for i, blk in enumerate(self.blocks):
597            x = blk(x)
598            if i in self.attn_outs:
599                list_of_encoder.append(x)
600
601        x = self.norm(x)
602        x = x[:, self.num_register_tokens + 1:].reshape(
603            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
604        ).permute(0, 3, 1, 2).contiguous()
605
606        list_of_encoder = [
607            o[:, self.num_register_tokens + 1:].reshape(
608                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
609            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
610        ]
611
612        return x, list_of_encoder[:3]
613
614
615class ViT_DINOv3(DinoV3VisionTransformer):
616    """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
617
618    Based on:
619    https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
620
621    Args:
622        img_size: The input image size.
623        patch_size: The patch size.
624        embed_dim: The embedding dimension.
625        depth: The depth of the network.
626        num_heads: The number of heads.
627        ffn_ratio: The FFN rato.
628        n_storage_tokens: The number of storage (class) tokens to remove.
629        kwargs: Keyword arguments for the image encoder base class.
630    """
631    def __init__(
632        self,
633        in_chans: int = 3,
634        img_size: int = 224,
635        patch_size: int = 16,
636        embed_dim: int = 768,
637        depth: int = 12,
638        num_heads: int = 12,
639        ffn_ratio: float = 4.0,
640        n_storage_tokens: int = 0,
641        **kwargs
642    ):
643        if not _dinov3_import_success:
644            raise RuntimeError(
645                "The vision transformer backend can only be initialized if DINOv3 is installed. "
646                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
647                "and then rerun your code."
648            )
649
650        super().__init__(
651            in_chans=in_chans,
652            img_size=img_size,
653            patch_size=patch_size,
654            embed_dim=embed_dim,
655            depth=depth,
656            num_heads=num_heads,
657            ffn_ratio=ffn_ratio,
658            n_storage_tokens=n_storage_tokens,
659            **kwargs
660        )
661
662        self.in_chans = in_chans
663        self.img_size = img_size
664        self.n_storage_tokens = n_storage_tokens
665        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
666
667    def forward(self, x) -> torch.Tensor:
668
669        B = x.shape[0]
670
671        x, hw_tuple = self.prepare_tokens_with_masks(x)
672
673        list_of_encoder = []
674        for i, blk in enumerate(self.blocks):
675            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
676            x = blk(x, rope_sincos)
677            if i in self.attn_outs:
678                list_of_encoder.append(x)
679
680        x = self.norm(x)
681        x = x[:, self.n_storage_tokens + 1:].reshape(
682            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
683        ).permute(0, 3, 1, 2).contiguous()
684
685        list_of_encoder = [
686            o[:, self.n_storage_tokens + 1:].reshape(
687                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
688            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
689        ]
690
691        return x, list_of_encoder[:3]
692
693
694def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
695    """Get vision transformer encoder.
696
697    Args:
698        backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
699        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
700        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
701        kwargs: Additional kwargs which can be expected by the vision transformer,
702            e.g. 'base_resolution' for `ViT_ScaleMAE`.
703
704    Returns:
705        The vision transformer.
706    """
707    if backbone == "sam":
708        if model == "vit_b":
709            encoder = ViT_Sam(
710                depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4,
711                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
712                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
713                global_attn_indexes=[2, 5, 8, 11],
714                window_size=14, out_chans=256,
715            )
716        elif model == "vit_l":
717            encoder = ViT_Sam(
718                depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4,
719                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
720                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
721                global_attn_indexes=[5, 11, 17, 23],
722                window_size=14, out_chans=256,
723            )
724        elif model == "vit_h":
725            encoder = ViT_Sam(
726                depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4,
727                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
728                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
729                global_attn_indexes=[7, 15, 23, 31],
730                window_size=14, out_chans=256,
731            )
732        else:
733            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
734
735    elif backbone == "sam2":
736        if model == "hvit_t":
737            encoder = ViT_Sam2(
738                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
739                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
740            )
741        elif model == "hvit_s":
742            encoder = ViT_Sam2(
743                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
744                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
745            )
746        elif model == "hvit_b":
747            encoder = ViT_Sam2(
748                img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
749            )
750        elif model == "hvit_l":
751            encoder = ViT_Sam2(
752                img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
753                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
754            )
755        else:
756            raise ValueError(
757                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
758            )
759
760    elif backbone == "sam3":
761        if model != "vit_pe":
762            raise ValueError(
763                "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration."
764            )
765
766        encoder = ViT_Sam3(
767            img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16,
768            mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True,
769            tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True,
770            use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True,
771            ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None,
772        )
773
774    elif backbone == "mae":
775        if model == "vit_b":
776            encoder = ViT_MAE(
777                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
778                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
779            )
780        elif model == "vit_l":
781            encoder = ViT_MAE(
782                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
783                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
784            )
785        elif model == "vit_h":
786            encoder = ViT_MAE(
787                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
788                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
789            )
790        else:
791            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
792
793    elif backbone == "scalemae":
794        base_resolution = kwargs.get("base_resolution", 2.5)
795
796        if model == "vit_b":
797            encoder = ViT_ScaleMAE(
798                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
799                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
800                base_resolution=base_resolution,
801            )
802        elif model == "vit_l":
803            encoder = ViT_ScaleMAE(
804                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
805                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
806                base_resolution=base_resolution,
807            )
808        elif model == "vit_h":
809            encoder = ViT_ScaleMAE(
810                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
811                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
812                base_resolution=base_resolution,
813            )
814        else:
815            raise ValueError(
816                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
817            )
818
819    elif backbone == "dinov2":
820        block_fn = partial(Block, attn_class=MemEffAttention)
821        msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>."
822
823        if model.startswith("vit_s"):
824            assert model in ["vit_s", "vit_s_reg4"], msg
825            encoder = ViT_DINOv2(
826                img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
827                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
828                num_register_tokens=4 if model.endswith("_reg4") else 0,
829            )
830        elif model.startswith("vit_b"):
831            assert model in ["vit_b", "vit_b_reg4"], msg
832            encoder = ViT_DINOv2(
833                img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
834                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
835                num_register_tokens=4 if model.endswith("_reg4") else 0,
836            )
837        elif model.startswith("vit_l"):
838            assert model in ["vit_l", "vit_l_reg4"], msg
839            encoder = ViT_DINOv2(
840                img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
841                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
842                num_register_tokens=4 if model.endswith("_reg4") else 0,
843            )
844        elif model.startswith("vit_g"):
845            assert model in ["vit_g", "vit_g_reg4"], msg
846            encoder = ViT_DINOv2(
847                img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4,
848                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
849                num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu",
850            )
851        else:
852            raise ValueError(
853                f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported."
854            )
855
856    elif backbone == "dinov3":
857
858        if model == "vit_s":
859            encoder = ViT_DINOv3(
860                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
861                num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
862            )
863        elif model == "vit_s+":
864            encoder = ViT_DINOv3(
865                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
866                num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16",
867                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
868            )
869
870        elif model == "vit_b":
871            encoder = ViT_DINOv3(
872                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32",
873                layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
874            )
875        elif model == "vit_l":
876            encoder = ViT_DINOv3(
877                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
878                depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16",
879                n_storage_tokens=4, mask_k_bias=True,
880            )
881        elif model == "vit_l+":
882            encoder = ViT_DINOv3(
883                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
884                depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
885                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
886            )
887        elif model == "vit_h+":
888            encoder = ViT_DINOv3(
889                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280,
890                depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
891                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
892            )
893        elif model == "vit_7b":
894            encoder = ViT_DINOv3(
895                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096,
896                depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05,
897                norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True,
898                untie_global_and_local_cls_norm=True,
899            )
900        else:
901            raise ValueError(
902                f"'{model}' is not supported by DINOv3. Currently, "
903                " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported."
904            )
905
906    else:
907        raise ValueError(
908            "The 'UNETR' supported backbones are 'sam', 'sam2', 'sam3', 'mae', 'scalemae', 'dinov2' or 'dinov3'. "
909            "Please choose one of them."
910        )
911
912    return encoder
class ViT_Sam:
 59class ViT_Sam(ImageEncoderViT):
 60    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
 61
 62    Based on:
 63    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 64
 65    Args:
 66        in_chans: The number of input channels.
 67        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 68        global_attn_indexes: The global attention indices.
 69        kwargs: Keyword arguments for the image encoder base class.
 70    """
 71    def __init__(
 72        self,
 73        in_chans: int = 3,
 74        embed_dim: int = 768,
 75        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
 76        **kwargs,
 77    ) -> None:
 78        if not _sam_import_success:
 79            raise RuntimeError(
 80                "The vision transformer backend can only be initialized if segment anything is installed. "
 81                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
 82                "and then rerun your code."
 83            )
 84
 85        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
 86        self.chunks_for_projection = global_attn_indexes
 87        self.in_chans = in_chans
 88        self.embed_dim = embed_dim
 89
 90    def forward(self, x: torch.Tensor) -> torch.Tensor:
 91        """Apply the vision transformer to input data.
 92
 93        Args:
 94            x: The input data.
 95
 96        Returns:
 97            The vision transformer output.
 98        """
 99        x = self.patch_embed(x)
100        if self.pos_embed is not None:
101            x = x + self.pos_embed
102
103        list_from_encoder = []
104        for i, blk in enumerate(self.blocks):
105            x = blk(x)
106            if i in self.chunks_for_projection:
107                list_from_encoder.append(x)
108
109        x = x.permute(0, 3, 1, 2)
110        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
111        return x, list_from_encoder[:3]

Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).

Based on: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

Arguments:
  • in_chans: The number of input channels.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • global_attn_indexes: The global attention indices.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam( in_chans: int = 3, embed_dim: int = 768, global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], **kwargs)
71    def __init__(
72        self,
73        in_chans: int = 3,
74        embed_dim: int = 768,
75        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
76        **kwargs,
77    ) -> None:
78        if not _sam_import_success:
79            raise RuntimeError(
80                "The vision transformer backend can only be initialized if segment anything is installed. "
81                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
82                "and then rerun your code."
83            )
84
85        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
86        self.chunks_for_projection = global_attn_indexes
87        self.in_chans = in_chans
88        self.embed_dim = embed_dim
chunks_for_projection
in_chans
embed_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
 90    def forward(self, x: torch.Tensor) -> torch.Tensor:
 91        """Apply the vision transformer to input data.
 92
 93        Args:
 94            x: The input data.
 95
 96        Returns:
 97            The vision transformer output.
 98        """
 99        x = self.patch_embed(x)
100        if self.pos_embed is not None:
101            x = x + self.pos_embed
102
103        list_from_encoder = []
104        for i, blk in enumerate(self.blocks):
105            x = blk(x)
106            if i in self.chunks_for_projection:
107                list_from_encoder.append(x)
108
109        x = x.permute(0, 3, 1, 2)
110        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
111        return x, list_from_encoder[:3]

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_MAE:
114class ViT_MAE(VisionTransformer):
115    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
116
117    Based on:
118    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
119
120    Args:
121        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
122        in_chans: The number of input channels.
123        depth: The depth of the vision transformer.
124        kwargs: Additional keyword arguments for the vision transformer base class.
125    """
126    def __init__(
127        self,
128        img_size: int = 1024,  # chosen to match our experiments with segment anything
129        in_chans: int = 3,
130        depth: int = 12,
131        **kwargs
132    ):
133        if not _timm_import_success:
134            raise RuntimeError(
135                "The vision transformer backend can only be initialized if timm is installed. "
136                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
137                "and then rerun your code"
138            )
139        super().__init__(img_size=img_size, depth=depth, **kwargs)
140        self.img_size = img_size
141        self.in_chans = in_chans
142        self.depth = depth
143
144    def convert_to_expected_dim(self, inputs_):
145        """@private
146        """
147        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
148        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
149        rdim = inputs_.shape[1]
150        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
151        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
152        inputs_ = inputs_.permute(0, 3, 1, 2)
153        return inputs_
154
155    def forward_features(self, x):
156        """@private
157        """
158        B = x.shape[0]
159        x = self.patch_embed(x)
160
161        cls_tokens = self.cls_token.expand(B, -1, -1)
162        x = torch.cat((cls_tokens, x), dim=1)
163
164        x = x + self.pos_embed
165        x = self.pos_drop(x)
166
167        # chunks obtained for getting the projections for conjuctions with upsampling blocks
168        _chunks = int(self.depth / 4)
169        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
170
171        list_from_encoder = []
172        for i, blk in enumerate(self.blocks):
173            x = blk(x)
174            if i in chunks_for_projection:
175                list_from_encoder.append(self.convert_to_expected_dim(x))
176
177        x = self.convert_to_expected_dim(x)
178        return x, list_from_encoder[:3]
179
180    def forward(self, x: torch.Tensor) -> torch.Tensor:
181        """Apply the vision transformer to input data.
182
183        Args:
184            x: The input data.
185
186        Returns:
187            The vision transformer output.
188        """
189        x, list_from_encoder = self.forward_features(x)
190        return x, list_from_encoder

Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).

Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • in_chans: The number of input channels.
  • depth: The depth of the vision transformer.
  • kwargs: Additional keyword arguments for the vision transformer base class.
ViT_MAE(img_size: int = 1024, in_chans: int = 3, depth: int = 12, **kwargs)
126    def __init__(
127        self,
128        img_size: int = 1024,  # chosen to match our experiments with segment anything
129        in_chans: int = 3,
130        depth: int = 12,
131        **kwargs
132    ):
133        if not _timm_import_success:
134            raise RuntimeError(
135                "The vision transformer backend can only be initialized if timm is installed. "
136                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
137                "and then rerun your code"
138            )
139        super().__init__(img_size=img_size, depth=depth, **kwargs)
140        self.img_size = img_size
141        self.in_chans = in_chans
142        self.depth = depth
img_size
in_chans
depth
def forward(self, x: torch.Tensor) -> torch.Tensor:
180    def forward(self, x: torch.Tensor) -> torch.Tensor:
181        """Apply the vision transformer to input data.
182
183        Args:
184            x: The input data.
185
186        Returns:
187            The vision transformer output.
188        """
189        x, list_from_encoder = self.forward_features(x)
190        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_Sam2:
193class ViT_Sam2(ImageEncoder):
194    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
195
196    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
197
198    Args:
199        backbone_channel_list: The channels throughout the entire backbone.
200        embed_dim: The initial embedding dimension.
201        num_heads: The initial number of heads.
202        stages: The number of blocks per stage.
203        global_att_blocks: The parameter to decide which blocks have global attention.
204        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
205        window_spec: The window size per stage, when not using global attention.
206        scalp: The count of lowest resolution features to discard.
207    """
208    def __init__(
209        self,
210        backbone_channel_list: List[int],
211        img_size: int = 1024,
212        embed_dim: int = 96,
213        num_heads: int = 1,
214        stages: Tuple[int, ...] = (2, 3, 16, 3),
215        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
216        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
217        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
218        scalp: int = 1,
219        **kwargs
220    ):
221        if not _sam2_import_success:
222            raise RuntimeError(
223                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
224                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
225                "and then rerun your code"
226            )
227
228        trunk = Hiera(
229            embed_dim=embed_dim,
230            num_heads=num_heads,
231            stages=stages,
232            global_att_blocks=global_att_blocks,
233            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
234            window_spec=window_spec,
235        )
236        neck = FpnNeck(
237            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
238            d_model=256,
239            backbone_channel_list=backbone_channel_list,
240            fpn_top_down_levels=[2, 3],
241            fpn_interp_model="nearest",
242        )
243
244        super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs)
245        self.scalp = scalp
246        self.embed_dim = embed_dim
247        self.img_size = img_size
248
249    def forward(self, x: torch.Tensor):
250        # The forward pass throught the backbone.
251        features, pos = self.neck(self.trunk(x))
252        if self.scalp > 0:  # This discard the "n" lowest resolution features.
253            features, pos = features[:-self.scalp], pos[:-self.scalp]
254
255        return features[-1], features

Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).

Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.

Arguments:
  • backbone_channel_list: The channels throughout the entire backbone.
  • embed_dim: The initial embedding dimension.
  • num_heads: The initial number of heads.
  • stages: The number of blocks per stage.
  • global_att_blocks: The parameter to decide which blocks have global attention.
  • window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
  • window_spec: The window size per stage, when not using global attention.
  • scalp: The count of lowest resolution features to discard.
ViT_Sam2( backbone_channel_list: List[int], img_size: int = 1024, embed_dim: int = 96, num_heads: int = 1, stages: Tuple[int, ...] = (2, 3, 16, 3), global_att_blocks: Tuple[int, ...] = (12, 16, 20), window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), window_spec: Tuple[int, ...] = (8, 4, 14, 7), scalp: int = 1, **kwargs)
208    def __init__(
209        self,
210        backbone_channel_list: List[int],
211        img_size: int = 1024,
212        embed_dim: int = 96,
213        num_heads: int = 1,
214        stages: Tuple[int, ...] = (2, 3, 16, 3),
215        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
216        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
217        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
218        scalp: int = 1,
219        **kwargs
220    ):
221        if not _sam2_import_success:
222            raise RuntimeError(
223                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
224                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
225                "and then rerun your code"
226            )
227
228        trunk = Hiera(
229            embed_dim=embed_dim,
230            num_heads=num_heads,
231            stages=stages,
232            global_att_blocks=global_att_blocks,
233            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
234            window_spec=window_spec,
235        )
236        neck = FpnNeck(
237            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
238            d_model=256,
239            backbone_channel_list=backbone_channel_list,
240            fpn_top_down_levels=[2, 3],
241            fpn_interp_model="nearest",
242        )
243
244        super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs)
245        self.scalp = scalp
246        self.embed_dim = embed_dim
247        self.img_size = img_size
scalp
embed_dim
img_size
def forward(self, x: torch.Tensor):
249    def forward(self, x: torch.Tensor):
250        # The forward pass throught the backbone.
251        features, pos = self.neck(self.trunk(x))
252        if self.scalp > 0:  # This discard the "n" lowest resolution features.
253            features, pos = features[:-self.scalp], pos[:-self.scalp]
254
255        return features[-1], features
class ViT_Sam3:
258class ViT_Sam3(SAM3ViT):
259    """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).
260
261    Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.
262
263    Args:
264        img_size: The input image size.
265        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
266        kwargs: Keyword arguments for the image encoder base class.
267    """
268    def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs):
269        if not _sam3_import_success:
270            raise RuntimeError(
271                "The vision transformer backend can only be initialized if segment anything 3 is installed. "
272                "Please install segment anything 3 from https://github.com/facebookresearch/sam3 "
273                "and then rerun your code"
274            )
275
276        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
277        self.img_size = img_size
278        self.embed_dim = embed_dim
279
280    def forward_features(self, x):
281        """@private
282        """
283        x = self.patch_embed(x)
284        h, w = x.shape[1], x.shape[2]
285
286        s = 0
287        if self.retain_cls_token:
288            # If the 'cls_token' is retained, we don't maintain the spatial shape.
289            x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
290            s = 1
291
292        if self.pos_embed is not None:
293            x = x + get_abs_pos(
294                self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos,
295            )
296
297        x = self.ln_pre(x)
298
299        list_from_encoder = []
300        for i, blk in enumerate(self.blocks):
301            if self.use_act_checkpoint and self.training:
302                x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False)
303            else:
304                x = blk(x)
305
306            x = self._convert_to_expected_dim(x, i, s)
307
308            if i in self.full_attn_ids:
309                list_from_encoder.append(x)
310
311        return x, list_from_encoder
312
313    def _convert_to_expected_dim(self, x, i, s):
314        if (i == self.full_attn_ids[-1]) or (
315            self.return_interm_layers and i in self.full_attn_ids
316        ):
317            if i == self.full_attn_ids[-1]:
318                x = self.ln_post(x)
319
320            feats = x[:, s:]
321            if feats.ndim == 4:
322                feats = feats.permute(0, 3, 1, 2)
323            else:
324                assert feats.ndim == 3
325                h = w = math.sqrt(feats.shape[1])
326                feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
327            return feats
328
329        else:
330            return x
331
332    def forward(self, x: torch.Tensor):
333        """Apply the vision transformer to input data.
334
335        Args:
336            x: The input data.
337
338        Returns:
339            The vision transformer output.
340        """
341        x, list_from_encoder = self.forward_features(x)
342        return x, list_from_encoder

Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).

Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.

Arguments:
  • img_size: The input image size.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam3(img_size: int = 1024, embed_dim: int = 768, **kwargs)
268    def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs):
269        if not _sam3_import_success:
270            raise RuntimeError(
271                "The vision transformer backend can only be initialized if segment anything 3 is installed. "
272                "Please install segment anything 3 from https://github.com/facebookresearch/sam3 "
273                "and then rerun your code"
274            )
275
276        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
277        self.img_size = img_size
278        self.embed_dim = embed_dim
img_size
embed_dim
def forward(self, x: torch.Tensor):
332    def forward(self, x: torch.Tensor):
333        """Apply the vision transformer to input data.
334
335        Args:
336            x: The input data.
337
338        Returns:
339            The vision transformer output.
340        """
341        x, list_from_encoder = self.forward_features(x)
342        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class CustomCompose:
349class CustomCompose:
350    def __init__(self, rescale_transform, other_transforms, src_transform):
351        self.rescale_transform = rescale_transform
352        self.other_transforms = other_transforms
353        self.src_transform = src_transform
354
355    def __call__(self, x, valid_masks=None):
356        if valid_masks is not None:
357            nodata = (x * (1 - valid_masks.float())).max()
358        x_aug = self.rescale_transform(x)
359        parms = self.rescale_transform._params
360
361        # sanity check, comment if this is working
362        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
363        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
364
365        if valid_masks is not None:
366            valid_masks = x_aug != nodata
367            _, c, h, w = x_aug.shape
368            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
369        else:
370            zero_ratio = -1
371
372        if self.other_transforms:
373            x_aug = self.other_transforms(x_aug)
374        x_src = self.src_transform(x_aug)
375        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
376
377        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
378        # assert (dx == dy).all()
379
380        h, w = x_aug.shape[-2:]
381        # assert h == w
382
383        return x_aug, x_src, dx / h, zero_ratio, valid_masks
CustomCompose(rescale_transform, other_transforms, src_transform)
350    def __init__(self, rescale_transform, other_transforms, src_transform):
351        self.rescale_transform = rescale_transform
352        self.other_transforms = other_transforms
353        self.src_transform = src_transform
rescale_transform
other_transforms
src_transform
def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device='cpu'):
386def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
387    """
388    grid_size: int of the grid height and width
389    res: array of size n, representing the resolution of a pixel (say, in meters),
390    return:
391    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
392    """
393    # res = torch.FloatTensor(res).to(device)
394    res = res.to(device)
395    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
396    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
397    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
398    grid = torch.stack(grid, dim=0)  # 2 x h x w
399
400    # grid = grid.reshape([2, 1, grid_size, grid_size])
401    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
402    _, n, h, w = grid.shape
403    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
404    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
405    if cls_token:
406        pos_embed = torch.cat(
407            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
408        )
409
410    return pos_embed

grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)

def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
413def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
414    assert embed_dim % 2 == 0
415
416    # use half of dimensions to encode grid_h
417    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
418    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
419
420    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
421    return emb
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
424def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
425    """
426    embed_dim: output dimension for each position
427    pos: a list of positions to be encoded: size (M,)
428    out: (M, D)
429    """
430    assert embed_dim % 2 == 0
431    # old_shape = pos
432    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
433    omega /= embed_dim / 2.0
434    omega = 1.0 / 10000**omega  # (D/2,)
435
436    pos = pos.reshape(-1)  # (M,)
437    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
438
439    emb_sin = torch.sin(out)  # (M, D/2)
440    emb_cos = torch.cos(out)  # (M, D/2)
441
442    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
443    return emb

embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)

class PatchEmbedUnSafe:
446class PatchEmbedUnSafe(PatchEmbed):
447    """Image to Patch Embedding"""
448
449    def forward(self, x):
450        B, C, H, W = x.shape
451
452        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
453        # assert H == self.img_size[0] and W == self.img_size[1], \
454        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
455
456        x = self.proj(x).flatten(2).transpose(1, 2)
457        return x

Image to Patch Embedding

def forward(self, x):
449    def forward(self, x):
450        B, C, H, W = x.shape
451
452        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
453        # assert H == self.img_size[0] and W == self.img_size[1], \
454        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
455
456        x = self.proj(x).flatten(2).transpose(1, 2)
457        return x
class ViT_ScaleMAE:
460class ViT_ScaleMAE(VisionTransformer):
461    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
462
463    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
464    the model on a different zoom factor dataset.
465    """
466
467    def __init__(
468        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
469    ):
470        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
471        self.img_size = img_size
472        self.in_chans = in_chans
473        self.depth = depth
474        self.base_resolution = base_resolution
475
476        self.patch_embed = PatchEmbedUnSafe(
477            img_size=img_size,
478            patch_size=patch_size,
479            in_chans=in_chans,
480            embed_dim=embed_dim,
481        )
482
483    def transform_inputs(self, x):
484        import kornia.augmentation as K
485        from kornia.constants import Resample
486
487        self._transforms = CustomCompose(
488            rescale_transform=K.RandomResizedCrop(
489                (448, 448),
490                ratio=(1.0, 1.0),
491                scale=(1.0, 1.0),
492                resample=Resample.BICUBIC.name,
493            ),
494            other_transforms=None,
495            src_transform=K.Resize((224, 224)),
496        )
497        x, _, ratios, _, _ = self._transforms(x)
498        input_res = ratios * self.base_resolution
499        return x, input_res
500
501    def convert_to_expected_dim(self, x):
502        inputs_ = x[:, 1:, :]  # removing the class tokens
503        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
504        rdim = inputs_.shape[1]
505        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
506        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
507        inputs_ = inputs_.permute(0, 3, 1, 2)
508        return inputs_
509
510    def forward_features(self, x):
511        x, input_res = self.transform_inputs(x)
512
513        B, _, h, w = x.shape
514        x = self.patch_embed(x)
515
516        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
517        pos_embed = get_2d_sincos_pos_embed_with_resolution(
518            x.shape[-1],
519            int(num_patches ** 0.5),
520            input_res,
521            cls_token=True,
522            device=x.device,
523        )
524
525        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
526        x = torch.cat((cls_tokens, x), dim=1)
527        x = x + pos_embed
528        x = self.pos_drop(x)
529
530        # chunks obtained for getting the projections for conjuctions with upsampling blocks
531        _chunks = int(self.depth / 4)
532        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
533
534        list_from_encoder = []
535        for i, blk in enumerate(self.blocks):
536            x = blk(x)
537            if i in chunks_for_projection:
538                list_from_encoder.append(self.convert_to_expected_dim(x))
539
540        x = self.convert_to_expected_dim(x)
541
542        return x, list_from_encoder
543
544    def forward(self, x):
545        x, list_from_encoder = self.forward_features(x)
546        return x, list_from_encoder

Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).

NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.

ViT_ScaleMAE( img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs)
467    def __init__(
468        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
469    ):
470        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
471        self.img_size = img_size
472        self.in_chans = in_chans
473        self.depth = depth
474        self.base_resolution = base_resolution
475
476        self.patch_embed = PatchEmbedUnSafe(
477            img_size=img_size,
478            patch_size=patch_size,
479            in_chans=in_chans,
480            embed_dim=embed_dim,
481        )
img_size
in_chans
depth
base_resolution
patch_embed
def transform_inputs(self, x):
483    def transform_inputs(self, x):
484        import kornia.augmentation as K
485        from kornia.constants import Resample
486
487        self._transforms = CustomCompose(
488            rescale_transform=K.RandomResizedCrop(
489                (448, 448),
490                ratio=(1.0, 1.0),
491                scale=(1.0, 1.0),
492                resample=Resample.BICUBIC.name,
493            ),
494            other_transforms=None,
495            src_transform=K.Resize((224, 224)),
496        )
497        x, _, ratios, _, _ = self._transforms(x)
498        input_res = ratios * self.base_resolution
499        return x, input_res
def convert_to_expected_dim(self, x):
501    def convert_to_expected_dim(self, x):
502        inputs_ = x[:, 1:, :]  # removing the class tokens
503        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
504        rdim = inputs_.shape[1]
505        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
506        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
507        inputs_ = inputs_.permute(0, 3, 1, 2)
508        return inputs_
def forward_features(self, x):
510    def forward_features(self, x):
511        x, input_res = self.transform_inputs(x)
512
513        B, _, h, w = x.shape
514        x = self.patch_embed(x)
515
516        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
517        pos_embed = get_2d_sincos_pos_embed_with_resolution(
518            x.shape[-1],
519            int(num_patches ** 0.5),
520            input_res,
521            cls_token=True,
522            device=x.device,
523        )
524
525        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
526        x = torch.cat((cls_tokens, x), dim=1)
527        x = x + pos_embed
528        x = self.pos_drop(x)
529
530        # chunks obtained for getting the projections for conjuctions with upsampling blocks
531        _chunks = int(self.depth / 4)
532        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
533
534        list_from_encoder = []
535        for i, blk in enumerate(self.blocks):
536            x = blk(x)
537            if i in chunks_for_projection:
538                list_from_encoder.append(self.convert_to_expected_dim(x))
539
540        x = self.convert_to_expected_dim(x)
541
542        return x, list_from_encoder
def forward(self, x):
544    def forward(self, x):
545        x, list_from_encoder = self.forward_features(x)
546        return x, list_from_encoder
class ViT_DINOv2:
549class ViT_DINOv2(DinoV2VisionTransformer):
550    """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
551
552    Based on:
553    https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
554
555    Args:
556        img_size: The input image size.
557        patch_size: The patch size.
558        depth: The depth of the network.
559        num_register_tokens: The number of registers added (in addition to the class tokens).
560            It's important to know for ViTs trained with registers, to remove them at the end.
561    """
562    def __init__(
563        self,
564        img_size: int = 224,
565        patch_size: int = 16,
566        depth: int = 12,
567        num_register_tokens: int = 0,
568        **kwargs
569    ):
570        if not _dinov2_import_success:
571            raise RuntimeError(
572                "The vision transformer backend can only be initialized if DINOv2 is installed. "
573                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
574                "and then rerun your code."
575            )
576
577        super().__init__(
578            img_size=img_size,
579            depth=depth,
580            patch_size=patch_size,
581            num_register_tokens=num_register_tokens,
582            **kwargs
583        )
584
585        self.img_size = img_size
586        self.num_register_tokens = num_register_tokens
587        self.patch_size = patch_size
588        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
589
590    def forward(self, x, masks=None) -> torch.Tensor:
591
592        B = x.shape[0]
593
594        x = self.prepare_tokens_with_masks(x)
595
596        list_of_encoder = []
597        for i, blk in enumerate(self.blocks):
598            x = blk(x)
599            if i in self.attn_outs:
600                list_of_encoder.append(x)
601
602        x = self.norm(x)
603        x = x[:, self.num_register_tokens + 1:].reshape(
604            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
605        ).permute(0, 3, 1, 2).contiguous()
606
607        list_of_encoder = [
608            o[:, self.num_register_tokens + 1:].reshape(
609                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
610            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
611        ]
612
613        return x, list_of_encoder[:3]

Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).

Based on: https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.

Arguments:
  • img_size: The input image size.
  • patch_size: The patch size.
  • depth: The depth of the network.
  • num_register_tokens: The number of registers added (in addition to the class tokens). It's important to know for ViTs trained with registers, to remove them at the end.
ViT_DINOv2( img_size: int = 224, patch_size: int = 16, depth: int = 12, num_register_tokens: int = 0, **kwargs)
562    def __init__(
563        self,
564        img_size: int = 224,
565        patch_size: int = 16,
566        depth: int = 12,
567        num_register_tokens: int = 0,
568        **kwargs
569    ):
570        if not _dinov2_import_success:
571            raise RuntimeError(
572                "The vision transformer backend can only be initialized if DINOv2 is installed. "
573                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
574                "and then rerun your code."
575            )
576
577        super().__init__(
578            img_size=img_size,
579            depth=depth,
580            patch_size=patch_size,
581            num_register_tokens=num_register_tokens,
582            **kwargs
583        )
584
585        self.img_size = img_size
586        self.num_register_tokens = num_register_tokens
587        self.patch_size = patch_size
588        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
img_size
num_register_tokens
patch_size
attn_outs
def forward(self, x, masks=None) -> torch.Tensor:
590    def forward(self, x, masks=None) -> torch.Tensor:
591
592        B = x.shape[0]
593
594        x = self.prepare_tokens_with_masks(x)
595
596        list_of_encoder = []
597        for i, blk in enumerate(self.blocks):
598            x = blk(x)
599            if i in self.attn_outs:
600                list_of_encoder.append(x)
601
602        x = self.norm(x)
603        x = x[:, self.num_register_tokens + 1:].reshape(
604            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
605        ).permute(0, 3, 1, 2).contiguous()
606
607        list_of_encoder = [
608            o[:, self.num_register_tokens + 1:].reshape(
609                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
610            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
611        ]
612
613        return x, list_of_encoder[:3]
class ViT_DINOv3:
616class ViT_DINOv3(DinoV3VisionTransformer):
617    """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
618
619    Based on:
620    https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
621
622    Args:
623        img_size: The input image size.
624        patch_size: The patch size.
625        embed_dim: The embedding dimension.
626        depth: The depth of the network.
627        num_heads: The number of heads.
628        ffn_ratio: The FFN rato.
629        n_storage_tokens: The number of storage (class) tokens to remove.
630        kwargs: Keyword arguments for the image encoder base class.
631    """
632    def __init__(
633        self,
634        in_chans: int = 3,
635        img_size: int = 224,
636        patch_size: int = 16,
637        embed_dim: int = 768,
638        depth: int = 12,
639        num_heads: int = 12,
640        ffn_ratio: float = 4.0,
641        n_storage_tokens: int = 0,
642        **kwargs
643    ):
644        if not _dinov3_import_success:
645            raise RuntimeError(
646                "The vision transformer backend can only be initialized if DINOv3 is installed. "
647                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
648                "and then rerun your code."
649            )
650
651        super().__init__(
652            in_chans=in_chans,
653            img_size=img_size,
654            patch_size=patch_size,
655            embed_dim=embed_dim,
656            depth=depth,
657            num_heads=num_heads,
658            ffn_ratio=ffn_ratio,
659            n_storage_tokens=n_storage_tokens,
660            **kwargs
661        )
662
663        self.in_chans = in_chans
664        self.img_size = img_size
665        self.n_storage_tokens = n_storage_tokens
666        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
667
668    def forward(self, x) -> torch.Tensor:
669
670        B = x.shape[0]
671
672        x, hw_tuple = self.prepare_tokens_with_masks(x)
673
674        list_of_encoder = []
675        for i, blk in enumerate(self.blocks):
676            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
677            x = blk(x, rope_sincos)
678            if i in self.attn_outs:
679                list_of_encoder.append(x)
680
681        x = self.norm(x)
682        x = x[:, self.n_storage_tokens + 1:].reshape(
683            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
684        ).permute(0, 3, 1, 2).contiguous()
685
686        list_of_encoder = [
687            o[:, self.n_storage_tokens + 1:].reshape(
688                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
689            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
690        ]
691
692        return x, list_of_encoder[:3]

Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).

Based on: https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.

Arguments:
  • img_size: The input image size.
  • patch_size: The patch size.
  • embed_dim: The embedding dimension.
  • depth: The depth of the network.
  • num_heads: The number of heads.
  • ffn_ratio: The FFN rato.
  • n_storage_tokens: The number of storage (class) tokens to remove.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_DINOv3( in_chans: int = 3, img_size: int = 224, patch_size: int = 16, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, ffn_ratio: float = 4.0, n_storage_tokens: int = 0, **kwargs)
632    def __init__(
633        self,
634        in_chans: int = 3,
635        img_size: int = 224,
636        patch_size: int = 16,
637        embed_dim: int = 768,
638        depth: int = 12,
639        num_heads: int = 12,
640        ffn_ratio: float = 4.0,
641        n_storage_tokens: int = 0,
642        **kwargs
643    ):
644        if not _dinov3_import_success:
645            raise RuntimeError(
646                "The vision transformer backend can only be initialized if DINOv3 is installed. "
647                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
648                "and then rerun your code."
649            )
650
651        super().__init__(
652            in_chans=in_chans,
653            img_size=img_size,
654            patch_size=patch_size,
655            embed_dim=embed_dim,
656            depth=depth,
657            num_heads=num_heads,
658            ffn_ratio=ffn_ratio,
659            n_storage_tokens=n_storage_tokens,
660            **kwargs
661        )
662
663        self.in_chans = in_chans
664        self.img_size = img_size
665        self.n_storage_tokens = n_storage_tokens
666        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
in_chans
img_size
n_storage_tokens
attn_outs
def forward(self, x) -> torch.Tensor:
668    def forward(self, x) -> torch.Tensor:
669
670        B = x.shape[0]
671
672        x, hw_tuple = self.prepare_tokens_with_masks(x)
673
674        list_of_encoder = []
675        for i, blk in enumerate(self.blocks):
676            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
677            x = blk(x, rope_sincos)
678            if i in self.attn_outs:
679                list_of_encoder.append(x)
680
681        x = self.norm(x)
682        x = x[:, self.n_storage_tokens + 1:].reshape(
683            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
684        ).permute(0, 3, 1, 2).contiguous()
685
686        list_of_encoder = [
687            o[:, self.n_storage_tokens + 1:].reshape(
688                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
689            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
690        ]
691
692        return x, list_of_encoder[:3]
def get_vision_transformer( backbone: str, model: str, img_size: int = 1024, **kwargs) -> torch.nn.modules.module.Module:
695def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
696    """Get vision transformer encoder.
697
698    Args:
699        backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
700        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
701        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
702        kwargs: Additional kwargs which can be expected by the vision transformer,
703            e.g. 'base_resolution' for `ViT_ScaleMAE`.
704
705    Returns:
706        The vision transformer.
707    """
708    if backbone == "sam":
709        if model == "vit_b":
710            encoder = ViT_Sam(
711                depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4,
712                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
713                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
714                global_attn_indexes=[2, 5, 8, 11],
715                window_size=14, out_chans=256,
716            )
717        elif model == "vit_l":
718            encoder = ViT_Sam(
719                depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4,
720                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
721                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
722                global_attn_indexes=[5, 11, 17, 23],
723                window_size=14, out_chans=256,
724            )
725        elif model == "vit_h":
726            encoder = ViT_Sam(
727                depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4,
728                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
729                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
730                global_attn_indexes=[7, 15, 23, 31],
731                window_size=14, out_chans=256,
732            )
733        else:
734            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
735
736    elif backbone == "sam2":
737        if model == "hvit_t":
738            encoder = ViT_Sam2(
739                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
740                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
741            )
742        elif model == "hvit_s":
743            encoder = ViT_Sam2(
744                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
745                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
746            )
747        elif model == "hvit_b":
748            encoder = ViT_Sam2(
749                img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
750            )
751        elif model == "hvit_l":
752            encoder = ViT_Sam2(
753                img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
754                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
755            )
756        else:
757            raise ValueError(
758                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
759            )
760
761    elif backbone == "sam3":
762        if model != "vit_pe":
763            raise ValueError(
764                "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration."
765            )
766
767        encoder = ViT_Sam3(
768            img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16,
769            mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True,
770            tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True,
771            use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True,
772            ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None,
773        )
774
775    elif backbone == "mae":
776        if model == "vit_b":
777            encoder = ViT_MAE(
778                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
779                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
780            )
781        elif model == "vit_l":
782            encoder = ViT_MAE(
783                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
784                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
785            )
786        elif model == "vit_h":
787            encoder = ViT_MAE(
788                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
789                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
790            )
791        else:
792            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
793
794    elif backbone == "scalemae":
795        base_resolution = kwargs.get("base_resolution", 2.5)
796
797        if model == "vit_b":
798            encoder = ViT_ScaleMAE(
799                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
800                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
801                base_resolution=base_resolution,
802            )
803        elif model == "vit_l":
804            encoder = ViT_ScaleMAE(
805                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
806                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
807                base_resolution=base_resolution,
808            )
809        elif model == "vit_h":
810            encoder = ViT_ScaleMAE(
811                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
812                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
813                base_resolution=base_resolution,
814            )
815        else:
816            raise ValueError(
817                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
818            )
819
820    elif backbone == "dinov2":
821        block_fn = partial(Block, attn_class=MemEffAttention)
822        msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>."
823
824        if model.startswith("vit_s"):
825            assert model in ["vit_s", "vit_s_reg4"], msg
826            encoder = ViT_DINOv2(
827                img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
828                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
829                num_register_tokens=4 if model.endswith("_reg4") else 0,
830            )
831        elif model.startswith("vit_b"):
832            assert model in ["vit_b", "vit_b_reg4"], msg
833            encoder = ViT_DINOv2(
834                img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
835                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
836                num_register_tokens=4 if model.endswith("_reg4") else 0,
837            )
838        elif model.startswith("vit_l"):
839            assert model in ["vit_l", "vit_l_reg4"], msg
840            encoder = ViT_DINOv2(
841                img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
842                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
843                num_register_tokens=4 if model.endswith("_reg4") else 0,
844            )
845        elif model.startswith("vit_g"):
846            assert model in ["vit_g", "vit_g_reg4"], msg
847            encoder = ViT_DINOv2(
848                img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4,
849                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
850                num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu",
851            )
852        else:
853            raise ValueError(
854                f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported."
855            )
856
857    elif backbone == "dinov3":
858
859        if model == "vit_s":
860            encoder = ViT_DINOv3(
861                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
862                num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
863            )
864        elif model == "vit_s+":
865            encoder = ViT_DINOv3(
866                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
867                num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16",
868                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
869            )
870
871        elif model == "vit_b":
872            encoder = ViT_DINOv3(
873                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32",
874                layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
875            )
876        elif model == "vit_l":
877            encoder = ViT_DINOv3(
878                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
879                depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16",
880                n_storage_tokens=4, mask_k_bias=True,
881            )
882        elif model == "vit_l+":
883            encoder = ViT_DINOv3(
884                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
885                depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
886                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
887            )
888        elif model == "vit_h+":
889            encoder = ViT_DINOv3(
890                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280,
891                depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
892                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
893            )
894        elif model == "vit_7b":
895            encoder = ViT_DINOv3(
896                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096,
897                depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05,
898                norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True,
899                untie_global_and_local_cls_norm=True,
900            )
901        else:
902            raise ValueError(
903                f"'{model}' is not supported by DINOv3. Currently, "
904                " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported."
905            )
906
907    else:
908        raise ValueError(
909            "The 'UNETR' supported backbones are 'sam', 'sam2', 'sam3', 'mae', 'scalemae', 'dinov2' or 'dinov3'. "
910            "Please choose one of them."
911        )
912
913    return encoder

Get vision transformer encoder.

Arguments:
  • backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
  • model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • kwargs: Additional kwargs which can be expected by the vision transformer, e.g. 'base_resolution' for ViT_ScaleMAE.
Returns:

The vision transformer.