torch_em.model.vit

   1import math
   2from functools import partial
   3from typing import Tuple, List
   4
   5import torch
   6import torch.nn as nn
   7
   8# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should
   9# only be optional dependencies for torch_em
  10try:
  11    from segment_anything.modeling import ImageEncoderViT
  12    _sam_import_success = True
  13except ImportError:
  14    ImageEncoderViT = object
  15    _sam_import_success = False
  16
  17try:
  18    from timm.models.vision_transformer import VisionTransformer, PatchEmbed
  19    _timm_import_success = True
  20except ImportError:
  21    VisionTransformer = object
  22    PatchEmbed = object
  23    _timm_import_success = False
  24
  25try:
  26    from sam2.modeling.backbones.hieradet import Hiera
  27    from sam2.modeling.position_encoding import PositionEmbeddingSine
  28    from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck
  29    _sam2_import_success = True
  30except ImportError:
  31    ImageEncoder = object
  32    _sam2_import_success = False
  33
  34try:
  35    from dinov2.models.vision_transformer import DinoVisionTransformer as DinoV2VisionTransformer
  36    from dinov2.layers import MemEffAttention, NestedTensorBlock as Block
  37    _dinov2_import_success = True
  38except ImportError:
  39    DinoV2VisionTransformer = object
  40    _dinov2_import_success = False
  41
  42try:
  43    from dinov3.models.vision_transformer import DinoVisionTransformer as DinoV3VisionTransformer
  44    _dinov3_import_success = True
  45except ImportError:
  46    DinoV3VisionTransformer = object
  47    _dinov3_import_success = False
  48
  49
  50try:
  51    from sam3.model.vitdet import ViT as SAM3ViT, get_abs_pos
  52    _sam3_import_success = True
  53except ImportError:
  54    SAM3ViT = object
  55    _sam3_import_success = False
  56
  57
  58class ViT_Sam(ImageEncoderViT):
  59    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
  60
  61    Based on:
  62    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
  63
  64    Args:
  65        in_chans: The number of input channels.
  66        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  67        global_attn_indexes: The global attention indices.
  68        apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head.
  69        kwargs: Keyword arguments for the image encoder base class.
  70    """
  71    def __init__(
  72        self,
  73        in_chans: int = 3,
  74        embed_dim: int = 768,
  75        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
  76        apply_neck: bool = False,
  77        **kwargs,
  78    ) -> None:
  79        if not _sam_import_success:
  80            raise RuntimeError(
  81                "The vision transformer backend can only be initialized if segment anything is installed. "
  82                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
  83                "and then rerun your code."
  84            )
  85
  86        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
  87        self.chunks_for_projection = global_attn_indexes
  88        self.in_chans = in_chans
  89        self.embed_dim = embed_dim
  90        self.apply_neck = apply_neck
  91
  92    def forward(self, x: torch.Tensor) -> torch.Tensor:
  93        """Apply the vision transformer to input data.
  94
  95        Args:
  96            x: The input data.
  97
  98        Returns:
  99            The vision transformer output.
 100        """
 101        x = self.patch_embed(x)
 102        if self.pos_embed is not None:
 103            x = x + self.pos_embed
 104
 105        list_from_encoder = []
 106        for i, blk in enumerate(self.blocks):
 107            x = blk(x)
 108            if i in self.chunks_for_projection:
 109                list_from_encoder.append(x)
 110
 111        x = x.permute(0, 3, 1, 2)
 112
 113        if self.apply_neck:
 114            x = self.neck(x)
 115
 116        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
 117        return x, list_from_encoder[:3]
 118
 119
 120class ViT_CellposeSAM(nn.Module):
 121    """Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x).
 122
 123    This replicates CellposeSAM's actual initialization: instantiate SAM's ``ImageEncoderViT`` via
 124    ``sam_model_registry``, then modify the patch embedding, position embeddings, and set global attention.
 125    This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading
 126    without any interpolation.
 127
 128    Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py
 129
 130    NOTE: The pretrained CellposeSAM model uses ``vit_l`` exclusively.
 131
 132    Args:
 133        ps: The patch size (default for CellposeSAM is 8).
 134        bsize: The input image size (default for CellposeSAM is 256).
 135        apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head.
 136    """
 137    def __init__(self, ps: int = 8, bsize: int = 256, apply_neck: bool = True) -> None:
 138        super().__init__()
 139
 140        if not _sam_import_success:
 141            raise RuntimeError(
 142                "The vision transformer backend can only be initialized if segment anything is installed. "
 143                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
 144                "and then rerun your code."
 145            )
 146
 147        from segment_anything import sam_model_registry
 148
 149        # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer).
 150        encoder = sam_model_registry["vit_l"](None).image_encoder
 151
 152        w = encoder.patch_embed.proj.weight.detach()
 153        nchan = w.shape[0]
 154
 155        # CellPoseSAM changes the patch size from 16 to 'ps'.
 156        encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
 157        encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps]
 158
 159        # Next, they subsample position embeddings for the new patch size and input resolution.
 160        ds = (1024 // 16) // (bsize // ps)
 161        encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True)
 162
 163        # Finally, they set all blocks to global attention.
 164        for blk in encoder.blocks:
 165            blk.window_size = 0
 166
 167        # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping).
 168        self.patch_embed = encoder.patch_embed
 169        self.pos_embed = encoder.pos_embed
 170        self.blocks = encoder.blocks
 171        self.neck = encoder.neck
 172
 173        # Additional attributes expected by UNETR.
 174        self.embed_dim = nchan
 175        self.img_size = bsize
 176        self.in_chans = 3
 177        self.apply_neck = apply_neck
 178
 179        # Feature extraction at evenly-spaced depths.
 180        depth = len(self.blocks)
 181        _chunks = depth // 4
 182        self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1]
 183
 184    def forward(self, x: torch.Tensor) -> torch.Tensor:
 185        """Apply the vision transformer to input data.
 186
 187        Args:
 188            x: The input data.
 189
 190        Returns:
 191            The vision transformer output.
 192        """
 193        x = self.patch_embed(x)
 194        if self.pos_embed is not None:
 195            x = x + self.pos_embed
 196
 197        list_from_encoder = []
 198        for i, blk in enumerate(self.blocks):
 199            x = blk(x)
 200            if i in self.chunks_for_projection:
 201                list_from_encoder.append(x)
 202
 203        x = x.permute(0, 3, 1, 2)
 204
 205        if self.apply_neck:
 206            x = self.neck(x)
 207
 208        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
 209        return x, list_from_encoder[:3]
 210
 211
 212class ViT_MAE(VisionTransformer):
 213    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
 214
 215    Based on:
 216    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
 217
 218    Args:
 219        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 220        in_chans: The number of input channels.
 221        depth: The depth of the vision transformer.
 222        kwargs: Additional keyword arguments for the vision transformer base class.
 223    """
 224    def __init__(
 225        self,
 226        img_size: int = 1024,  # chosen to match our experiments with segment anything
 227        in_chans: int = 3,
 228        depth: int = 12,
 229        **kwargs
 230    ):
 231        if not _timm_import_success:
 232            raise RuntimeError(
 233                "The vision transformer backend can only be initialized if timm is installed. "
 234                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
 235                "and then rerun your code"
 236            )
 237        super().__init__(img_size=img_size, depth=depth, **kwargs)
 238        self.img_size = img_size
 239        self.in_chans = in_chans
 240        self.depth = depth
 241
 242    def convert_to_expected_dim(self, inputs_):
 243        """@private
 244        """
 245        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
 246        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
 247        rdim = inputs_.shape[1]
 248        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
 249        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
 250        inputs_ = inputs_.permute(0, 3, 1, 2)
 251        return inputs_
 252
 253    def forward_features(self, x):
 254        """@private
 255        """
 256        B = x.shape[0]
 257        x = self.patch_embed(x)
 258
 259        cls_tokens = self.cls_token.expand(B, -1, -1)
 260        x = torch.cat((cls_tokens, x), dim=1)
 261
 262        x = x + self.pos_embed
 263        x = self.pos_drop(x)
 264
 265        # chunks obtained for getting the projections for conjuctions with upsampling blocks
 266        _chunks = int(self.depth / 4)
 267        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
 268
 269        list_from_encoder = []
 270        for i, blk in enumerate(self.blocks):
 271            x = blk(x)
 272            if i in chunks_for_projection:
 273                list_from_encoder.append(self.convert_to_expected_dim(x))
 274
 275        x = self.convert_to_expected_dim(x)
 276        return x, list_from_encoder[:3]
 277
 278    def forward(self, x: torch.Tensor) -> torch.Tensor:
 279        """Apply the vision transformer to input data.
 280
 281        Args:
 282            x: The input data.
 283
 284        Returns:
 285            The vision transformer output.
 286        """
 287        x, list_from_encoder = self.forward_features(x)
 288        return x, list_from_encoder
 289
 290
 291class ViT_Sam2(ImageEncoder):
 292    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
 293
 294    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
 295
 296    Args:
 297        backbone_channel_list: The channels throughout the entire backbone.
 298        embed_dim: The initial embedding dimension.
 299        num_heads: The initial number of heads.
 300        stages: The number of blocks per stage.
 301        global_att_blocks: The parameter to decide which blocks have global attention.
 302        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
 303        window_spec: The window size per stage, when not using global attention.
 304        scalp: The count of lowest resolution features to discard.
 305    """
 306    def __init__(
 307        self,
 308        backbone_channel_list: List[int],
 309        img_size: int = 1024,
 310        embed_dim: int = 96,
 311        num_heads: int = 1,
 312        stages: Tuple[int, ...] = (2, 3, 16, 3),
 313        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
 314        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
 315        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
 316        scalp: int = 1,
 317        **kwargs
 318    ):
 319        if not _sam2_import_success:
 320            raise RuntimeError(
 321                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
 322                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
 323                "and then rerun your code"
 324            )
 325
 326        trunk = Hiera(
 327            embed_dim=embed_dim,
 328            num_heads=num_heads,
 329            stages=stages,
 330            global_att_blocks=global_att_blocks,
 331            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
 332            window_spec=window_spec,
 333        )
 334        neck = FpnNeck(
 335            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
 336            d_model=256,
 337            backbone_channel_list=backbone_channel_list,
 338            fpn_top_down_levels=[2, 3],
 339            fpn_interp_model="nearest",
 340        )
 341
 342        super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs)
 343        self.scalp = scalp
 344        self.embed_dim = embed_dim
 345        self.img_size = img_size
 346
 347    def forward(self, x: torch.Tensor):
 348        # The forward pass throught the backbone.
 349        features, pos = self.neck(self.trunk(x))
 350        if self.scalp > 0:  # This discard the "n" lowest resolution features.
 351            features, pos = features[:-self.scalp], pos[:-self.scalp]
 352
 353        return features[-1], features
 354
 355
 356class ViT_Sam3(SAM3ViT):
 357    """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).
 358
 359    Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.
 360
 361    Args:
 362        img_size: The input image size.
 363        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 364        kwargs: Keyword arguments for the image encoder base class.
 365    """
 366    def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs):
 367        if not _sam3_import_success:
 368            raise RuntimeError(
 369                "The vision transformer backend can only be initialized if segment anything 3 is installed. "
 370                "Please install segment anything 3 from https://github.com/facebookresearch/sam3 "
 371                "and then rerun your code"
 372            )
 373
 374        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
 375        self.img_size = img_size
 376        self.embed_dim = embed_dim
 377
 378    def forward_features(self, x):
 379        """@private
 380        """
 381        x = self.patch_embed(x)
 382        h, w = x.shape[1], x.shape[2]
 383
 384        s = 0
 385        if self.retain_cls_token:
 386            # If the 'cls_token' is retained, we don't maintain the spatial shape.
 387            x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
 388            s = 1
 389
 390        if self.pos_embed is not None:
 391            x = x + get_abs_pos(
 392                self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos,
 393            )
 394
 395        x = self.ln_pre(x)
 396
 397        list_from_encoder = []
 398        for i, blk in enumerate(self.blocks):
 399            if self.use_act_checkpoint and self.training:
 400                x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False)
 401            else:
 402                x = blk(x)
 403
 404            x = self._convert_to_expected_dim(x, i, s)
 405
 406            if i in self.full_attn_ids:
 407                list_from_encoder.append(x)
 408
 409        return x, list_from_encoder
 410
 411    def _convert_to_expected_dim(self, x, i, s):
 412        if (i == self.full_attn_ids[-1]) or (
 413            self.return_interm_layers and i in self.full_attn_ids
 414        ):
 415            if i == self.full_attn_ids[-1]:
 416                x = self.ln_post(x)
 417
 418            feats = x[:, s:]
 419            if feats.ndim == 4:
 420                feats = feats.permute(0, 3, 1, 2)
 421            else:
 422                assert feats.ndim == 3
 423                h = w = math.sqrt(feats.shape[1])
 424                feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
 425            return feats
 426
 427        else:
 428            return x
 429
 430    def forward(self, x: torch.Tensor):
 431        """Apply the vision transformer to input data.
 432
 433        Args:
 434            x: The input data.
 435
 436        Returns:
 437            The vision transformer output.
 438        """
 439        x, list_from_encoder = self.forward_features(x)
 440        return x, list_from_encoder
 441
 442#
 443# Utilities for ScaleMAE's ViT
 444#
 445
 446
 447class CustomCompose:
 448    def __init__(self, rescale_transform, other_transforms, src_transform):
 449        self.rescale_transform = rescale_transform
 450        self.other_transforms = other_transforms
 451        self.src_transform = src_transform
 452
 453    def __call__(self, x, valid_masks=None):
 454        if valid_masks is not None:
 455            nodata = (x * (1 - valid_masks.float())).max()
 456        x_aug = self.rescale_transform(x)
 457        parms = self.rescale_transform._params
 458
 459        # sanity check, comment if this is working
 460        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
 461        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
 462
 463        if valid_masks is not None:
 464            valid_masks = x_aug != nodata
 465            _, c, h, w = x_aug.shape
 466            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
 467        else:
 468            zero_ratio = -1
 469
 470        if self.other_transforms:
 471            x_aug = self.other_transforms(x_aug)
 472        x_src = self.src_transform(x_aug)
 473        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
 474
 475        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
 476        # assert (dx == dy).all()
 477
 478        h, w = x_aug.shape[-2:]
 479        # assert h == w
 480
 481        return x_aug, x_src, dx / h, zero_ratio, valid_masks
 482
 483
 484def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
 485    """
 486    grid_size: int of the grid height and width
 487    res: array of size n, representing the resolution of a pixel (say, in meters),
 488    return:
 489    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
 490    """
 491    # res = torch.FloatTensor(res).to(device)
 492    res = res.to(device)
 493    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
 494    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
 495    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
 496    grid = torch.stack(grid, dim=0)  # 2 x h x w
 497
 498    # grid = grid.reshape([2, 1, grid_size, grid_size])
 499    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
 500    _, n, h, w = grid.shape
 501    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
 502    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
 503    if cls_token:
 504        pos_embed = torch.cat(
 505            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
 506        )
 507
 508    return pos_embed
 509
 510
 511def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
 512    assert embed_dim % 2 == 0
 513
 514    # use half of dimensions to encode grid_h
 515    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
 516    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
 517
 518    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
 519    return emb
 520
 521
 522def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
 523    """
 524    embed_dim: output dimension for each position
 525    pos: a list of positions to be encoded: size (M,)
 526    out: (M, D)
 527    """
 528    assert embed_dim % 2 == 0
 529    # old_shape = pos
 530    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
 531    omega /= embed_dim / 2.0
 532    omega = 1.0 / 10000**omega  # (D/2,)
 533
 534    pos = pos.reshape(-1)  # (M,)
 535    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
 536
 537    emb_sin = torch.sin(out)  # (M, D/2)
 538    emb_cos = torch.cos(out)  # (M, D/2)
 539
 540    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
 541    return emb
 542
 543
 544class PatchEmbedUnSafe(PatchEmbed):
 545    """Image to Patch Embedding"""
 546
 547    def forward(self, x):
 548        B, C, H, W = x.shape
 549
 550        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
 551        # assert H == self.img_size[0] and W == self.img_size[1], \
 552        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
 553
 554        x = self.proj(x).flatten(2).transpose(1, 2)
 555        return x
 556
 557
 558class ViT_ScaleMAE(VisionTransformer):
 559    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
 560
 561    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
 562    the model on a different zoom factor dataset.
 563    """
 564
 565    def __init__(
 566        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
 567    ):
 568        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
 569        self.img_size = img_size
 570        self.in_chans = in_chans
 571        self.depth = depth
 572        self.base_resolution = base_resolution
 573
 574        self.patch_embed = PatchEmbedUnSafe(
 575            img_size=img_size,
 576            patch_size=patch_size,
 577            in_chans=in_chans,
 578            embed_dim=embed_dim,
 579        )
 580
 581    def transform_inputs(self, x):
 582        import kornia.augmentation as K
 583        from kornia.constants import Resample
 584
 585        self._transforms = CustomCompose(
 586            rescale_transform=K.RandomResizedCrop(
 587                (448, 448),
 588                ratio=(1.0, 1.0),
 589                scale=(1.0, 1.0),
 590                resample=Resample.BICUBIC.name,
 591            ),
 592            other_transforms=None,
 593            src_transform=K.Resize((224, 224)),
 594        )
 595        x, _, ratios, _, _ = self._transforms(x)
 596        input_res = ratios * self.base_resolution
 597        return x, input_res
 598
 599    def convert_to_expected_dim(self, x):
 600        inputs_ = x[:, 1:, :]  # removing the class tokens
 601        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
 602        rdim = inputs_.shape[1]
 603        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
 604        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
 605        inputs_ = inputs_.permute(0, 3, 1, 2)
 606        return inputs_
 607
 608    def forward_features(self, x):
 609        x, input_res = self.transform_inputs(x)
 610
 611        B, _, h, w = x.shape
 612        x = self.patch_embed(x)
 613
 614        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
 615        pos_embed = get_2d_sincos_pos_embed_with_resolution(
 616            x.shape[-1],
 617            int(num_patches ** 0.5),
 618            input_res,
 619            cls_token=True,
 620            device=x.device,
 621        )
 622
 623        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
 624        x = torch.cat((cls_tokens, x), dim=1)
 625        x = x + pos_embed
 626        x = self.pos_drop(x)
 627
 628        # chunks obtained for getting the projections for conjuctions with upsampling blocks
 629        _chunks = int(self.depth / 4)
 630        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
 631
 632        list_from_encoder = []
 633        for i, blk in enumerate(self.blocks):
 634            x = blk(x)
 635            if i in chunks_for_projection:
 636                list_from_encoder.append(self.convert_to_expected_dim(x))
 637
 638        x = self.convert_to_expected_dim(x)
 639
 640        return x, list_from_encoder
 641
 642    def forward(self, x):
 643        x, list_from_encoder = self.forward_features(x)
 644        return x, list_from_encoder
 645
 646
 647class ViT_DINOv2(DinoV2VisionTransformer):
 648    """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
 649
 650    Based on:
 651    https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
 652
 653    Args:
 654        img_size: The input image size.
 655        patch_size: The patch size.
 656        depth: The depth of the network.
 657        num_register_tokens: The number of registers added (in addition to the class tokens).
 658            It's important to know for ViTs trained with registers, to remove them at the end.
 659    """
 660    def __init__(
 661        self,
 662        img_size: int = 224,
 663        patch_size: int = 16,
 664        depth: int = 12,
 665        num_register_tokens: int = 0,
 666        **kwargs
 667    ):
 668        if not _dinov2_import_success:
 669            raise RuntimeError(
 670                "The vision transformer backend can only be initialized if DINOv2 is installed. "
 671                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
 672                "and then rerun your code."
 673            )
 674
 675        super().__init__(
 676            img_size=img_size,
 677            depth=depth,
 678            patch_size=patch_size,
 679            num_register_tokens=num_register_tokens,
 680            **kwargs
 681        )
 682
 683        self.img_size = img_size
 684        self.num_register_tokens = num_register_tokens
 685        self.patch_size = patch_size
 686        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
 687
 688    def forward(self, x, masks=None) -> torch.Tensor:
 689
 690        B = x.shape[0]
 691
 692        x = self.prepare_tokens_with_masks(x)
 693
 694        list_of_encoder = []
 695        for i, blk in enumerate(self.blocks):
 696            x = blk(x)
 697            if i in self.attn_outs:
 698                list_of_encoder.append(x)
 699
 700        x = self.norm(x)
 701        x = x[:, self.num_register_tokens + 1:].reshape(
 702            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
 703        ).permute(0, 3, 1, 2).contiguous()
 704
 705        list_of_encoder = [
 706            o[:, self.num_register_tokens + 1:].reshape(
 707                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
 708            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
 709        ]
 710
 711        return x, list_of_encoder[:3]
 712
 713
 714class ViT_DINOv3(DinoV3VisionTransformer):
 715    """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
 716
 717    Based on:
 718    https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
 719
 720    Args:
 721        img_size: The input image size.
 722        patch_size: The patch size.
 723        embed_dim: The embedding dimension.
 724        depth: The depth of the network.
 725        num_heads: The number of heads.
 726        ffn_ratio: The FFN rato.
 727        n_storage_tokens: The number of storage (class) tokens to remove.
 728        kwargs: Keyword arguments for the image encoder base class.
 729    """
 730    def __init__(
 731        self,
 732        in_chans: int = 3,
 733        img_size: int = 224,
 734        patch_size: int = 16,
 735        embed_dim: int = 768,
 736        depth: int = 12,
 737        num_heads: int = 12,
 738        ffn_ratio: float = 4.0,
 739        n_storage_tokens: int = 0,
 740        **kwargs
 741    ):
 742        if not _dinov3_import_success:
 743            raise RuntimeError(
 744                "The vision transformer backend can only be initialized if DINOv3 is installed. "
 745                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
 746                "and then rerun your code."
 747            )
 748
 749        super().__init__(
 750            in_chans=in_chans,
 751            img_size=img_size,
 752            patch_size=patch_size,
 753            embed_dim=embed_dim,
 754            depth=depth,
 755            num_heads=num_heads,
 756            ffn_ratio=ffn_ratio,
 757            n_storage_tokens=n_storage_tokens,
 758            **kwargs
 759        )
 760
 761        self.in_chans = in_chans
 762        self.img_size = img_size
 763        self.n_storage_tokens = n_storage_tokens
 764        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
 765
 766    def forward(self, x) -> torch.Tensor:
 767
 768        B = x.shape[0]
 769
 770        x, hw_tuple = self.prepare_tokens_with_masks(x)
 771
 772        list_of_encoder = []
 773        for i, blk in enumerate(self.blocks):
 774            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
 775            x = blk(x, rope_sincos)
 776            if i in self.attn_outs:
 777                list_of_encoder.append(x)
 778
 779        x = self.norm(x)
 780        x = x[:, self.n_storage_tokens + 1:].reshape(
 781            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
 782        ).permute(0, 3, 1, 2).contiguous()
 783
 784        list_of_encoder = [
 785            o[:, self.n_storage_tokens + 1:].reshape(
 786                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
 787            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
 788        ]
 789
 790        return x, list_of_encoder[:3]
 791
 792
 793def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
 794    """Get vision transformer encoder.
 795
 796    Args:
 797        backbone: The name of the vision transformer implementation.
 798            One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3".
 799        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
 800        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 801        kwargs: Additional kwargs which can be expected by the vision transformer,
 802            e.g. 'base_resolution' for `ViT_ScaleMAE`.
 803
 804    Returns:
 805        The vision transformer.
 806    """
 807    if backbone == "sam":
 808        if model == "vit_b":
 809            encoder = ViT_Sam(
 810                depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4,
 811                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 812                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
 813                global_attn_indexes=[2, 5, 8, 11],
 814                window_size=14, out_chans=256,
 815            )
 816        elif model == "vit_l":
 817            encoder = ViT_Sam(
 818                depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4,
 819                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 820                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
 821                global_attn_indexes=[5, 11, 17, 23],
 822                window_size=14, out_chans=256,
 823            )
 824        elif model == "vit_h":
 825            encoder = ViT_Sam(
 826                depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4,
 827                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 828                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
 829                global_attn_indexes=[7, 15, 23, 31],
 830                window_size=14, out_chans=256,
 831            )
 832        else:
 833            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
 834
 835    elif backbone == "cellpose_sam":
 836        if model != "vit_l":
 837            raise ValueError(f"'{model}' is not supported by CellposeSAM. Only 'vit_l' is supported.")
 838        encoder = ViT_CellposeSAM(ps=8, bsize=img_size)
 839
 840    elif backbone == "sam2":
 841        if model == "hvit_t":
 842            encoder = ViT_Sam2(
 843                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
 844                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
 845            )
 846        elif model == "hvit_s":
 847            encoder = ViT_Sam2(
 848                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
 849                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
 850            )
 851        elif model == "hvit_b":
 852            encoder = ViT_Sam2(
 853                img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
 854            )
 855        elif model == "hvit_l":
 856            encoder = ViT_Sam2(
 857                img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
 858                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
 859            )
 860        else:
 861            raise ValueError(
 862                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
 863            )
 864
 865    elif backbone == "sam3":
 866        if model != "vit_pe":
 867            raise ValueError(
 868                "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration."
 869            )
 870
 871        encoder = ViT_Sam3(
 872            img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16,
 873            mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True,
 874            tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True,
 875            use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True,
 876            ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None,
 877        )
 878
 879    elif backbone == "mae":
 880        if model == "vit_b":
 881            encoder = ViT_MAE(
 882                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
 883                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 884            )
 885        elif model == "vit_l":
 886            encoder = ViT_MAE(
 887                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
 888                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 889            )
 890        elif model == "vit_h":
 891            encoder = ViT_MAE(
 892                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
 893                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 894            )
 895        else:
 896            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
 897
 898    elif backbone == "scalemae":
 899        base_resolution = kwargs.get("base_resolution", 2.5)
 900
 901        if model == "vit_b":
 902            encoder = ViT_ScaleMAE(
 903                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
 904                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 905                base_resolution=base_resolution,
 906            )
 907        elif model == "vit_l":
 908            encoder = ViT_ScaleMAE(
 909                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
 910                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 911                base_resolution=base_resolution,
 912            )
 913        elif model == "vit_h":
 914            encoder = ViT_ScaleMAE(
 915                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
 916                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 917                base_resolution=base_resolution,
 918            )
 919        else:
 920            raise ValueError(
 921                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
 922            )
 923
 924    elif backbone == "dinov2":
 925        block_fn = partial(Block, attn_class=MemEffAttention)
 926        msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>."
 927
 928        if model.startswith("vit_s"):
 929            assert model in ["vit_s", "vit_s_reg4"], msg
 930            encoder = ViT_DINOv2(
 931                img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
 932                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 933                num_register_tokens=4 if model.endswith("_reg4") else 0,
 934            )
 935        elif model.startswith("vit_b"):
 936            assert model in ["vit_b", "vit_b_reg4"], msg
 937            encoder = ViT_DINOv2(
 938                img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
 939                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 940                num_register_tokens=4 if model.endswith("_reg4") else 0,
 941            )
 942        elif model.startswith("vit_l"):
 943            assert model in ["vit_l", "vit_l_reg4"], msg
 944            encoder = ViT_DINOv2(
 945                img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
 946                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 947                num_register_tokens=4 if model.endswith("_reg4") else 0,
 948            )
 949        elif model.startswith("vit_g"):
 950            assert model in ["vit_g", "vit_g_reg4"], msg
 951            encoder = ViT_DINOv2(
 952                img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4,
 953                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 954                num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu",
 955            )
 956        else:
 957            raise ValueError(
 958                f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported."
 959            )
 960
 961    elif backbone == "dinov3":
 962
 963        if model == "vit_s":
 964            encoder = ViT_DINOv3(
 965                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
 966                num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
 967            )
 968        elif model == "vit_s+":
 969            encoder = ViT_DINOv3(
 970                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
 971                num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 972                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 973            )
 974
 975        elif model == "vit_b":
 976            encoder = ViT_DINOv3(
 977                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32",
 978                layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
 979            )
 980        elif model == "vit_l":
 981            encoder = ViT_DINOv3(
 982                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
 983                depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 984                n_storage_tokens=4, mask_k_bias=True,
 985            )
 986        elif model == "vit_l+":
 987            encoder = ViT_DINOv3(
 988                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
 989                depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 990                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 991            )
 992        elif model == "vit_h+":
 993            encoder = ViT_DINOv3(
 994                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280,
 995                depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 996                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 997            )
 998        elif model == "vit_7b":
 999            encoder = ViT_DINOv3(
1000                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096,
1001                depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05,
1002                norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True,
1003                untie_global_and_local_cls_norm=True,
1004            )
1005        else:
1006            raise ValueError(
1007                f"'{model}' is not supported by DINOv3. Currently, "
1008                " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported."
1009            )
1010
1011    else:
1012        raise ValueError(
1013            "The 'UNETR' supported backbones are 'sam', 'cellpose_sam', 'sam2', 'sam3', "
1014            "'mae', 'scalemae', 'dinov2' or 'dinov3'. Please choose one of them."
1015        )
1016
1017    return encoder
class ViT_Sam:
 59class ViT_Sam(ImageEncoderViT):
 60    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
 61
 62    Based on:
 63    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 64
 65    Args:
 66        in_chans: The number of input channels.
 67        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 68        global_attn_indexes: The global attention indices.
 69        apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head.
 70        kwargs: Keyword arguments for the image encoder base class.
 71    """
 72    def __init__(
 73        self,
 74        in_chans: int = 3,
 75        embed_dim: int = 768,
 76        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
 77        apply_neck: bool = False,
 78        **kwargs,
 79    ) -> None:
 80        if not _sam_import_success:
 81            raise RuntimeError(
 82                "The vision transformer backend can only be initialized if segment anything is installed. "
 83                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
 84                "and then rerun your code."
 85            )
 86
 87        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
 88        self.chunks_for_projection = global_attn_indexes
 89        self.in_chans = in_chans
 90        self.embed_dim = embed_dim
 91        self.apply_neck = apply_neck
 92
 93    def forward(self, x: torch.Tensor) -> torch.Tensor:
 94        """Apply the vision transformer to input data.
 95
 96        Args:
 97            x: The input data.
 98
 99        Returns:
100            The vision transformer output.
101        """
102        x = self.patch_embed(x)
103        if self.pos_embed is not None:
104            x = x + self.pos_embed
105
106        list_from_encoder = []
107        for i, blk in enumerate(self.blocks):
108            x = blk(x)
109            if i in self.chunks_for_projection:
110                list_from_encoder.append(x)
111
112        x = x.permute(0, 3, 1, 2)
113
114        if self.apply_neck:
115            x = self.neck(x)
116
117        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
118        return x, list_from_encoder[:3]

Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).

Based on: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

Arguments:
  • in_chans: The number of input channels.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • global_attn_indexes: The global attention indices.
  • apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam( in_chans: int = 3, embed_dim: int = 768, global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], apply_neck: bool = False, **kwargs)
72    def __init__(
73        self,
74        in_chans: int = 3,
75        embed_dim: int = 768,
76        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
77        apply_neck: bool = False,
78        **kwargs,
79    ) -> None:
80        if not _sam_import_success:
81            raise RuntimeError(
82                "The vision transformer backend can only be initialized if segment anything is installed. "
83                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
84                "and then rerun your code."
85            )
86
87        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
88        self.chunks_for_projection = global_attn_indexes
89        self.in_chans = in_chans
90        self.embed_dim = embed_dim
91        self.apply_neck = apply_neck
chunks_for_projection
in_chans
embed_dim
apply_neck
def forward(self, x: torch.Tensor) -> torch.Tensor:
 93    def forward(self, x: torch.Tensor) -> torch.Tensor:
 94        """Apply the vision transformer to input data.
 95
 96        Args:
 97            x: The input data.
 98
 99        Returns:
100            The vision transformer output.
101        """
102        x = self.patch_embed(x)
103        if self.pos_embed is not None:
104            x = x + self.pos_embed
105
106        list_from_encoder = []
107        for i, blk in enumerate(self.blocks):
108            x = blk(x)
109            if i in self.chunks_for_projection:
110                list_from_encoder.append(x)
111
112        x = x.permute(0, 3, 1, 2)
113
114        if self.apply_neck:
115            x = self.neck(x)
116
117        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
118        return x, list_from_encoder[:3]

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_CellposeSAM(torch.nn.modules.module.Module):
121class ViT_CellposeSAM(nn.Module):
122    """Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x).
123
124    This replicates CellposeSAM's actual initialization: instantiate SAM's ``ImageEncoderViT`` via
125    ``sam_model_registry``, then modify the patch embedding, position embeddings, and set global attention.
126    This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading
127    without any interpolation.
128
129    Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py
130
131    NOTE: The pretrained CellposeSAM model uses ``vit_l`` exclusively.
132
133    Args:
134        ps: The patch size (default for CellposeSAM is 8).
135        bsize: The input image size (default for CellposeSAM is 256).
136        apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head.
137    """
138    def __init__(self, ps: int = 8, bsize: int = 256, apply_neck: bool = True) -> None:
139        super().__init__()
140
141        if not _sam_import_success:
142            raise RuntimeError(
143                "The vision transformer backend can only be initialized if segment anything is installed. "
144                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
145                "and then rerun your code."
146            )
147
148        from segment_anything import sam_model_registry
149
150        # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer).
151        encoder = sam_model_registry["vit_l"](None).image_encoder
152
153        w = encoder.patch_embed.proj.weight.detach()
154        nchan = w.shape[0]
155
156        # CellPoseSAM changes the patch size from 16 to 'ps'.
157        encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
158        encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps]
159
160        # Next, they subsample position embeddings for the new patch size and input resolution.
161        ds = (1024 // 16) // (bsize // ps)
162        encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True)
163
164        # Finally, they set all blocks to global attention.
165        for blk in encoder.blocks:
166            blk.window_size = 0
167
168        # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping).
169        self.patch_embed = encoder.patch_embed
170        self.pos_embed = encoder.pos_embed
171        self.blocks = encoder.blocks
172        self.neck = encoder.neck
173
174        # Additional attributes expected by UNETR.
175        self.embed_dim = nchan
176        self.img_size = bsize
177        self.in_chans = 3
178        self.apply_neck = apply_neck
179
180        # Feature extraction at evenly-spaced depths.
181        depth = len(self.blocks)
182        _chunks = depth // 4
183        self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1]
184
185    def forward(self, x: torch.Tensor) -> torch.Tensor:
186        """Apply the vision transformer to input data.
187
188        Args:
189            x: The input data.
190
191        Returns:
192            The vision transformer output.
193        """
194        x = self.patch_embed(x)
195        if self.pos_embed is not None:
196            x = x + self.pos_embed
197
198        list_from_encoder = []
199        for i, blk in enumerate(self.blocks):
200            x = blk(x)
201            if i in self.chunks_for_projection:
202                list_from_encoder.append(x)
203
204        x = x.permute(0, 3, 1, 2)
205
206        if self.apply_neck:
207            x = self.neck(x)
208
209        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
210        return x, list_from_encoder[:3]

Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x).

This replicates CellposeSAM's actual initialization: instantiate SAM's ImageEncoderViT via sam_model_registry, then modify the patch embedding, position embeddings, and set global attention. This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading without any interpolation.

Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py

NOTE: The pretrained CellposeSAM model uses vit_l exclusively.

Arguments:
  • ps: The patch size (default for CellposeSAM is 8).
  • bsize: The input image size (default for CellposeSAM is 256).
  • apply_neck: Whether to apply the convolutional bottleneck after outputs of the last attention head.
ViT_CellposeSAM(ps: int = 8, bsize: int = 256, apply_neck: bool = True)
138    def __init__(self, ps: int = 8, bsize: int = 256, apply_neck: bool = True) -> None:
139        super().__init__()
140
141        if not _sam_import_success:
142            raise RuntimeError(
143                "The vision transformer backend can only be initialized if segment anything is installed. "
144                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
145                "and then rerun your code."
146            )
147
148        from segment_anything import sam_model_registry
149
150        # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer).
151        encoder = sam_model_registry["vit_l"](None).image_encoder
152
153        w = encoder.patch_embed.proj.weight.detach()
154        nchan = w.shape[0]
155
156        # CellPoseSAM changes the patch size from 16 to 'ps'.
157        encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
158        encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps]
159
160        # Next, they subsample position embeddings for the new patch size and input resolution.
161        ds = (1024 // 16) // (bsize // ps)
162        encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True)
163
164        # Finally, they set all blocks to global attention.
165        for blk in encoder.blocks:
166            blk.window_size = 0
167
168        # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping).
169        self.patch_embed = encoder.patch_embed
170        self.pos_embed = encoder.pos_embed
171        self.blocks = encoder.blocks
172        self.neck = encoder.neck
173
174        # Additional attributes expected by UNETR.
175        self.embed_dim = nchan
176        self.img_size = bsize
177        self.in_chans = 3
178        self.apply_neck = apply_neck
179
180        # Feature extraction at evenly-spaced depths.
181        depth = len(self.blocks)
182        _chunks = depth // 4
183        self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

patch_embed
pos_embed
blocks
neck
embed_dim
img_size
in_chans
apply_neck
chunks_for_projection
def forward(self, x: torch.Tensor) -> torch.Tensor:
185    def forward(self, x: torch.Tensor) -> torch.Tensor:
186        """Apply the vision transformer to input data.
187
188        Args:
189            x: The input data.
190
191        Returns:
192            The vision transformer output.
193        """
194        x = self.patch_embed(x)
195        if self.pos_embed is not None:
196            x = x + self.pos_embed
197
198        list_from_encoder = []
199        for i, blk in enumerate(self.blocks):
200            x = blk(x)
201            if i in self.chunks_for_projection:
202                list_from_encoder.append(x)
203
204        x = x.permute(0, 3, 1, 2)
205
206        if self.apply_neck:
207            x = self.neck(x)
208
209        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
210        return x, list_from_encoder[:3]

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_MAE:
213class ViT_MAE(VisionTransformer):
214    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
215
216    Based on:
217    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
218
219    Args:
220        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
221        in_chans: The number of input channels.
222        depth: The depth of the vision transformer.
223        kwargs: Additional keyword arguments for the vision transformer base class.
224    """
225    def __init__(
226        self,
227        img_size: int = 1024,  # chosen to match our experiments with segment anything
228        in_chans: int = 3,
229        depth: int = 12,
230        **kwargs
231    ):
232        if not _timm_import_success:
233            raise RuntimeError(
234                "The vision transformer backend can only be initialized if timm is installed. "
235                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
236                "and then rerun your code"
237            )
238        super().__init__(img_size=img_size, depth=depth, **kwargs)
239        self.img_size = img_size
240        self.in_chans = in_chans
241        self.depth = depth
242
243    def convert_to_expected_dim(self, inputs_):
244        """@private
245        """
246        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
247        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
248        rdim = inputs_.shape[1]
249        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
250        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
251        inputs_ = inputs_.permute(0, 3, 1, 2)
252        return inputs_
253
254    def forward_features(self, x):
255        """@private
256        """
257        B = x.shape[0]
258        x = self.patch_embed(x)
259
260        cls_tokens = self.cls_token.expand(B, -1, -1)
261        x = torch.cat((cls_tokens, x), dim=1)
262
263        x = x + self.pos_embed
264        x = self.pos_drop(x)
265
266        # chunks obtained for getting the projections for conjuctions with upsampling blocks
267        _chunks = int(self.depth / 4)
268        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
269
270        list_from_encoder = []
271        for i, blk in enumerate(self.blocks):
272            x = blk(x)
273            if i in chunks_for_projection:
274                list_from_encoder.append(self.convert_to_expected_dim(x))
275
276        x = self.convert_to_expected_dim(x)
277        return x, list_from_encoder[:3]
278
279    def forward(self, x: torch.Tensor) -> torch.Tensor:
280        """Apply the vision transformer to input data.
281
282        Args:
283            x: The input data.
284
285        Returns:
286            The vision transformer output.
287        """
288        x, list_from_encoder = self.forward_features(x)
289        return x, list_from_encoder

Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).

Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • in_chans: The number of input channels.
  • depth: The depth of the vision transformer.
  • kwargs: Additional keyword arguments for the vision transformer base class.
ViT_MAE(img_size: int = 1024, in_chans: int = 3, depth: int = 12, **kwargs)
225    def __init__(
226        self,
227        img_size: int = 1024,  # chosen to match our experiments with segment anything
228        in_chans: int = 3,
229        depth: int = 12,
230        **kwargs
231    ):
232        if not _timm_import_success:
233            raise RuntimeError(
234                "The vision transformer backend can only be initialized if timm is installed. "
235                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
236                "and then rerun your code"
237            )
238        super().__init__(img_size=img_size, depth=depth, **kwargs)
239        self.img_size = img_size
240        self.in_chans = in_chans
241        self.depth = depth
img_size
in_chans
depth
def forward(self, x: torch.Tensor) -> torch.Tensor:
279    def forward(self, x: torch.Tensor) -> torch.Tensor:
280        """Apply the vision transformer to input data.
281
282        Args:
283            x: The input data.
284
285        Returns:
286            The vision transformer output.
287        """
288        x, list_from_encoder = self.forward_features(x)
289        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_Sam2:
292class ViT_Sam2(ImageEncoder):
293    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
294
295    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
296
297    Args:
298        backbone_channel_list: The channels throughout the entire backbone.
299        embed_dim: The initial embedding dimension.
300        num_heads: The initial number of heads.
301        stages: The number of blocks per stage.
302        global_att_blocks: The parameter to decide which blocks have global attention.
303        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
304        window_spec: The window size per stage, when not using global attention.
305        scalp: The count of lowest resolution features to discard.
306    """
307    def __init__(
308        self,
309        backbone_channel_list: List[int],
310        img_size: int = 1024,
311        embed_dim: int = 96,
312        num_heads: int = 1,
313        stages: Tuple[int, ...] = (2, 3, 16, 3),
314        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
315        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
316        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
317        scalp: int = 1,
318        **kwargs
319    ):
320        if not _sam2_import_success:
321            raise RuntimeError(
322                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
323                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
324                "and then rerun your code"
325            )
326
327        trunk = Hiera(
328            embed_dim=embed_dim,
329            num_heads=num_heads,
330            stages=stages,
331            global_att_blocks=global_att_blocks,
332            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
333            window_spec=window_spec,
334        )
335        neck = FpnNeck(
336            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
337            d_model=256,
338            backbone_channel_list=backbone_channel_list,
339            fpn_top_down_levels=[2, 3],
340            fpn_interp_model="nearest",
341        )
342
343        super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs)
344        self.scalp = scalp
345        self.embed_dim = embed_dim
346        self.img_size = img_size
347
348    def forward(self, x: torch.Tensor):
349        # The forward pass throught the backbone.
350        features, pos = self.neck(self.trunk(x))
351        if self.scalp > 0:  # This discard the "n" lowest resolution features.
352            features, pos = features[:-self.scalp], pos[:-self.scalp]
353
354        return features[-1], features

Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).

Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.

Arguments:
  • backbone_channel_list: The channels throughout the entire backbone.
  • embed_dim: The initial embedding dimension.
  • num_heads: The initial number of heads.
  • stages: The number of blocks per stage.
  • global_att_blocks: The parameter to decide which blocks have global attention.
  • window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
  • window_spec: The window size per stage, when not using global attention.
  • scalp: The count of lowest resolution features to discard.
ViT_Sam2( backbone_channel_list: List[int], img_size: int = 1024, embed_dim: int = 96, num_heads: int = 1, stages: Tuple[int, ...] = (2, 3, 16, 3), global_att_blocks: Tuple[int, ...] = (12, 16, 20), window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), window_spec: Tuple[int, ...] = (8, 4, 14, 7), scalp: int = 1, **kwargs)
307    def __init__(
308        self,
309        backbone_channel_list: List[int],
310        img_size: int = 1024,
311        embed_dim: int = 96,
312        num_heads: int = 1,
313        stages: Tuple[int, ...] = (2, 3, 16, 3),
314        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
315        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
316        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
317        scalp: int = 1,
318        **kwargs
319    ):
320        if not _sam2_import_success:
321            raise RuntimeError(
322                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
323                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
324                "and then rerun your code"
325            )
326
327        trunk = Hiera(
328            embed_dim=embed_dim,
329            num_heads=num_heads,
330            stages=stages,
331            global_att_blocks=global_att_blocks,
332            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
333            window_spec=window_spec,
334        )
335        neck = FpnNeck(
336            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
337            d_model=256,
338            backbone_channel_list=backbone_channel_list,
339            fpn_top_down_levels=[2, 3],
340            fpn_interp_model="nearest",
341        )
342
343        super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs)
344        self.scalp = scalp
345        self.embed_dim = embed_dim
346        self.img_size = img_size
scalp
embed_dim
img_size
def forward(self, x: torch.Tensor):
348    def forward(self, x: torch.Tensor):
349        # The forward pass throught the backbone.
350        features, pos = self.neck(self.trunk(x))
351        if self.scalp > 0:  # This discard the "n" lowest resolution features.
352            features, pos = features[:-self.scalp], pos[:-self.scalp]
353
354        return features[-1], features
class ViT_Sam3:
357class ViT_Sam3(SAM3ViT):
358    """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).
359
360    Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.
361
362    Args:
363        img_size: The input image size.
364        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
365        kwargs: Keyword arguments for the image encoder base class.
366    """
367    def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs):
368        if not _sam3_import_success:
369            raise RuntimeError(
370                "The vision transformer backend can only be initialized if segment anything 3 is installed. "
371                "Please install segment anything 3 from https://github.com/facebookresearch/sam3 "
372                "and then rerun your code"
373            )
374
375        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
376        self.img_size = img_size
377        self.embed_dim = embed_dim
378
379    def forward_features(self, x):
380        """@private
381        """
382        x = self.patch_embed(x)
383        h, w = x.shape[1], x.shape[2]
384
385        s = 0
386        if self.retain_cls_token:
387            # If the 'cls_token' is retained, we don't maintain the spatial shape.
388            x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
389            s = 1
390
391        if self.pos_embed is not None:
392            x = x + get_abs_pos(
393                self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos,
394            )
395
396        x = self.ln_pre(x)
397
398        list_from_encoder = []
399        for i, blk in enumerate(self.blocks):
400            if self.use_act_checkpoint and self.training:
401                x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False)
402            else:
403                x = blk(x)
404
405            x = self._convert_to_expected_dim(x, i, s)
406
407            if i in self.full_attn_ids:
408                list_from_encoder.append(x)
409
410        return x, list_from_encoder
411
412    def _convert_to_expected_dim(self, x, i, s):
413        if (i == self.full_attn_ids[-1]) or (
414            self.return_interm_layers and i in self.full_attn_ids
415        ):
416            if i == self.full_attn_ids[-1]:
417                x = self.ln_post(x)
418
419            feats = x[:, s:]
420            if feats.ndim == 4:
421                feats = feats.permute(0, 3, 1, 2)
422            else:
423                assert feats.ndim == 3
424                h = w = math.sqrt(feats.shape[1])
425                feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
426            return feats
427
428        else:
429            return x
430
431    def forward(self, x: torch.Tensor):
432        """Apply the vision transformer to input data.
433
434        Args:
435            x: The input data.
436
437        Returns:
438            The vision transformer output.
439        """
440        x, list_from_encoder = self.forward_features(x)
441        return x, list_from_encoder

Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).

Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.

Arguments:
  • img_size: The input image size.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam3(img_size: int = 1024, embed_dim: int = 768, **kwargs)
367    def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs):
368        if not _sam3_import_success:
369            raise RuntimeError(
370                "The vision transformer backend can only be initialized if segment anything 3 is installed. "
371                "Please install segment anything 3 from https://github.com/facebookresearch/sam3 "
372                "and then rerun your code"
373            )
374
375        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
376        self.img_size = img_size
377        self.embed_dim = embed_dim
img_size
embed_dim
def forward(self, x: torch.Tensor):
431    def forward(self, x: torch.Tensor):
432        """Apply the vision transformer to input data.
433
434        Args:
435            x: The input data.
436
437        Returns:
438            The vision transformer output.
439        """
440        x, list_from_encoder = self.forward_features(x)
441        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class CustomCompose:
448class CustomCompose:
449    def __init__(self, rescale_transform, other_transforms, src_transform):
450        self.rescale_transform = rescale_transform
451        self.other_transforms = other_transforms
452        self.src_transform = src_transform
453
454    def __call__(self, x, valid_masks=None):
455        if valid_masks is not None:
456            nodata = (x * (1 - valid_masks.float())).max()
457        x_aug = self.rescale_transform(x)
458        parms = self.rescale_transform._params
459
460        # sanity check, comment if this is working
461        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
462        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
463
464        if valid_masks is not None:
465            valid_masks = x_aug != nodata
466            _, c, h, w = x_aug.shape
467            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
468        else:
469            zero_ratio = -1
470
471        if self.other_transforms:
472            x_aug = self.other_transforms(x_aug)
473        x_src = self.src_transform(x_aug)
474        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
475
476        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
477        # assert (dx == dy).all()
478
479        h, w = x_aug.shape[-2:]
480        # assert h == w
481
482        return x_aug, x_src, dx / h, zero_ratio, valid_masks
CustomCompose(rescale_transform, other_transforms, src_transform)
449    def __init__(self, rescale_transform, other_transforms, src_transform):
450        self.rescale_transform = rescale_transform
451        self.other_transforms = other_transforms
452        self.src_transform = src_transform
rescale_transform
other_transforms
src_transform
def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device='cpu'):
485def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
486    """
487    grid_size: int of the grid height and width
488    res: array of size n, representing the resolution of a pixel (say, in meters),
489    return:
490    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
491    """
492    # res = torch.FloatTensor(res).to(device)
493    res = res.to(device)
494    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
495    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
496    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
497    grid = torch.stack(grid, dim=0)  # 2 x h x w
498
499    # grid = grid.reshape([2, 1, grid_size, grid_size])
500    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
501    _, n, h, w = grid.shape
502    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
503    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
504    if cls_token:
505        pos_embed = torch.cat(
506            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
507        )
508
509    return pos_embed

grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_sizegrid_size, embed_dim] or [n,1+grid_sizegrid_size, embed_dim] (w/ or w/o cls_token)

def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
512def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
513    assert embed_dim % 2 == 0
514
515    # use half of dimensions to encode grid_h
516    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
517    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
518
519    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
520    return emb
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
523def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
524    """
525    embed_dim: output dimension for each position
526    pos: a list of positions to be encoded: size (M,)
527    out: (M, D)
528    """
529    assert embed_dim % 2 == 0
530    # old_shape = pos
531    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
532    omega /= embed_dim / 2.0
533    omega = 1.0 / 10000**omega  # (D/2,)
534
535    pos = pos.reshape(-1)  # (M,)
536    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
537
538    emb_sin = torch.sin(out)  # (M, D/2)
539    emb_cos = torch.cos(out)  # (M, D/2)
540
541    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
542    return emb

embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)

class PatchEmbedUnSafe:
545class PatchEmbedUnSafe(PatchEmbed):
546    """Image to Patch Embedding"""
547
548    def forward(self, x):
549        B, C, H, W = x.shape
550
551        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
552        # assert H == self.img_size[0] and W == self.img_size[1], \
553        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
554
555        x = self.proj(x).flatten(2).transpose(1, 2)
556        return x

Image to Patch Embedding

def forward(self, x):
548    def forward(self, x):
549        B, C, H, W = x.shape
550
551        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
552        # assert H == self.img_size[0] and W == self.img_size[1], \
553        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
554
555        x = self.proj(x).flatten(2).transpose(1, 2)
556        return x
class ViT_ScaleMAE:
559class ViT_ScaleMAE(VisionTransformer):
560    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
561
562    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
563    the model on a different zoom factor dataset.
564    """
565
566    def __init__(
567        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
568    ):
569        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
570        self.img_size = img_size
571        self.in_chans = in_chans
572        self.depth = depth
573        self.base_resolution = base_resolution
574
575        self.patch_embed = PatchEmbedUnSafe(
576            img_size=img_size,
577            patch_size=patch_size,
578            in_chans=in_chans,
579            embed_dim=embed_dim,
580        )
581
582    def transform_inputs(self, x):
583        import kornia.augmentation as K
584        from kornia.constants import Resample
585
586        self._transforms = CustomCompose(
587            rescale_transform=K.RandomResizedCrop(
588                (448, 448),
589                ratio=(1.0, 1.0),
590                scale=(1.0, 1.0),
591                resample=Resample.BICUBIC.name,
592            ),
593            other_transforms=None,
594            src_transform=K.Resize((224, 224)),
595        )
596        x, _, ratios, _, _ = self._transforms(x)
597        input_res = ratios * self.base_resolution
598        return x, input_res
599
600    def convert_to_expected_dim(self, x):
601        inputs_ = x[:, 1:, :]  # removing the class tokens
602        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
603        rdim = inputs_.shape[1]
604        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
605        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
606        inputs_ = inputs_.permute(0, 3, 1, 2)
607        return inputs_
608
609    def forward_features(self, x):
610        x, input_res = self.transform_inputs(x)
611
612        B, _, h, w = x.shape
613        x = self.patch_embed(x)
614
615        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
616        pos_embed = get_2d_sincos_pos_embed_with_resolution(
617            x.shape[-1],
618            int(num_patches ** 0.5),
619            input_res,
620            cls_token=True,
621            device=x.device,
622        )
623
624        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
625        x = torch.cat((cls_tokens, x), dim=1)
626        x = x + pos_embed
627        x = self.pos_drop(x)
628
629        # chunks obtained for getting the projections for conjuctions with upsampling blocks
630        _chunks = int(self.depth / 4)
631        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
632
633        list_from_encoder = []
634        for i, blk in enumerate(self.blocks):
635            x = blk(x)
636            if i in chunks_for_projection:
637                list_from_encoder.append(self.convert_to_expected_dim(x))
638
639        x = self.convert_to_expected_dim(x)
640
641        return x, list_from_encoder
642
643    def forward(self, x):
644        x, list_from_encoder = self.forward_features(x)
645        return x, list_from_encoder

Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).

NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.

ViT_ScaleMAE( img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs)
566    def __init__(
567        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
568    ):
569        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
570        self.img_size = img_size
571        self.in_chans = in_chans
572        self.depth = depth
573        self.base_resolution = base_resolution
574
575        self.patch_embed = PatchEmbedUnSafe(
576            img_size=img_size,
577            patch_size=patch_size,
578            in_chans=in_chans,
579            embed_dim=embed_dim,
580        )
img_size
in_chans
depth
base_resolution
patch_embed
def transform_inputs(self, x):
582    def transform_inputs(self, x):
583        import kornia.augmentation as K
584        from kornia.constants import Resample
585
586        self._transforms = CustomCompose(
587            rescale_transform=K.RandomResizedCrop(
588                (448, 448),
589                ratio=(1.0, 1.0),
590                scale=(1.0, 1.0),
591                resample=Resample.BICUBIC.name,
592            ),
593            other_transforms=None,
594            src_transform=K.Resize((224, 224)),
595        )
596        x, _, ratios, _, _ = self._transforms(x)
597        input_res = ratios * self.base_resolution
598        return x, input_res
def convert_to_expected_dim(self, x):
600    def convert_to_expected_dim(self, x):
601        inputs_ = x[:, 1:, :]  # removing the class tokens
602        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
603        rdim = inputs_.shape[1]
604        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
605        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
606        inputs_ = inputs_.permute(0, 3, 1, 2)
607        return inputs_
def forward_features(self, x):
609    def forward_features(self, x):
610        x, input_res = self.transform_inputs(x)
611
612        B, _, h, w = x.shape
613        x = self.patch_embed(x)
614
615        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
616        pos_embed = get_2d_sincos_pos_embed_with_resolution(
617            x.shape[-1],
618            int(num_patches ** 0.5),
619            input_res,
620            cls_token=True,
621            device=x.device,
622        )
623
624        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
625        x = torch.cat((cls_tokens, x), dim=1)
626        x = x + pos_embed
627        x = self.pos_drop(x)
628
629        # chunks obtained for getting the projections for conjuctions with upsampling blocks
630        _chunks = int(self.depth / 4)
631        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
632
633        list_from_encoder = []
634        for i, blk in enumerate(self.blocks):
635            x = blk(x)
636            if i in chunks_for_projection:
637                list_from_encoder.append(self.convert_to_expected_dim(x))
638
639        x = self.convert_to_expected_dim(x)
640
641        return x, list_from_encoder
def forward(self, x):
643    def forward(self, x):
644        x, list_from_encoder = self.forward_features(x)
645        return x, list_from_encoder
class ViT_DINOv2:
648class ViT_DINOv2(DinoV2VisionTransformer):
649    """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
650
651    Based on:
652    https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
653
654    Args:
655        img_size: The input image size.
656        patch_size: The patch size.
657        depth: The depth of the network.
658        num_register_tokens: The number of registers added (in addition to the class tokens).
659            It's important to know for ViTs trained with registers, to remove them at the end.
660    """
661    def __init__(
662        self,
663        img_size: int = 224,
664        patch_size: int = 16,
665        depth: int = 12,
666        num_register_tokens: int = 0,
667        **kwargs
668    ):
669        if not _dinov2_import_success:
670            raise RuntimeError(
671                "The vision transformer backend can only be initialized if DINOv2 is installed. "
672                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
673                "and then rerun your code."
674            )
675
676        super().__init__(
677            img_size=img_size,
678            depth=depth,
679            patch_size=patch_size,
680            num_register_tokens=num_register_tokens,
681            **kwargs
682        )
683
684        self.img_size = img_size
685        self.num_register_tokens = num_register_tokens
686        self.patch_size = patch_size
687        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
688
689    def forward(self, x, masks=None) -> torch.Tensor:
690
691        B = x.shape[0]
692
693        x = self.prepare_tokens_with_masks(x)
694
695        list_of_encoder = []
696        for i, blk in enumerate(self.blocks):
697            x = blk(x)
698            if i in self.attn_outs:
699                list_of_encoder.append(x)
700
701        x = self.norm(x)
702        x = x[:, self.num_register_tokens + 1:].reshape(
703            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
704        ).permute(0, 3, 1, 2).contiguous()
705
706        list_of_encoder = [
707            o[:, self.num_register_tokens + 1:].reshape(
708                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
709            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
710        ]
711
712        return x, list_of_encoder[:3]

Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).

Based on: https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.

Arguments:
  • img_size: The input image size.
  • patch_size: The patch size.
  • depth: The depth of the network.
  • num_register_tokens: The number of registers added (in addition to the class tokens). It's important to know for ViTs trained with registers, to remove them at the end.
ViT_DINOv2( img_size: int = 224, patch_size: int = 16, depth: int = 12, num_register_tokens: int = 0, **kwargs)
661    def __init__(
662        self,
663        img_size: int = 224,
664        patch_size: int = 16,
665        depth: int = 12,
666        num_register_tokens: int = 0,
667        **kwargs
668    ):
669        if not _dinov2_import_success:
670            raise RuntimeError(
671                "The vision transformer backend can only be initialized if DINOv2 is installed. "
672                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
673                "and then rerun your code."
674            )
675
676        super().__init__(
677            img_size=img_size,
678            depth=depth,
679            patch_size=patch_size,
680            num_register_tokens=num_register_tokens,
681            **kwargs
682        )
683
684        self.img_size = img_size
685        self.num_register_tokens = num_register_tokens
686        self.patch_size = patch_size
687        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
img_size
num_register_tokens
patch_size
attn_outs
def forward(self, x, masks=None) -> torch.Tensor:
689    def forward(self, x, masks=None) -> torch.Tensor:
690
691        B = x.shape[0]
692
693        x = self.prepare_tokens_with_masks(x)
694
695        list_of_encoder = []
696        for i, blk in enumerate(self.blocks):
697            x = blk(x)
698            if i in self.attn_outs:
699                list_of_encoder.append(x)
700
701        x = self.norm(x)
702        x = x[:, self.num_register_tokens + 1:].reshape(
703            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
704        ).permute(0, 3, 1, 2).contiguous()
705
706        list_of_encoder = [
707            o[:, self.num_register_tokens + 1:].reshape(
708                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
709            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
710        ]
711
712        return x, list_of_encoder[:3]
class ViT_DINOv3:
715class ViT_DINOv3(DinoV3VisionTransformer):
716    """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
717
718    Based on:
719    https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
720
721    Args:
722        img_size: The input image size.
723        patch_size: The patch size.
724        embed_dim: The embedding dimension.
725        depth: The depth of the network.
726        num_heads: The number of heads.
727        ffn_ratio: The FFN rato.
728        n_storage_tokens: The number of storage (class) tokens to remove.
729        kwargs: Keyword arguments for the image encoder base class.
730    """
731    def __init__(
732        self,
733        in_chans: int = 3,
734        img_size: int = 224,
735        patch_size: int = 16,
736        embed_dim: int = 768,
737        depth: int = 12,
738        num_heads: int = 12,
739        ffn_ratio: float = 4.0,
740        n_storage_tokens: int = 0,
741        **kwargs
742    ):
743        if not _dinov3_import_success:
744            raise RuntimeError(
745                "The vision transformer backend can only be initialized if DINOv3 is installed. "
746                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
747                "and then rerun your code."
748            )
749
750        super().__init__(
751            in_chans=in_chans,
752            img_size=img_size,
753            patch_size=patch_size,
754            embed_dim=embed_dim,
755            depth=depth,
756            num_heads=num_heads,
757            ffn_ratio=ffn_ratio,
758            n_storage_tokens=n_storage_tokens,
759            **kwargs
760        )
761
762        self.in_chans = in_chans
763        self.img_size = img_size
764        self.n_storage_tokens = n_storage_tokens
765        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
766
767    def forward(self, x) -> torch.Tensor:
768
769        B = x.shape[0]
770
771        x, hw_tuple = self.prepare_tokens_with_masks(x)
772
773        list_of_encoder = []
774        for i, blk in enumerate(self.blocks):
775            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
776            x = blk(x, rope_sincos)
777            if i in self.attn_outs:
778                list_of_encoder.append(x)
779
780        x = self.norm(x)
781        x = x[:, self.n_storage_tokens + 1:].reshape(
782            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
783        ).permute(0, 3, 1, 2).contiguous()
784
785        list_of_encoder = [
786            o[:, self.n_storage_tokens + 1:].reshape(
787                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
788            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
789        ]
790
791        return x, list_of_encoder[:3]

Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).

Based on: https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.

Arguments:
  • img_size: The input image size.
  • patch_size: The patch size.
  • embed_dim: The embedding dimension.
  • depth: The depth of the network.
  • num_heads: The number of heads.
  • ffn_ratio: The FFN rato.
  • n_storage_tokens: The number of storage (class) tokens to remove.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_DINOv3( in_chans: int = 3, img_size: int = 224, patch_size: int = 16, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, ffn_ratio: float = 4.0, n_storage_tokens: int = 0, **kwargs)
731    def __init__(
732        self,
733        in_chans: int = 3,
734        img_size: int = 224,
735        patch_size: int = 16,
736        embed_dim: int = 768,
737        depth: int = 12,
738        num_heads: int = 12,
739        ffn_ratio: float = 4.0,
740        n_storage_tokens: int = 0,
741        **kwargs
742    ):
743        if not _dinov3_import_success:
744            raise RuntimeError(
745                "The vision transformer backend can only be initialized if DINOv3 is installed. "
746                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
747                "and then rerun your code."
748            )
749
750        super().__init__(
751            in_chans=in_chans,
752            img_size=img_size,
753            patch_size=patch_size,
754            embed_dim=embed_dim,
755            depth=depth,
756            num_heads=num_heads,
757            ffn_ratio=ffn_ratio,
758            n_storage_tokens=n_storage_tokens,
759            **kwargs
760        )
761
762        self.in_chans = in_chans
763        self.img_size = img_size
764        self.n_storage_tokens = n_storage_tokens
765        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
in_chans
img_size
n_storage_tokens
attn_outs
def forward(self, x) -> torch.Tensor:
767    def forward(self, x) -> torch.Tensor:
768
769        B = x.shape[0]
770
771        x, hw_tuple = self.prepare_tokens_with_masks(x)
772
773        list_of_encoder = []
774        for i, blk in enumerate(self.blocks):
775            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
776            x = blk(x, rope_sincos)
777            if i in self.attn_outs:
778                list_of_encoder.append(x)
779
780        x = self.norm(x)
781        x = x[:, self.n_storage_tokens + 1:].reshape(
782            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
783        ).permute(0, 3, 1, 2).contiguous()
784
785        list_of_encoder = [
786            o[:, self.n_storage_tokens + 1:].reshape(
787                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
788            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
789        ]
790
791        return x, list_of_encoder[:3]
def get_vision_transformer( backbone: str, model: str, img_size: int = 1024, **kwargs) -> torch.nn.modules.module.Module:
 794def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
 795    """Get vision transformer encoder.
 796
 797    Args:
 798        backbone: The name of the vision transformer implementation.
 799            One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3".
 800        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
 801        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 802        kwargs: Additional kwargs which can be expected by the vision transformer,
 803            e.g. 'base_resolution' for `ViT_ScaleMAE`.
 804
 805    Returns:
 806        The vision transformer.
 807    """
 808    if backbone == "sam":
 809        if model == "vit_b":
 810            encoder = ViT_Sam(
 811                depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4,
 812                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 813                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
 814                global_attn_indexes=[2, 5, 8, 11],
 815                window_size=14, out_chans=256,
 816            )
 817        elif model == "vit_l":
 818            encoder = ViT_Sam(
 819                depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4,
 820                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 821                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
 822                global_attn_indexes=[5, 11, 17, 23],
 823                window_size=14, out_chans=256,
 824            )
 825        elif model == "vit_h":
 826            encoder = ViT_Sam(
 827                depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4,
 828                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 829                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
 830                global_attn_indexes=[7, 15, 23, 31],
 831                window_size=14, out_chans=256,
 832            )
 833        else:
 834            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
 835
 836    elif backbone == "cellpose_sam":
 837        if model != "vit_l":
 838            raise ValueError(f"'{model}' is not supported by CellposeSAM. Only 'vit_l' is supported.")
 839        encoder = ViT_CellposeSAM(ps=8, bsize=img_size)
 840
 841    elif backbone == "sam2":
 842        if model == "hvit_t":
 843            encoder = ViT_Sam2(
 844                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
 845                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
 846            )
 847        elif model == "hvit_s":
 848            encoder = ViT_Sam2(
 849                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
 850                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
 851            )
 852        elif model == "hvit_b":
 853            encoder = ViT_Sam2(
 854                img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
 855            )
 856        elif model == "hvit_l":
 857            encoder = ViT_Sam2(
 858                img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
 859                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
 860            )
 861        else:
 862            raise ValueError(
 863                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
 864            )
 865
 866    elif backbone == "sam3":
 867        if model != "vit_pe":
 868            raise ValueError(
 869                "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration."
 870            )
 871
 872        encoder = ViT_Sam3(
 873            img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16,
 874            mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True,
 875            tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True,
 876            use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True,
 877            ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None,
 878        )
 879
 880    elif backbone == "mae":
 881        if model == "vit_b":
 882            encoder = ViT_MAE(
 883                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
 884                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 885            )
 886        elif model == "vit_l":
 887            encoder = ViT_MAE(
 888                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
 889                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 890            )
 891        elif model == "vit_h":
 892            encoder = ViT_MAE(
 893                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
 894                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 895            )
 896        else:
 897            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
 898
 899    elif backbone == "scalemae":
 900        base_resolution = kwargs.get("base_resolution", 2.5)
 901
 902        if model == "vit_b":
 903            encoder = ViT_ScaleMAE(
 904                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
 905                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 906                base_resolution=base_resolution,
 907            )
 908        elif model == "vit_l":
 909            encoder = ViT_ScaleMAE(
 910                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
 911                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 912                base_resolution=base_resolution,
 913            )
 914        elif model == "vit_h":
 915            encoder = ViT_ScaleMAE(
 916                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
 917                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 918                base_resolution=base_resolution,
 919            )
 920        else:
 921            raise ValueError(
 922                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
 923            )
 924
 925    elif backbone == "dinov2":
 926        block_fn = partial(Block, attn_class=MemEffAttention)
 927        msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>."
 928
 929        if model.startswith("vit_s"):
 930            assert model in ["vit_s", "vit_s_reg4"], msg
 931            encoder = ViT_DINOv2(
 932                img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
 933                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 934                num_register_tokens=4 if model.endswith("_reg4") else 0,
 935            )
 936        elif model.startswith("vit_b"):
 937            assert model in ["vit_b", "vit_b_reg4"], msg
 938            encoder = ViT_DINOv2(
 939                img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
 940                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 941                num_register_tokens=4 if model.endswith("_reg4") else 0,
 942            )
 943        elif model.startswith("vit_l"):
 944            assert model in ["vit_l", "vit_l_reg4"], msg
 945            encoder = ViT_DINOv2(
 946                img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
 947                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 948                num_register_tokens=4 if model.endswith("_reg4") else 0,
 949            )
 950        elif model.startswith("vit_g"):
 951            assert model in ["vit_g", "vit_g_reg4"], msg
 952            encoder = ViT_DINOv2(
 953                img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4,
 954                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 955                num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu",
 956            )
 957        else:
 958            raise ValueError(
 959                f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported."
 960            )
 961
 962    elif backbone == "dinov3":
 963
 964        if model == "vit_s":
 965            encoder = ViT_DINOv3(
 966                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
 967                num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
 968            )
 969        elif model == "vit_s+":
 970            encoder = ViT_DINOv3(
 971                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
 972                num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 973                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 974            )
 975
 976        elif model == "vit_b":
 977            encoder = ViT_DINOv3(
 978                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32",
 979                layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
 980            )
 981        elif model == "vit_l":
 982            encoder = ViT_DINOv3(
 983                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
 984                depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 985                n_storage_tokens=4, mask_k_bias=True,
 986            )
 987        elif model == "vit_l+":
 988            encoder = ViT_DINOv3(
 989                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
 990                depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 991                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 992            )
 993        elif model == "vit_h+":
 994            encoder = ViT_DINOv3(
 995                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280,
 996                depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 997                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 998            )
 999        elif model == "vit_7b":
1000            encoder = ViT_DINOv3(
1001                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096,
1002                depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05,
1003                norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True,
1004                untie_global_and_local_cls_norm=True,
1005            )
1006        else:
1007            raise ValueError(
1008                f"'{model}' is not supported by DINOv3. Currently, "
1009                " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported."
1010            )
1011
1012    else:
1013        raise ValueError(
1014            "The 'UNETR' supported backbones are 'sam', 'cellpose_sam', 'sam2', 'sam3', "
1015            "'mae', 'scalemae', 'dinov2' or 'dinov3'. Please choose one of them."
1016        )
1017
1018    return encoder

Get vision transformer encoder.

Arguments:
  • backbone: The name of the vision transformer implementation. One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3".
  • model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • kwargs: Additional kwargs which can be expected by the vision transformer, e.g. 'base_resolution' for ViT_ScaleMAE.
Returns:

The vision transformer.