torch_em.model.vit

   1import math
   2from functools import partial
   3from typing import Tuple, List
   4
   5import torch
   6import torch.nn as nn
   7
   8# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should
   9# only be optional dependencies for torch_em
  10try:
  11    from segment_anything.modeling import ImageEncoderViT
  12    _sam_import_success = True
  13except ImportError:
  14    ImageEncoderViT = object
  15    _sam_import_success = False
  16
  17try:
  18    from timm.models.vision_transformer import VisionTransformer, PatchEmbed
  19    _timm_import_success = True
  20except ImportError:
  21    VisionTransformer = object
  22    PatchEmbed = object
  23    _timm_import_success = False
  24
  25try:
  26    from sam2.modeling.backbones.hieradet import Hiera
  27    from sam2.modeling.position_encoding import PositionEmbeddingSine
  28    from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck
  29    _sam2_import_success = True
  30except ImportError:
  31    ImageEncoder = object
  32    _sam2_import_success = False
  33
  34try:
  35    from dinov2.models.vision_transformer import DinoVisionTransformer as DinoV2VisionTransformer
  36    from dinov2.layers import MemEffAttention, NestedTensorBlock as Block
  37    _dinov2_import_success = True
  38except ImportError:
  39    DinoV2VisionTransformer = object
  40    _dinov2_import_success = False
  41
  42try:
  43    from dinov3.models.vision_transformer import DinoVisionTransformer as DinoV3VisionTransformer
  44    _dinov3_import_success = True
  45except ImportError:
  46    DinoV3VisionTransformer = object
  47    _dinov3_import_success = False
  48
  49
  50try:
  51    from sam3.model.vitdet import ViT as SAM3ViT, get_abs_pos
  52    _sam3_import_success = True
  53except ImportError:
  54    SAM3ViT = object
  55    _sam3_import_success = False
  56
  57
  58class ViT_Sam(ImageEncoderViT):
  59    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
  60
  61    Based on:
  62    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
  63
  64    Args:
  65        in_chans: The number of input channels.
  66        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  67        global_attn_indexes: The global attention indices.
  68        kwargs: Keyword arguments for the image encoder base class.
  69    """
  70    def __init__(
  71        self,
  72        in_chans: int = 3,
  73        embed_dim: int = 768,
  74        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
  75        **kwargs,
  76    ) -> None:
  77        if not _sam_import_success:
  78            raise RuntimeError(
  79                "The vision transformer backend can only be initialized if segment anything is installed. "
  80                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
  81                "and then rerun your code."
  82            )
  83
  84        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
  85        self.chunks_for_projection = global_attn_indexes
  86        self.in_chans = in_chans
  87        self.embed_dim = embed_dim
  88
  89    def forward(self, x: torch.Tensor) -> torch.Tensor:
  90        """Apply the vision transformer to input data.
  91
  92        Args:
  93            x: The input data.
  94
  95        Returns:
  96            The vision transformer output.
  97        """
  98        x = self.patch_embed(x)
  99        if self.pos_embed is not None:
 100            x = x + self.pos_embed
 101
 102        list_from_encoder = []
 103        for i, blk in enumerate(self.blocks):
 104            x = blk(x)
 105            if i in self.chunks_for_projection:
 106                list_from_encoder.append(x)
 107
 108        x = x.permute(0, 3, 1, 2)
 109        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
 110        return x, list_from_encoder[:3]
 111
 112
 113class ViT_CellposeSAM(nn.Module):
 114    """Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x).
 115
 116    This replicates CellposeSAM's actual initialization: instantiate SAM's ``ImageEncoderViT`` via
 117    ``sam_model_registry``, then modify the patch embedding, position embeddings, and set global attention.
 118    This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading
 119    without any interpolation.
 120
 121    Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py
 122
 123    NOTE: The pretrained CellposeSAM model uses ``vit_l`` exclusively.
 124
 125    Args:
 126        ps: The patch size (default for CellposeSAM is 8).
 127        bsize: The input image size (default for CellposeSAM is 256).
 128    """
 129    def __init__(self, ps: int = 8, bsize: int = 256) -> None:
 130        super().__init__()
 131
 132        if not _sam_import_success:
 133            raise RuntimeError(
 134                "The vision transformer backend can only be initialized if segment anything is installed. "
 135                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
 136                "and then rerun your code."
 137            )
 138
 139        from segment_anything import sam_model_registry
 140
 141        # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer).
 142        encoder = sam_model_registry["vit_l"](None).image_encoder
 143
 144        w = encoder.patch_embed.proj.weight.detach()
 145        nchan = w.shape[0]
 146
 147        # CellPoseSAM changes the patch size from 16 to 'ps'.
 148        encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
 149        encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps]
 150
 151        # Next, they subsample position embeddings for the new patch size and input resolution.
 152        ds = (1024 // 16) // (bsize // ps)
 153        encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True)
 154
 155        # Finally, they set all blocks to global attention.
 156        for blk in encoder.blocks:
 157            blk.window_size = 0
 158
 159        # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping).
 160        self.patch_embed = encoder.patch_embed
 161        self.pos_embed = encoder.pos_embed
 162        self.blocks = encoder.blocks
 163        self.neck = encoder.neck
 164
 165        # Additional attributes expected by UNETR.
 166        self.embed_dim = nchan
 167        self.img_size = bsize
 168        self.in_chans = 3
 169
 170        # Feature extraction at evenly-spaced depths.
 171        depth = len(self.blocks)
 172        _chunks = depth // 4
 173        self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1]
 174
 175    def forward(self, x: torch.Tensor) -> torch.Tensor:
 176        """Apply the vision transformer to input data.
 177
 178        Args:
 179            x: The input data.
 180
 181        Returns:
 182            The vision transformer output.
 183        """
 184        x = self.patch_embed(x)
 185        if self.pos_embed is not None:
 186            x = x + self.pos_embed
 187
 188        list_from_encoder = []
 189        for i, blk in enumerate(self.blocks):
 190            x = blk(x)
 191            if i in self.chunks_for_projection:
 192                list_from_encoder.append(x)
 193
 194        x = x.permute(0, 3, 1, 2)
 195        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
 196        return x, list_from_encoder[:3]
 197
 198
 199class ViT_MAE(VisionTransformer):
 200    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
 201
 202    Based on:
 203    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
 204
 205    Args:
 206        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 207        in_chans: The number of input channels.
 208        depth: The depth of the vision transformer.
 209        kwargs: Additional keyword arguments for the vision transformer base class.
 210    """
 211    def __init__(
 212        self,
 213        img_size: int = 1024,  # chosen to match our experiments with segment anything
 214        in_chans: int = 3,
 215        depth: int = 12,
 216        **kwargs
 217    ):
 218        if not _timm_import_success:
 219            raise RuntimeError(
 220                "The vision transformer backend can only be initialized if timm is installed. "
 221                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
 222                "and then rerun your code"
 223            )
 224        super().__init__(img_size=img_size, depth=depth, **kwargs)
 225        self.img_size = img_size
 226        self.in_chans = in_chans
 227        self.depth = depth
 228
 229    def convert_to_expected_dim(self, inputs_):
 230        """@private
 231        """
 232        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
 233        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
 234        rdim = inputs_.shape[1]
 235        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
 236        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
 237        inputs_ = inputs_.permute(0, 3, 1, 2)
 238        return inputs_
 239
 240    def forward_features(self, x):
 241        """@private
 242        """
 243        B = x.shape[0]
 244        x = self.patch_embed(x)
 245
 246        cls_tokens = self.cls_token.expand(B, -1, -1)
 247        x = torch.cat((cls_tokens, x), dim=1)
 248
 249        x = x + self.pos_embed
 250        x = self.pos_drop(x)
 251
 252        # chunks obtained for getting the projections for conjuctions with upsampling blocks
 253        _chunks = int(self.depth / 4)
 254        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
 255
 256        list_from_encoder = []
 257        for i, blk in enumerate(self.blocks):
 258            x = blk(x)
 259            if i in chunks_for_projection:
 260                list_from_encoder.append(self.convert_to_expected_dim(x))
 261
 262        x = self.convert_to_expected_dim(x)
 263        return x, list_from_encoder[:3]
 264
 265    def forward(self, x: torch.Tensor) -> torch.Tensor:
 266        """Apply the vision transformer to input data.
 267
 268        Args:
 269            x: The input data.
 270
 271        Returns:
 272            The vision transformer output.
 273        """
 274        x, list_from_encoder = self.forward_features(x)
 275        return x, list_from_encoder
 276
 277
 278class ViT_Sam2(ImageEncoder):
 279    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
 280
 281    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
 282
 283    Args:
 284        backbone_channel_list: The channels throughout the entire backbone.
 285        embed_dim: The initial embedding dimension.
 286        num_heads: The initial number of heads.
 287        stages: The number of blocks per stage.
 288        global_att_blocks: The parameter to decide which blocks have global attention.
 289        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
 290        window_spec: The window size per stage, when not using global attention.
 291        scalp: The count of lowest resolution features to discard.
 292    """
 293    def __init__(
 294        self,
 295        backbone_channel_list: List[int],
 296        img_size: int = 1024,
 297        embed_dim: int = 96,
 298        num_heads: int = 1,
 299        stages: Tuple[int, ...] = (2, 3, 16, 3),
 300        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
 301        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
 302        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
 303        scalp: int = 1,
 304        **kwargs
 305    ):
 306        if not _sam2_import_success:
 307            raise RuntimeError(
 308                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
 309                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
 310                "and then rerun your code"
 311            )
 312
 313        trunk = Hiera(
 314            embed_dim=embed_dim,
 315            num_heads=num_heads,
 316            stages=stages,
 317            global_att_blocks=global_att_blocks,
 318            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
 319            window_spec=window_spec,
 320        )
 321        neck = FpnNeck(
 322            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
 323            d_model=256,
 324            backbone_channel_list=backbone_channel_list,
 325            fpn_top_down_levels=[2, 3],
 326            fpn_interp_model="nearest",
 327        )
 328
 329        super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs)
 330        self.scalp = scalp
 331        self.embed_dim = embed_dim
 332        self.img_size = img_size
 333
 334    def forward(self, x: torch.Tensor):
 335        # The forward pass throught the backbone.
 336        features, pos = self.neck(self.trunk(x))
 337        if self.scalp > 0:  # This discard the "n" lowest resolution features.
 338            features, pos = features[:-self.scalp], pos[:-self.scalp]
 339
 340        return features[-1], features
 341
 342
 343class ViT_Sam3(SAM3ViT):
 344    """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).
 345
 346    Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.
 347
 348    Args:
 349        img_size: The input image size.
 350        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 351        kwargs: Keyword arguments for the image encoder base class.
 352    """
 353    def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs):
 354        if not _sam3_import_success:
 355            raise RuntimeError(
 356                "The vision transformer backend can only be initialized if segment anything 3 is installed. "
 357                "Please install segment anything 3 from https://github.com/facebookresearch/sam3 "
 358                "and then rerun your code"
 359            )
 360
 361        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
 362        self.img_size = img_size
 363        self.embed_dim = embed_dim
 364
 365    def forward_features(self, x):
 366        """@private
 367        """
 368        x = self.patch_embed(x)
 369        h, w = x.shape[1], x.shape[2]
 370
 371        s = 0
 372        if self.retain_cls_token:
 373            # If the 'cls_token' is retained, we don't maintain the spatial shape.
 374            x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
 375            s = 1
 376
 377        if self.pos_embed is not None:
 378            x = x + get_abs_pos(
 379                self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos,
 380            )
 381
 382        x = self.ln_pre(x)
 383
 384        list_from_encoder = []
 385        for i, blk in enumerate(self.blocks):
 386            if self.use_act_checkpoint and self.training:
 387                x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False)
 388            else:
 389                x = blk(x)
 390
 391            x = self._convert_to_expected_dim(x, i, s)
 392
 393            if i in self.full_attn_ids:
 394                list_from_encoder.append(x)
 395
 396        return x, list_from_encoder
 397
 398    def _convert_to_expected_dim(self, x, i, s):
 399        if (i == self.full_attn_ids[-1]) or (
 400            self.return_interm_layers and i in self.full_attn_ids
 401        ):
 402            if i == self.full_attn_ids[-1]:
 403                x = self.ln_post(x)
 404
 405            feats = x[:, s:]
 406            if feats.ndim == 4:
 407                feats = feats.permute(0, 3, 1, 2)
 408            else:
 409                assert feats.ndim == 3
 410                h = w = math.sqrt(feats.shape[1])
 411                feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
 412            return feats
 413
 414        else:
 415            return x
 416
 417    def forward(self, x: torch.Tensor):
 418        """Apply the vision transformer to input data.
 419
 420        Args:
 421            x: The input data.
 422
 423        Returns:
 424            The vision transformer output.
 425        """
 426        x, list_from_encoder = self.forward_features(x)
 427        return x, list_from_encoder
 428
 429#
 430# Utilities for ScaleMAE's ViT
 431#
 432
 433
 434class CustomCompose:
 435    def __init__(self, rescale_transform, other_transforms, src_transform):
 436        self.rescale_transform = rescale_transform
 437        self.other_transforms = other_transforms
 438        self.src_transform = src_transform
 439
 440    def __call__(self, x, valid_masks=None):
 441        if valid_masks is not None:
 442            nodata = (x * (1 - valid_masks.float())).max()
 443        x_aug = self.rescale_transform(x)
 444        parms = self.rescale_transform._params
 445
 446        # sanity check, comment if this is working
 447        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
 448        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
 449
 450        if valid_masks is not None:
 451            valid_masks = x_aug != nodata
 452            _, c, h, w = x_aug.shape
 453            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
 454        else:
 455            zero_ratio = -1
 456
 457        if self.other_transforms:
 458            x_aug = self.other_transforms(x_aug)
 459        x_src = self.src_transform(x_aug)
 460        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
 461
 462        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
 463        # assert (dx == dy).all()
 464
 465        h, w = x_aug.shape[-2:]
 466        # assert h == w
 467
 468        return x_aug, x_src, dx / h, zero_ratio, valid_masks
 469
 470
 471def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
 472    """
 473    grid_size: int of the grid height and width
 474    res: array of size n, representing the resolution of a pixel (say, in meters),
 475    return:
 476    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
 477    """
 478    # res = torch.FloatTensor(res).to(device)
 479    res = res.to(device)
 480    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
 481    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
 482    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
 483    grid = torch.stack(grid, dim=0)  # 2 x h x w
 484
 485    # grid = grid.reshape([2, 1, grid_size, grid_size])
 486    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
 487    _, n, h, w = grid.shape
 488    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
 489    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
 490    if cls_token:
 491        pos_embed = torch.cat(
 492            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
 493        )
 494
 495    return pos_embed
 496
 497
 498def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
 499    assert embed_dim % 2 == 0
 500
 501    # use half of dimensions to encode grid_h
 502    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
 503    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
 504
 505    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
 506    return emb
 507
 508
 509def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
 510    """
 511    embed_dim: output dimension for each position
 512    pos: a list of positions to be encoded: size (M,)
 513    out: (M, D)
 514    """
 515    assert embed_dim % 2 == 0
 516    # old_shape = pos
 517    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
 518    omega /= embed_dim / 2.0
 519    omega = 1.0 / 10000**omega  # (D/2,)
 520
 521    pos = pos.reshape(-1)  # (M,)
 522    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
 523
 524    emb_sin = torch.sin(out)  # (M, D/2)
 525    emb_cos = torch.cos(out)  # (M, D/2)
 526
 527    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
 528    return emb
 529
 530
 531class PatchEmbedUnSafe(PatchEmbed):
 532    """Image to Patch Embedding"""
 533
 534    def forward(self, x):
 535        B, C, H, W = x.shape
 536
 537        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
 538        # assert H == self.img_size[0] and W == self.img_size[1], \
 539        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
 540
 541        x = self.proj(x).flatten(2).transpose(1, 2)
 542        return x
 543
 544
 545class ViT_ScaleMAE(VisionTransformer):
 546    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
 547
 548    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
 549    the model on a different zoom factor dataset.
 550    """
 551
 552    def __init__(
 553        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
 554    ):
 555        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
 556        self.img_size = img_size
 557        self.in_chans = in_chans
 558        self.depth = depth
 559        self.base_resolution = base_resolution
 560
 561        self.patch_embed = PatchEmbedUnSafe(
 562            img_size=img_size,
 563            patch_size=patch_size,
 564            in_chans=in_chans,
 565            embed_dim=embed_dim,
 566        )
 567
 568    def transform_inputs(self, x):
 569        import kornia.augmentation as K
 570        from kornia.constants import Resample
 571
 572        self._transforms = CustomCompose(
 573            rescale_transform=K.RandomResizedCrop(
 574                (448, 448),
 575                ratio=(1.0, 1.0),
 576                scale=(1.0, 1.0),
 577                resample=Resample.BICUBIC.name,
 578            ),
 579            other_transforms=None,
 580            src_transform=K.Resize((224, 224)),
 581        )
 582        x, _, ratios, _, _ = self._transforms(x)
 583        input_res = ratios * self.base_resolution
 584        return x, input_res
 585
 586    def convert_to_expected_dim(self, x):
 587        inputs_ = x[:, 1:, :]  # removing the class tokens
 588        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
 589        rdim = inputs_.shape[1]
 590        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
 591        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
 592        inputs_ = inputs_.permute(0, 3, 1, 2)
 593        return inputs_
 594
 595    def forward_features(self, x):
 596        x, input_res = self.transform_inputs(x)
 597
 598        B, _, h, w = x.shape
 599        x = self.patch_embed(x)
 600
 601        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
 602        pos_embed = get_2d_sincos_pos_embed_with_resolution(
 603            x.shape[-1],
 604            int(num_patches ** 0.5),
 605            input_res,
 606            cls_token=True,
 607            device=x.device,
 608        )
 609
 610        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
 611        x = torch.cat((cls_tokens, x), dim=1)
 612        x = x + pos_embed
 613        x = self.pos_drop(x)
 614
 615        # chunks obtained for getting the projections for conjuctions with upsampling blocks
 616        _chunks = int(self.depth / 4)
 617        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
 618
 619        list_from_encoder = []
 620        for i, blk in enumerate(self.blocks):
 621            x = blk(x)
 622            if i in chunks_for_projection:
 623                list_from_encoder.append(self.convert_to_expected_dim(x))
 624
 625        x = self.convert_to_expected_dim(x)
 626
 627        return x, list_from_encoder
 628
 629    def forward(self, x):
 630        x, list_from_encoder = self.forward_features(x)
 631        return x, list_from_encoder
 632
 633
 634class ViT_DINOv2(DinoV2VisionTransformer):
 635    """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
 636
 637    Based on:
 638    https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
 639
 640    Args:
 641        img_size: The input image size.
 642        patch_size: The patch size.
 643        depth: The depth of the network.
 644        num_register_tokens: The number of registers added (in addition to the class tokens).
 645            It's important to know for ViTs trained with registers, to remove them at the end.
 646    """
 647    def __init__(
 648        self,
 649        img_size: int = 224,
 650        patch_size: int = 16,
 651        depth: int = 12,
 652        num_register_tokens: int = 0,
 653        **kwargs
 654    ):
 655        if not _dinov2_import_success:
 656            raise RuntimeError(
 657                "The vision transformer backend can only be initialized if DINOv2 is installed. "
 658                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
 659                "and then rerun your code."
 660            )
 661
 662        super().__init__(
 663            img_size=img_size,
 664            depth=depth,
 665            patch_size=patch_size,
 666            num_register_tokens=num_register_tokens,
 667            **kwargs
 668        )
 669
 670        self.img_size = img_size
 671        self.num_register_tokens = num_register_tokens
 672        self.patch_size = patch_size
 673        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
 674
 675    def forward(self, x, masks=None) -> torch.Tensor:
 676
 677        B = x.shape[0]
 678
 679        x = self.prepare_tokens_with_masks(x)
 680
 681        list_of_encoder = []
 682        for i, blk in enumerate(self.blocks):
 683            x = blk(x)
 684            if i in self.attn_outs:
 685                list_of_encoder.append(x)
 686
 687        x = self.norm(x)
 688        x = x[:, self.num_register_tokens + 1:].reshape(
 689            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
 690        ).permute(0, 3, 1, 2).contiguous()
 691
 692        list_of_encoder = [
 693            o[:, self.num_register_tokens + 1:].reshape(
 694                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
 695            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
 696        ]
 697
 698        return x, list_of_encoder[:3]
 699
 700
 701class ViT_DINOv3(DinoV3VisionTransformer):
 702    """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
 703
 704    Based on:
 705    https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
 706
 707    Args:
 708        img_size: The input image size.
 709        patch_size: The patch size.
 710        embed_dim: The embedding dimension.
 711        depth: The depth of the network.
 712        num_heads: The number of heads.
 713        ffn_ratio: The FFN rato.
 714        n_storage_tokens: The number of storage (class) tokens to remove.
 715        kwargs: Keyword arguments for the image encoder base class.
 716    """
 717    def __init__(
 718        self,
 719        in_chans: int = 3,
 720        img_size: int = 224,
 721        patch_size: int = 16,
 722        embed_dim: int = 768,
 723        depth: int = 12,
 724        num_heads: int = 12,
 725        ffn_ratio: float = 4.0,
 726        n_storage_tokens: int = 0,
 727        **kwargs
 728    ):
 729        if not _dinov3_import_success:
 730            raise RuntimeError(
 731                "The vision transformer backend can only be initialized if DINOv3 is installed. "
 732                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
 733                "and then rerun your code."
 734            )
 735
 736        super().__init__(
 737            in_chans=in_chans,
 738            img_size=img_size,
 739            patch_size=patch_size,
 740            embed_dim=embed_dim,
 741            depth=depth,
 742            num_heads=num_heads,
 743            ffn_ratio=ffn_ratio,
 744            n_storage_tokens=n_storage_tokens,
 745            **kwargs
 746        )
 747
 748        self.in_chans = in_chans
 749        self.img_size = img_size
 750        self.n_storage_tokens = n_storage_tokens
 751        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
 752
 753    def forward(self, x) -> torch.Tensor:
 754
 755        B = x.shape[0]
 756
 757        x, hw_tuple = self.prepare_tokens_with_masks(x)
 758
 759        list_of_encoder = []
 760        for i, blk in enumerate(self.blocks):
 761            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
 762            x = blk(x, rope_sincos)
 763            if i in self.attn_outs:
 764                list_of_encoder.append(x)
 765
 766        x = self.norm(x)
 767        x = x[:, self.n_storage_tokens + 1:].reshape(
 768            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
 769        ).permute(0, 3, 1, 2).contiguous()
 770
 771        list_of_encoder = [
 772            o[:, self.n_storage_tokens + 1:].reshape(
 773                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
 774            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
 775        ]
 776
 777        return x, list_of_encoder[:3]
 778
 779
 780def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
 781    """Get vision transformer encoder.
 782
 783    Args:
 784        backbone: The name of the vision transformer implementation.
 785            One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3".
 786        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
 787        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 788        kwargs: Additional kwargs which can be expected by the vision transformer,
 789            e.g. 'base_resolution' for `ViT_ScaleMAE`.
 790
 791    Returns:
 792        The vision transformer.
 793    """
 794    if backbone == "sam":
 795        if model == "vit_b":
 796            encoder = ViT_Sam(
 797                depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4,
 798                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 799                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
 800                global_attn_indexes=[2, 5, 8, 11],
 801                window_size=14, out_chans=256,
 802            )
 803        elif model == "vit_l":
 804            encoder = ViT_Sam(
 805                depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4,
 806                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 807                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
 808                global_attn_indexes=[5, 11, 17, 23],
 809                window_size=14, out_chans=256,
 810            )
 811        elif model == "vit_h":
 812            encoder = ViT_Sam(
 813                depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4,
 814                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 815                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
 816                global_attn_indexes=[7, 15, 23, 31],
 817                window_size=14, out_chans=256,
 818            )
 819        else:
 820            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
 821
 822    elif backbone == "cellpose_sam":
 823        if model != "vit_l":
 824            raise ValueError(f"'{model}' is not supported by CellposeSAM. Only 'vit_l' is supported.")
 825        encoder = ViT_CellposeSAM(ps=8, bsize=img_size)
 826
 827    elif backbone == "sam2":
 828        if model == "hvit_t":
 829            encoder = ViT_Sam2(
 830                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
 831                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
 832            )
 833        elif model == "hvit_s":
 834            encoder = ViT_Sam2(
 835                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
 836                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
 837            )
 838        elif model == "hvit_b":
 839            encoder = ViT_Sam2(
 840                img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
 841            )
 842        elif model == "hvit_l":
 843            encoder = ViT_Sam2(
 844                img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
 845                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
 846            )
 847        else:
 848            raise ValueError(
 849                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
 850            )
 851
 852    elif backbone == "sam3":
 853        if model != "vit_pe":
 854            raise ValueError(
 855                "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration."
 856            )
 857
 858        encoder = ViT_Sam3(
 859            img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16,
 860            mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True,
 861            tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True,
 862            use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True,
 863            ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None,
 864        )
 865
 866    elif backbone == "mae":
 867        if model == "vit_b":
 868            encoder = ViT_MAE(
 869                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
 870                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 871            )
 872        elif model == "vit_l":
 873            encoder = ViT_MAE(
 874                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
 875                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 876            )
 877        elif model == "vit_h":
 878            encoder = ViT_MAE(
 879                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
 880                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 881            )
 882        else:
 883            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
 884
 885    elif backbone == "scalemae":
 886        base_resolution = kwargs.get("base_resolution", 2.5)
 887
 888        if model == "vit_b":
 889            encoder = ViT_ScaleMAE(
 890                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
 891                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 892                base_resolution=base_resolution,
 893            )
 894        elif model == "vit_l":
 895            encoder = ViT_ScaleMAE(
 896                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
 897                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 898                base_resolution=base_resolution,
 899            )
 900        elif model == "vit_h":
 901            encoder = ViT_ScaleMAE(
 902                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
 903                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 904                base_resolution=base_resolution,
 905            )
 906        else:
 907            raise ValueError(
 908                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
 909            )
 910
 911    elif backbone == "dinov2":
 912        block_fn = partial(Block, attn_class=MemEffAttention)
 913        msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>."
 914
 915        if model.startswith("vit_s"):
 916            assert model in ["vit_s", "vit_s_reg4"], msg
 917            encoder = ViT_DINOv2(
 918                img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
 919                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 920                num_register_tokens=4 if model.endswith("_reg4") else 0,
 921            )
 922        elif model.startswith("vit_b"):
 923            assert model in ["vit_b", "vit_b_reg4"], msg
 924            encoder = ViT_DINOv2(
 925                img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
 926                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 927                num_register_tokens=4 if model.endswith("_reg4") else 0,
 928            )
 929        elif model.startswith("vit_l"):
 930            assert model in ["vit_l", "vit_l_reg4"], msg
 931            encoder = ViT_DINOv2(
 932                img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
 933                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 934                num_register_tokens=4 if model.endswith("_reg4") else 0,
 935            )
 936        elif model.startswith("vit_g"):
 937            assert model in ["vit_g", "vit_g_reg4"], msg
 938            encoder = ViT_DINOv2(
 939                img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4,
 940                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 941                num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu",
 942            )
 943        else:
 944            raise ValueError(
 945                f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported."
 946            )
 947
 948    elif backbone == "dinov3":
 949
 950        if model == "vit_s":
 951            encoder = ViT_DINOv3(
 952                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
 953                num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
 954            )
 955        elif model == "vit_s+":
 956            encoder = ViT_DINOv3(
 957                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
 958                num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 959                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 960            )
 961
 962        elif model == "vit_b":
 963            encoder = ViT_DINOv3(
 964                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32",
 965                layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
 966            )
 967        elif model == "vit_l":
 968            encoder = ViT_DINOv3(
 969                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
 970                depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 971                n_storage_tokens=4, mask_k_bias=True,
 972            )
 973        elif model == "vit_l+":
 974            encoder = ViT_DINOv3(
 975                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
 976                depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 977                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 978            )
 979        elif model == "vit_h+":
 980            encoder = ViT_DINOv3(
 981                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280,
 982                depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 983                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 984            )
 985        elif model == "vit_7b":
 986            encoder = ViT_DINOv3(
 987                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096,
 988                depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05,
 989                norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True,
 990                untie_global_and_local_cls_norm=True,
 991            )
 992        else:
 993            raise ValueError(
 994                f"'{model}' is not supported by DINOv3. Currently, "
 995                " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported."
 996            )
 997
 998    else:
 999        raise ValueError(
1000            "The 'UNETR' supported backbones are 'sam', 'cellpose_sam', 'sam2', 'sam3', "
1001            "'mae', 'scalemae', 'dinov2' or 'dinov3'. Please choose one of them."
1002        )
1003
1004    return encoder
class ViT_Sam:
 59class ViT_Sam(ImageEncoderViT):
 60    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
 61
 62    Based on:
 63    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 64
 65    Args:
 66        in_chans: The number of input channels.
 67        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 68        global_attn_indexes: The global attention indices.
 69        kwargs: Keyword arguments for the image encoder base class.
 70    """
 71    def __init__(
 72        self,
 73        in_chans: int = 3,
 74        embed_dim: int = 768,
 75        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
 76        **kwargs,
 77    ) -> None:
 78        if not _sam_import_success:
 79            raise RuntimeError(
 80                "The vision transformer backend can only be initialized if segment anything is installed. "
 81                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
 82                "and then rerun your code."
 83            )
 84
 85        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
 86        self.chunks_for_projection = global_attn_indexes
 87        self.in_chans = in_chans
 88        self.embed_dim = embed_dim
 89
 90    def forward(self, x: torch.Tensor) -> torch.Tensor:
 91        """Apply the vision transformer to input data.
 92
 93        Args:
 94            x: The input data.
 95
 96        Returns:
 97            The vision transformer output.
 98        """
 99        x = self.patch_embed(x)
100        if self.pos_embed is not None:
101            x = x + self.pos_embed
102
103        list_from_encoder = []
104        for i, blk in enumerate(self.blocks):
105            x = blk(x)
106            if i in self.chunks_for_projection:
107                list_from_encoder.append(x)
108
109        x = x.permute(0, 3, 1, 2)
110        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
111        return x, list_from_encoder[:3]

Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).

Based on: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

Arguments:
  • in_chans: The number of input channels.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • global_attn_indexes: The global attention indices.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam( in_chans: int = 3, embed_dim: int = 768, global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], **kwargs)
71    def __init__(
72        self,
73        in_chans: int = 3,
74        embed_dim: int = 768,
75        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
76        **kwargs,
77    ) -> None:
78        if not _sam_import_success:
79            raise RuntimeError(
80                "The vision transformer backend can only be initialized if segment anything is installed. "
81                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
82                "and then rerun your code."
83            )
84
85        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
86        self.chunks_for_projection = global_attn_indexes
87        self.in_chans = in_chans
88        self.embed_dim = embed_dim
chunks_for_projection
in_chans
embed_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
 90    def forward(self, x: torch.Tensor) -> torch.Tensor:
 91        """Apply the vision transformer to input data.
 92
 93        Args:
 94            x: The input data.
 95
 96        Returns:
 97            The vision transformer output.
 98        """
 99        x = self.patch_embed(x)
100        if self.pos_embed is not None:
101            x = x + self.pos_embed
102
103        list_from_encoder = []
104        for i, blk in enumerate(self.blocks):
105            x = blk(x)
106            if i in self.chunks_for_projection:
107                list_from_encoder.append(x)
108
109        x = x.permute(0, 3, 1, 2)
110        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
111        return x, list_from_encoder[:3]

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_CellposeSAM(torch.nn.modules.module.Module):
114class ViT_CellposeSAM(nn.Module):
115    """Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x).
116
117    This replicates CellposeSAM's actual initialization: instantiate SAM's ``ImageEncoderViT`` via
118    ``sam_model_registry``, then modify the patch embedding, position embeddings, and set global attention.
119    This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading
120    without any interpolation.
121
122    Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py
123
124    NOTE: The pretrained CellposeSAM model uses ``vit_l`` exclusively.
125
126    Args:
127        ps: The patch size (default for CellposeSAM is 8).
128        bsize: The input image size (default for CellposeSAM is 256).
129    """
130    def __init__(self, ps: int = 8, bsize: int = 256) -> None:
131        super().__init__()
132
133        if not _sam_import_success:
134            raise RuntimeError(
135                "The vision transformer backend can only be initialized if segment anything is installed. "
136                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
137                "and then rerun your code."
138            )
139
140        from segment_anything import sam_model_registry
141
142        # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer).
143        encoder = sam_model_registry["vit_l"](None).image_encoder
144
145        w = encoder.patch_embed.proj.weight.detach()
146        nchan = w.shape[0]
147
148        # CellPoseSAM changes the patch size from 16 to 'ps'.
149        encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
150        encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps]
151
152        # Next, they subsample position embeddings for the new patch size and input resolution.
153        ds = (1024 // 16) // (bsize // ps)
154        encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True)
155
156        # Finally, they set all blocks to global attention.
157        for blk in encoder.blocks:
158            blk.window_size = 0
159
160        # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping).
161        self.patch_embed = encoder.patch_embed
162        self.pos_embed = encoder.pos_embed
163        self.blocks = encoder.blocks
164        self.neck = encoder.neck
165
166        # Additional attributes expected by UNETR.
167        self.embed_dim = nchan
168        self.img_size = bsize
169        self.in_chans = 3
170
171        # Feature extraction at evenly-spaced depths.
172        depth = len(self.blocks)
173        _chunks = depth // 4
174        self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1]
175
176    def forward(self, x: torch.Tensor) -> torch.Tensor:
177        """Apply the vision transformer to input data.
178
179        Args:
180            x: The input data.
181
182        Returns:
183            The vision transformer output.
184        """
185        x = self.patch_embed(x)
186        if self.pos_embed is not None:
187            x = x + self.pos_embed
188
189        list_from_encoder = []
190        for i, blk in enumerate(self.blocks):
191            x = blk(x)
192            if i in self.chunks_for_projection:
193                list_from_encoder.append(x)
194
195        x = x.permute(0, 3, 1, 2)
196        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
197        return x, list_from_encoder[:3]

Vision Transformer derived from the CellposeSAM Codebase (https://doi.org/10.1038/s41592-025-02595-x).

This replicates CellposeSAM's actual initialization: instantiate SAM's ImageEncoderViT via sam_model_registry, then modify the patch embedding, position embeddings, and set global attention. This preserves SAM's original relative position bias sizes, enabling direct checkpoint loading without any interpolation.

Based on: https://github.com/MouseLand/cellpose/blob/main/cellpose/vit_sam.py

NOTE: The pretrained CellposeSAM model uses vit_l exclusively.

Arguments:
  • ps: The patch size (default for CellposeSAM is 8).
  • bsize: The input image size (default for CellposeSAM is 256).
ViT_CellposeSAM(ps: int = 8, bsize: int = 256)
130    def __init__(self, ps: int = 8, bsize: int = 256) -> None:
131        super().__init__()
132
133        if not _sam_import_success:
134            raise RuntimeError(
135                "The vision transformer backend can only be initialized if segment anything is installed. "
136                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
137                "and then rerun your code."
138            )
139
140        from segment_anything import sam_model_registry
141
142        # Creates the SAM vit_l encoder and applies CellposeSAM's modifications (same as cellpose.vit_sam.Transformer).
143        encoder = sam_model_registry["vit_l"](None).image_encoder
144
145        w = encoder.patch_embed.proj.weight.detach()
146        nchan = w.shape[0]
147
148        # CellPoseSAM changes the patch size from 16 to 'ps'.
149        encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
150        encoder.patch_embed.proj.weight.data = w[:, :, ::16 // ps, ::16 // ps]
151
152        # Next, they subsample position embeddings for the new patch size and input resolution.
153        ds = (1024 // 16) // (bsize // ps)
154        encoder.pos_embed = nn.Parameter(encoder.pos_embed[:, ::ds, ::ds], requires_grad=True)
155
156        # Finally, they set all blocks to global attention.
157        for blk in encoder.blocks:
158            blk.window_size = 0
159
160        # Store encoder submodules directly ('state_dict' keys match CellposeSAM after prefix stripping).
161        self.patch_embed = encoder.patch_embed
162        self.pos_embed = encoder.pos_embed
163        self.blocks = encoder.blocks
164        self.neck = encoder.neck
165
166        # Additional attributes expected by UNETR.
167        self.embed_dim = nchan
168        self.img_size = bsize
169        self.in_chans = 3
170
171        # Feature extraction at evenly-spaced depths.
172        depth = len(self.blocks)
173        _chunks = depth // 4
174        self.chunks_for_projection = [_chunks - 1, 2 * _chunks - 1, 3 * _chunks - 1, 4 * _chunks - 1]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

patch_embed
pos_embed
blocks
neck
embed_dim
img_size
in_chans
chunks_for_projection
def forward(self, x: torch.Tensor) -> torch.Tensor:
176    def forward(self, x: torch.Tensor) -> torch.Tensor:
177        """Apply the vision transformer to input data.
178
179        Args:
180            x: The input data.
181
182        Returns:
183            The vision transformer output.
184        """
185        x = self.patch_embed(x)
186        if self.pos_embed is not None:
187            x = x + self.pos_embed
188
189        list_from_encoder = []
190        for i, blk in enumerate(self.blocks):
191            x = blk(x)
192            if i in self.chunks_for_projection:
193                list_from_encoder.append(x)
194
195        x = x.permute(0, 3, 1, 2)
196        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
197        return x, list_from_encoder[:3]

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_MAE:
200class ViT_MAE(VisionTransformer):
201    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
202
203    Based on:
204    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
205
206    Args:
207        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
208        in_chans: The number of input channels.
209        depth: The depth of the vision transformer.
210        kwargs: Additional keyword arguments for the vision transformer base class.
211    """
212    def __init__(
213        self,
214        img_size: int = 1024,  # chosen to match our experiments with segment anything
215        in_chans: int = 3,
216        depth: int = 12,
217        **kwargs
218    ):
219        if not _timm_import_success:
220            raise RuntimeError(
221                "The vision transformer backend can only be initialized if timm is installed. "
222                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
223                "and then rerun your code"
224            )
225        super().__init__(img_size=img_size, depth=depth, **kwargs)
226        self.img_size = img_size
227        self.in_chans = in_chans
228        self.depth = depth
229
230    def convert_to_expected_dim(self, inputs_):
231        """@private
232        """
233        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
234        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
235        rdim = inputs_.shape[1]
236        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
237        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
238        inputs_ = inputs_.permute(0, 3, 1, 2)
239        return inputs_
240
241    def forward_features(self, x):
242        """@private
243        """
244        B = x.shape[0]
245        x = self.patch_embed(x)
246
247        cls_tokens = self.cls_token.expand(B, -1, -1)
248        x = torch.cat((cls_tokens, x), dim=1)
249
250        x = x + self.pos_embed
251        x = self.pos_drop(x)
252
253        # chunks obtained for getting the projections for conjuctions with upsampling blocks
254        _chunks = int(self.depth / 4)
255        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
256
257        list_from_encoder = []
258        for i, blk in enumerate(self.blocks):
259            x = blk(x)
260            if i in chunks_for_projection:
261                list_from_encoder.append(self.convert_to_expected_dim(x))
262
263        x = self.convert_to_expected_dim(x)
264        return x, list_from_encoder[:3]
265
266    def forward(self, x: torch.Tensor) -> torch.Tensor:
267        """Apply the vision transformer to input data.
268
269        Args:
270            x: The input data.
271
272        Returns:
273            The vision transformer output.
274        """
275        x, list_from_encoder = self.forward_features(x)
276        return x, list_from_encoder

Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).

Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • in_chans: The number of input channels.
  • depth: The depth of the vision transformer.
  • kwargs: Additional keyword arguments for the vision transformer base class.
ViT_MAE(img_size: int = 1024, in_chans: int = 3, depth: int = 12, **kwargs)
212    def __init__(
213        self,
214        img_size: int = 1024,  # chosen to match our experiments with segment anything
215        in_chans: int = 3,
216        depth: int = 12,
217        **kwargs
218    ):
219        if not _timm_import_success:
220            raise RuntimeError(
221                "The vision transformer backend can only be initialized if timm is installed. "
222                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
223                "and then rerun your code"
224            )
225        super().__init__(img_size=img_size, depth=depth, **kwargs)
226        self.img_size = img_size
227        self.in_chans = in_chans
228        self.depth = depth
img_size
in_chans
depth
def forward(self, x: torch.Tensor) -> torch.Tensor:
266    def forward(self, x: torch.Tensor) -> torch.Tensor:
267        """Apply the vision transformer to input data.
268
269        Args:
270            x: The input data.
271
272        Returns:
273            The vision transformer output.
274        """
275        x, list_from_encoder = self.forward_features(x)
276        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_Sam2:
279class ViT_Sam2(ImageEncoder):
280    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
281
282    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
283
284    Args:
285        backbone_channel_list: The channels throughout the entire backbone.
286        embed_dim: The initial embedding dimension.
287        num_heads: The initial number of heads.
288        stages: The number of blocks per stage.
289        global_att_blocks: The parameter to decide which blocks have global attention.
290        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
291        window_spec: The window size per stage, when not using global attention.
292        scalp: The count of lowest resolution features to discard.
293    """
294    def __init__(
295        self,
296        backbone_channel_list: List[int],
297        img_size: int = 1024,
298        embed_dim: int = 96,
299        num_heads: int = 1,
300        stages: Tuple[int, ...] = (2, 3, 16, 3),
301        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
302        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
303        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
304        scalp: int = 1,
305        **kwargs
306    ):
307        if not _sam2_import_success:
308            raise RuntimeError(
309                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
310                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
311                "and then rerun your code"
312            )
313
314        trunk = Hiera(
315            embed_dim=embed_dim,
316            num_heads=num_heads,
317            stages=stages,
318            global_att_blocks=global_att_blocks,
319            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
320            window_spec=window_spec,
321        )
322        neck = FpnNeck(
323            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
324            d_model=256,
325            backbone_channel_list=backbone_channel_list,
326            fpn_top_down_levels=[2, 3],
327            fpn_interp_model="nearest",
328        )
329
330        super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs)
331        self.scalp = scalp
332        self.embed_dim = embed_dim
333        self.img_size = img_size
334
335    def forward(self, x: torch.Tensor):
336        # The forward pass throught the backbone.
337        features, pos = self.neck(self.trunk(x))
338        if self.scalp > 0:  # This discard the "n" lowest resolution features.
339            features, pos = features[:-self.scalp], pos[:-self.scalp]
340
341        return features[-1], features

Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).

Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.

Arguments:
  • backbone_channel_list: The channels throughout the entire backbone.
  • embed_dim: The initial embedding dimension.
  • num_heads: The initial number of heads.
  • stages: The number of blocks per stage.
  • global_att_blocks: The parameter to decide which blocks have global attention.
  • window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
  • window_spec: The window size per stage, when not using global attention.
  • scalp: The count of lowest resolution features to discard.
ViT_Sam2( backbone_channel_list: List[int], img_size: int = 1024, embed_dim: int = 96, num_heads: int = 1, stages: Tuple[int, ...] = (2, 3, 16, 3), global_att_blocks: Tuple[int, ...] = (12, 16, 20), window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), window_spec: Tuple[int, ...] = (8, 4, 14, 7), scalp: int = 1, **kwargs)
294    def __init__(
295        self,
296        backbone_channel_list: List[int],
297        img_size: int = 1024,
298        embed_dim: int = 96,
299        num_heads: int = 1,
300        stages: Tuple[int, ...] = (2, 3, 16, 3),
301        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
302        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
303        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
304        scalp: int = 1,
305        **kwargs
306    ):
307        if not _sam2_import_success:
308            raise RuntimeError(
309                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
310                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
311                "and then rerun your code"
312            )
313
314        trunk = Hiera(
315            embed_dim=embed_dim,
316            num_heads=num_heads,
317            stages=stages,
318            global_att_blocks=global_att_blocks,
319            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
320            window_spec=window_spec,
321        )
322        neck = FpnNeck(
323            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
324            d_model=256,
325            backbone_channel_list=backbone_channel_list,
326            fpn_top_down_levels=[2, 3],
327            fpn_interp_model="nearest",
328        )
329
330        super().__init__(trunk=trunk, neck=neck, scalp=scalp, **kwargs)
331        self.scalp = scalp
332        self.embed_dim = embed_dim
333        self.img_size = img_size
scalp
embed_dim
img_size
def forward(self, x: torch.Tensor):
335    def forward(self, x: torch.Tensor):
336        # The forward pass throught the backbone.
337        features, pos = self.neck(self.trunk(x))
338        if self.scalp > 0:  # This discard the "n" lowest resolution features.
339            features, pos = features[:-self.scalp], pos[:-self.scalp]
340
341        return features[-1], features
class ViT_Sam3:
344class ViT_Sam3(SAM3ViT):
345    """Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).
346
347    Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.
348
349    Args:
350        img_size: The input image size.
351        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
352        kwargs: Keyword arguments for the image encoder base class.
353    """
354    def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs):
355        if not _sam3_import_success:
356            raise RuntimeError(
357                "The vision transformer backend can only be initialized if segment anything 3 is installed. "
358                "Please install segment anything 3 from https://github.com/facebookresearch/sam3 "
359                "and then rerun your code"
360            )
361
362        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
363        self.img_size = img_size
364        self.embed_dim = embed_dim
365
366    def forward_features(self, x):
367        """@private
368        """
369        x = self.patch_embed(x)
370        h, w = x.shape[1], x.shape[2]
371
372        s = 0
373        if self.retain_cls_token:
374            # If the 'cls_token' is retained, we don't maintain the spatial shape.
375            x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
376            s = 1
377
378        if self.pos_embed is not None:
379            x = x + get_abs_pos(
380                self.pos_embed, self.pretrain_use_cls_token, (h, w), self.retain_cls_token, tiling=self.tile_abs_pos,
381            )
382
383        x = self.ln_pre(x)
384
385        list_from_encoder = []
386        for i, blk in enumerate(self.blocks):
387            if self.use_act_checkpoint and self.training:
388                x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False)
389            else:
390                x = blk(x)
391
392            x = self._convert_to_expected_dim(x, i, s)
393
394            if i in self.full_attn_ids:
395                list_from_encoder.append(x)
396
397        return x, list_from_encoder
398
399    def _convert_to_expected_dim(self, x, i, s):
400        if (i == self.full_attn_ids[-1]) or (
401            self.return_interm_layers and i in self.full_attn_ids
402        ):
403            if i == self.full_attn_ids[-1]:
404                x = self.ln_post(x)
405
406            feats = x[:, s:]
407            if feats.ndim == 4:
408                feats = feats.permute(0, 3, 1, 2)
409            else:
410                assert feats.ndim == 3
411                h = w = math.sqrt(feats.shape[1])
412                feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
413            return feats
414
415        else:
416            return x
417
418    def forward(self, x: torch.Tensor):
419        """Apply the vision transformer to input data.
420
421        Args:
422            x: The input data.
423
424        Returns:
425            The vision transformer output.
426        """
427        x, list_from_encoder = self.forward_features(x)
428        return x, list_from_encoder

Vision Transformer derived from the Segment Anything 3 Codebase (https://arxiv.org/abs/2511.16719).

Based on https://github.com/facebookresearch/sam3/blob/main/sam3/model/vitdet.py.

Arguments:
  • img_size: The input image size.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam3(img_size: int = 1024, embed_dim: int = 768, **kwargs)
354    def __init__(self, img_size: int = 1024, embed_dim: int = 768, **kwargs):
355        if not _sam3_import_success:
356            raise RuntimeError(
357                "The vision transformer backend can only be initialized if segment anything 3 is installed. "
358                "Please install segment anything 3 from https://github.com/facebookresearch/sam3 "
359                "and then rerun your code"
360            )
361
362        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
363        self.img_size = img_size
364        self.embed_dim = embed_dim
img_size
embed_dim
def forward(self, x: torch.Tensor):
418    def forward(self, x: torch.Tensor):
419        """Apply the vision transformer to input data.
420
421        Args:
422            x: The input data.
423
424        Returns:
425            The vision transformer output.
426        """
427        x, list_from_encoder = self.forward_features(x)
428        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class CustomCompose:
435class CustomCompose:
436    def __init__(self, rescale_transform, other_transforms, src_transform):
437        self.rescale_transform = rescale_transform
438        self.other_transforms = other_transforms
439        self.src_transform = src_transform
440
441    def __call__(self, x, valid_masks=None):
442        if valid_masks is not None:
443            nodata = (x * (1 - valid_masks.float())).max()
444        x_aug = self.rescale_transform(x)
445        parms = self.rescale_transform._params
446
447        # sanity check, comment if this is working
448        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
449        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
450
451        if valid_masks is not None:
452            valid_masks = x_aug != nodata
453            _, c, h, w = x_aug.shape
454            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
455        else:
456            zero_ratio = -1
457
458        if self.other_transforms:
459            x_aug = self.other_transforms(x_aug)
460        x_src = self.src_transform(x_aug)
461        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
462
463        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
464        # assert (dx == dy).all()
465
466        h, w = x_aug.shape[-2:]
467        # assert h == w
468
469        return x_aug, x_src, dx / h, zero_ratio, valid_masks
CustomCompose(rescale_transform, other_transforms, src_transform)
436    def __init__(self, rescale_transform, other_transforms, src_transform):
437        self.rescale_transform = rescale_transform
438        self.other_transforms = other_transforms
439        self.src_transform = src_transform
rescale_transform
other_transforms
src_transform
def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device='cpu'):
472def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
473    """
474    grid_size: int of the grid height and width
475    res: array of size n, representing the resolution of a pixel (say, in meters),
476    return:
477    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
478    """
479    # res = torch.FloatTensor(res).to(device)
480    res = res.to(device)
481    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
482    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
483    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
484    grid = torch.stack(grid, dim=0)  # 2 x h x w
485
486    # grid = grid.reshape([2, 1, grid_size, grid_size])
487    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
488    _, n, h, w = grid.shape
489    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
490    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
491    if cls_token:
492        pos_embed = torch.cat(
493            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
494        )
495
496    return pos_embed

grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)

def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
499def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
500    assert embed_dim % 2 == 0
501
502    # use half of dimensions to encode grid_h
503    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
504    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
505
506    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
507    return emb
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
510def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
511    """
512    embed_dim: output dimension for each position
513    pos: a list of positions to be encoded: size (M,)
514    out: (M, D)
515    """
516    assert embed_dim % 2 == 0
517    # old_shape = pos
518    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
519    omega /= embed_dim / 2.0
520    omega = 1.0 / 10000**omega  # (D/2,)
521
522    pos = pos.reshape(-1)  # (M,)
523    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
524
525    emb_sin = torch.sin(out)  # (M, D/2)
526    emb_cos = torch.cos(out)  # (M, D/2)
527
528    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
529    return emb

embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)

class PatchEmbedUnSafe:
532class PatchEmbedUnSafe(PatchEmbed):
533    """Image to Patch Embedding"""
534
535    def forward(self, x):
536        B, C, H, W = x.shape
537
538        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
539        # assert H == self.img_size[0] and W == self.img_size[1], \
540        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
541
542        x = self.proj(x).flatten(2).transpose(1, 2)
543        return x

Image to Patch Embedding

def forward(self, x):
535    def forward(self, x):
536        B, C, H, W = x.shape
537
538        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
539        # assert H == self.img_size[0] and W == self.img_size[1], \
540        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
541
542        x = self.proj(x).flatten(2).transpose(1, 2)
543        return x
class ViT_ScaleMAE:
546class ViT_ScaleMAE(VisionTransformer):
547    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
548
549    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
550    the model on a different zoom factor dataset.
551    """
552
553    def __init__(
554        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
555    ):
556        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
557        self.img_size = img_size
558        self.in_chans = in_chans
559        self.depth = depth
560        self.base_resolution = base_resolution
561
562        self.patch_embed = PatchEmbedUnSafe(
563            img_size=img_size,
564            patch_size=patch_size,
565            in_chans=in_chans,
566            embed_dim=embed_dim,
567        )
568
569    def transform_inputs(self, x):
570        import kornia.augmentation as K
571        from kornia.constants import Resample
572
573        self._transforms = CustomCompose(
574            rescale_transform=K.RandomResizedCrop(
575                (448, 448),
576                ratio=(1.0, 1.0),
577                scale=(1.0, 1.0),
578                resample=Resample.BICUBIC.name,
579            ),
580            other_transforms=None,
581            src_transform=K.Resize((224, 224)),
582        )
583        x, _, ratios, _, _ = self._transforms(x)
584        input_res = ratios * self.base_resolution
585        return x, input_res
586
587    def convert_to_expected_dim(self, x):
588        inputs_ = x[:, 1:, :]  # removing the class tokens
589        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
590        rdim = inputs_.shape[1]
591        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
592        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
593        inputs_ = inputs_.permute(0, 3, 1, 2)
594        return inputs_
595
596    def forward_features(self, x):
597        x, input_res = self.transform_inputs(x)
598
599        B, _, h, w = x.shape
600        x = self.patch_embed(x)
601
602        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
603        pos_embed = get_2d_sincos_pos_embed_with_resolution(
604            x.shape[-1],
605            int(num_patches ** 0.5),
606            input_res,
607            cls_token=True,
608            device=x.device,
609        )
610
611        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
612        x = torch.cat((cls_tokens, x), dim=1)
613        x = x + pos_embed
614        x = self.pos_drop(x)
615
616        # chunks obtained for getting the projections for conjuctions with upsampling blocks
617        _chunks = int(self.depth / 4)
618        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
619
620        list_from_encoder = []
621        for i, blk in enumerate(self.blocks):
622            x = blk(x)
623            if i in chunks_for_projection:
624                list_from_encoder.append(self.convert_to_expected_dim(x))
625
626        x = self.convert_to_expected_dim(x)
627
628        return x, list_from_encoder
629
630    def forward(self, x):
631        x, list_from_encoder = self.forward_features(x)
632        return x, list_from_encoder

Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).

NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.

ViT_ScaleMAE( img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs)
553    def __init__(
554        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
555    ):
556        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
557        self.img_size = img_size
558        self.in_chans = in_chans
559        self.depth = depth
560        self.base_resolution = base_resolution
561
562        self.patch_embed = PatchEmbedUnSafe(
563            img_size=img_size,
564            patch_size=patch_size,
565            in_chans=in_chans,
566            embed_dim=embed_dim,
567        )
img_size
in_chans
depth
base_resolution
patch_embed
def transform_inputs(self, x):
569    def transform_inputs(self, x):
570        import kornia.augmentation as K
571        from kornia.constants import Resample
572
573        self._transforms = CustomCompose(
574            rescale_transform=K.RandomResizedCrop(
575                (448, 448),
576                ratio=(1.0, 1.0),
577                scale=(1.0, 1.0),
578                resample=Resample.BICUBIC.name,
579            ),
580            other_transforms=None,
581            src_transform=K.Resize((224, 224)),
582        )
583        x, _, ratios, _, _ = self._transforms(x)
584        input_res = ratios * self.base_resolution
585        return x, input_res
def convert_to_expected_dim(self, x):
587    def convert_to_expected_dim(self, x):
588        inputs_ = x[:, 1:, :]  # removing the class tokens
589        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
590        rdim = inputs_.shape[1]
591        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
592        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
593        inputs_ = inputs_.permute(0, 3, 1, 2)
594        return inputs_
def forward_features(self, x):
596    def forward_features(self, x):
597        x, input_res = self.transform_inputs(x)
598
599        B, _, h, w = x.shape
600        x = self.patch_embed(x)
601
602        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
603        pos_embed = get_2d_sincos_pos_embed_with_resolution(
604            x.shape[-1],
605            int(num_patches ** 0.5),
606            input_res,
607            cls_token=True,
608            device=x.device,
609        )
610
611        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
612        x = torch.cat((cls_tokens, x), dim=1)
613        x = x + pos_embed
614        x = self.pos_drop(x)
615
616        # chunks obtained for getting the projections for conjuctions with upsampling blocks
617        _chunks = int(self.depth / 4)
618        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
619
620        list_from_encoder = []
621        for i, blk in enumerate(self.blocks):
622            x = blk(x)
623            if i in chunks_for_projection:
624                list_from_encoder.append(self.convert_to_expected_dim(x))
625
626        x = self.convert_to_expected_dim(x)
627
628        return x, list_from_encoder
def forward(self, x):
630    def forward(self, x):
631        x, list_from_encoder = self.forward_features(x)
632        return x, list_from_encoder
class ViT_DINOv2:
635class ViT_DINOv2(DinoV2VisionTransformer):
636    """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
637
638    Based on:
639    https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
640
641    Args:
642        img_size: The input image size.
643        patch_size: The patch size.
644        depth: The depth of the network.
645        num_register_tokens: The number of registers added (in addition to the class tokens).
646            It's important to know for ViTs trained with registers, to remove them at the end.
647    """
648    def __init__(
649        self,
650        img_size: int = 224,
651        patch_size: int = 16,
652        depth: int = 12,
653        num_register_tokens: int = 0,
654        **kwargs
655    ):
656        if not _dinov2_import_success:
657            raise RuntimeError(
658                "The vision transformer backend can only be initialized if DINOv2 is installed. "
659                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
660                "and then rerun your code."
661            )
662
663        super().__init__(
664            img_size=img_size,
665            depth=depth,
666            patch_size=patch_size,
667            num_register_tokens=num_register_tokens,
668            **kwargs
669        )
670
671        self.img_size = img_size
672        self.num_register_tokens = num_register_tokens
673        self.patch_size = patch_size
674        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
675
676    def forward(self, x, masks=None) -> torch.Tensor:
677
678        B = x.shape[0]
679
680        x = self.prepare_tokens_with_masks(x)
681
682        list_of_encoder = []
683        for i, blk in enumerate(self.blocks):
684            x = blk(x)
685            if i in self.attn_outs:
686                list_of_encoder.append(x)
687
688        x = self.norm(x)
689        x = x[:, self.num_register_tokens + 1:].reshape(
690            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
691        ).permute(0, 3, 1, 2).contiguous()
692
693        list_of_encoder = [
694            o[:, self.num_register_tokens + 1:].reshape(
695                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
696            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
697        ]
698
699        return x, list_of_encoder[:3]

Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).

Based on: https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.

Arguments:
  • img_size: The input image size.
  • patch_size: The patch size.
  • depth: The depth of the network.
  • num_register_tokens: The number of registers added (in addition to the class tokens). It's important to know for ViTs trained with registers, to remove them at the end.
ViT_DINOv2( img_size: int = 224, patch_size: int = 16, depth: int = 12, num_register_tokens: int = 0, **kwargs)
648    def __init__(
649        self,
650        img_size: int = 224,
651        patch_size: int = 16,
652        depth: int = 12,
653        num_register_tokens: int = 0,
654        **kwargs
655    ):
656        if not _dinov2_import_success:
657            raise RuntimeError(
658                "The vision transformer backend can only be initialized if DINOv2 is installed. "
659                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
660                "and then rerun your code."
661            )
662
663        super().__init__(
664            img_size=img_size,
665            depth=depth,
666            patch_size=patch_size,
667            num_register_tokens=num_register_tokens,
668            **kwargs
669        )
670
671        self.img_size = img_size
672        self.num_register_tokens = num_register_tokens
673        self.patch_size = patch_size
674        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
img_size
num_register_tokens
patch_size
attn_outs
def forward(self, x, masks=None) -> torch.Tensor:
676    def forward(self, x, masks=None) -> torch.Tensor:
677
678        B = x.shape[0]
679
680        x = self.prepare_tokens_with_masks(x)
681
682        list_of_encoder = []
683        for i, blk in enumerate(self.blocks):
684            x = blk(x)
685            if i in self.attn_outs:
686                list_of_encoder.append(x)
687
688        x = self.norm(x)
689        x = x[:, self.num_register_tokens + 1:].reshape(
690            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
691        ).permute(0, 3, 1, 2).contiguous()
692
693        list_of_encoder = [
694            o[:, self.num_register_tokens + 1:].reshape(
695                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
696            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
697        ]
698
699        return x, list_of_encoder[:3]
class ViT_DINOv3:
702class ViT_DINOv3(DinoV3VisionTransformer):
703    """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
704
705    Based on:
706    https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
707
708    Args:
709        img_size: The input image size.
710        patch_size: The patch size.
711        embed_dim: The embedding dimension.
712        depth: The depth of the network.
713        num_heads: The number of heads.
714        ffn_ratio: The FFN rato.
715        n_storage_tokens: The number of storage (class) tokens to remove.
716        kwargs: Keyword arguments for the image encoder base class.
717    """
718    def __init__(
719        self,
720        in_chans: int = 3,
721        img_size: int = 224,
722        patch_size: int = 16,
723        embed_dim: int = 768,
724        depth: int = 12,
725        num_heads: int = 12,
726        ffn_ratio: float = 4.0,
727        n_storage_tokens: int = 0,
728        **kwargs
729    ):
730        if not _dinov3_import_success:
731            raise RuntimeError(
732                "The vision transformer backend can only be initialized if DINOv3 is installed. "
733                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
734                "and then rerun your code."
735            )
736
737        super().__init__(
738            in_chans=in_chans,
739            img_size=img_size,
740            patch_size=patch_size,
741            embed_dim=embed_dim,
742            depth=depth,
743            num_heads=num_heads,
744            ffn_ratio=ffn_ratio,
745            n_storage_tokens=n_storage_tokens,
746            **kwargs
747        )
748
749        self.in_chans = in_chans
750        self.img_size = img_size
751        self.n_storage_tokens = n_storage_tokens
752        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
753
754    def forward(self, x) -> torch.Tensor:
755
756        B = x.shape[0]
757
758        x, hw_tuple = self.prepare_tokens_with_masks(x)
759
760        list_of_encoder = []
761        for i, blk in enumerate(self.blocks):
762            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
763            x = blk(x, rope_sincos)
764            if i in self.attn_outs:
765                list_of_encoder.append(x)
766
767        x = self.norm(x)
768        x = x[:, self.n_storage_tokens + 1:].reshape(
769            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
770        ).permute(0, 3, 1, 2).contiguous()
771
772        list_of_encoder = [
773            o[:, self.n_storage_tokens + 1:].reshape(
774                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
775            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
776        ]
777
778        return x, list_of_encoder[:3]

Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).

Based on: https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.

Arguments:
  • img_size: The input image size.
  • patch_size: The patch size.
  • embed_dim: The embedding dimension.
  • depth: The depth of the network.
  • num_heads: The number of heads.
  • ffn_ratio: The FFN rato.
  • n_storage_tokens: The number of storage (class) tokens to remove.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_DINOv3( in_chans: int = 3, img_size: int = 224, patch_size: int = 16, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, ffn_ratio: float = 4.0, n_storage_tokens: int = 0, **kwargs)
718    def __init__(
719        self,
720        in_chans: int = 3,
721        img_size: int = 224,
722        patch_size: int = 16,
723        embed_dim: int = 768,
724        depth: int = 12,
725        num_heads: int = 12,
726        ffn_ratio: float = 4.0,
727        n_storage_tokens: int = 0,
728        **kwargs
729    ):
730        if not _dinov3_import_success:
731            raise RuntimeError(
732                "The vision transformer backend can only be initialized if DINOv3 is installed. "
733                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
734                "and then rerun your code."
735            )
736
737        super().__init__(
738            in_chans=in_chans,
739            img_size=img_size,
740            patch_size=patch_size,
741            embed_dim=embed_dim,
742            depth=depth,
743            num_heads=num_heads,
744            ffn_ratio=ffn_ratio,
745            n_storage_tokens=n_storage_tokens,
746            **kwargs
747        )
748
749        self.in_chans = in_chans
750        self.img_size = img_size
751        self.n_storage_tokens = n_storage_tokens
752        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
in_chans
img_size
n_storage_tokens
attn_outs
def forward(self, x) -> torch.Tensor:
754    def forward(self, x) -> torch.Tensor:
755
756        B = x.shape[0]
757
758        x, hw_tuple = self.prepare_tokens_with_masks(x)
759
760        list_of_encoder = []
761        for i, blk in enumerate(self.blocks):
762            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
763            x = blk(x, rope_sincos)
764            if i in self.attn_outs:
765                list_of_encoder.append(x)
766
767        x = self.norm(x)
768        x = x[:, self.n_storage_tokens + 1:].reshape(
769            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
770        ).permute(0, 3, 1, 2).contiguous()
771
772        list_of_encoder = [
773            o[:, self.n_storage_tokens + 1:].reshape(
774                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
775            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
776        ]
777
778        return x, list_of_encoder[:3]
def get_vision_transformer( backbone: str, model: str, img_size: int = 1024, **kwargs) -> torch.nn.modules.module.Module:
 781def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
 782    """Get vision transformer encoder.
 783
 784    Args:
 785        backbone: The name of the vision transformer implementation.
 786            One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3".
 787        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
 788        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
 789        kwargs: Additional kwargs which can be expected by the vision transformer,
 790            e.g. 'base_resolution' for `ViT_ScaleMAE`.
 791
 792    Returns:
 793        The vision transformer.
 794    """
 795    if backbone == "sam":
 796        if model == "vit_b":
 797            encoder = ViT_Sam(
 798                depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4,
 799                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 800                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
 801                global_attn_indexes=[2, 5, 8, 11],
 802                window_size=14, out_chans=256,
 803            )
 804        elif model == "vit_l":
 805            encoder = ViT_Sam(
 806                depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4,
 807                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 808                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
 809                global_attn_indexes=[5, 11, 17, 23],
 810                window_size=14, out_chans=256,
 811            )
 812        elif model == "vit_h":
 813            encoder = ViT_Sam(
 814                depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4,
 815                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 816                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
 817                global_attn_indexes=[7, 15, 23, 31],
 818                window_size=14, out_chans=256,
 819            )
 820        else:
 821            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
 822
 823    elif backbone == "cellpose_sam":
 824        if model != "vit_l":
 825            raise ValueError(f"'{model}' is not supported by CellposeSAM. Only 'vit_l' is supported.")
 826        encoder = ViT_CellposeSAM(ps=8, bsize=img_size)
 827
 828    elif backbone == "sam2":
 829        if model == "hvit_t":
 830            encoder = ViT_Sam2(
 831                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
 832                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
 833            )
 834        elif model == "hvit_s":
 835            encoder = ViT_Sam2(
 836                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
 837                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
 838            )
 839        elif model == "hvit_b":
 840            encoder = ViT_Sam2(
 841                img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
 842            )
 843        elif model == "hvit_l":
 844            encoder = ViT_Sam2(
 845                img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
 846                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
 847            )
 848        else:
 849            raise ValueError(
 850                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
 851            )
 852
 853    elif backbone == "sam3":
 854        if model != "vit_pe":
 855            raise ValueError(
 856                "'sam3' does not have multiple model configurations. Please use 'vit_pe' as the model configuration."
 857            )
 858
 859        encoder = ViT_Sam3(
 860            img_size=1008, pretrain_img_size=336, patch_size=14, embed_dim=1024, depth=32, num_heads=16,
 861            mlp_ratio=4.625, norm_layer="LayerNorm", drop_path_rate=0.1, qkv_bias=True, use_abs_pos=True,
 862            tile_abs_pos=True, global_att_blocks=(7, 15, 23, 31), rel_pos_blocks=(), use_rope=True,
 863            use_interp_rope=True, window_size=24, pretrain_use_cls_token=True, retain_cls_token=False, ln_pre=True,
 864            ln_post=False, return_interm_layers=False, bias_patch_embed=False, compile_mode=None,
 865        )
 866
 867    elif backbone == "mae":
 868        if model == "vit_b":
 869            encoder = ViT_MAE(
 870                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
 871                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 872            )
 873        elif model == "vit_l":
 874            encoder = ViT_MAE(
 875                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
 876                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 877            )
 878        elif model == "vit_h":
 879            encoder = ViT_MAE(
 880                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
 881                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
 882            )
 883        else:
 884            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
 885
 886    elif backbone == "scalemae":
 887        base_resolution = kwargs.get("base_resolution", 2.5)
 888
 889        if model == "vit_b":
 890            encoder = ViT_ScaleMAE(
 891                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
 892                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 893                base_resolution=base_resolution,
 894            )
 895        elif model == "vit_l":
 896            encoder = ViT_ScaleMAE(
 897                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
 898                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 899                base_resolution=base_resolution,
 900            )
 901        elif model == "vit_h":
 902            encoder = ViT_ScaleMAE(
 903                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
 904                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
 905                base_resolution=base_resolution,
 906            )
 907        else:
 908            raise ValueError(
 909                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
 910            )
 911
 912    elif backbone == "dinov2":
 913        block_fn = partial(Block, attn_class=MemEffAttention)
 914        msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>."
 915
 916        if model.startswith("vit_s"):
 917            assert model in ["vit_s", "vit_s_reg4"], msg
 918            encoder = ViT_DINOv2(
 919                img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
 920                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 921                num_register_tokens=4 if model.endswith("_reg4") else 0,
 922            )
 923        elif model.startswith("vit_b"):
 924            assert model in ["vit_b", "vit_b_reg4"], msg
 925            encoder = ViT_DINOv2(
 926                img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
 927                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 928                num_register_tokens=4 if model.endswith("_reg4") else 0,
 929            )
 930        elif model.startswith("vit_l"):
 931            assert model in ["vit_l", "vit_l_reg4"], msg
 932            encoder = ViT_DINOv2(
 933                img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
 934                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 935                num_register_tokens=4 if model.endswith("_reg4") else 0,
 936            )
 937        elif model.startswith("vit_g"):
 938            assert model in ["vit_g", "vit_g_reg4"], msg
 939            encoder = ViT_DINOv2(
 940                img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4,
 941                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
 942                num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu",
 943            )
 944        else:
 945            raise ValueError(
 946                f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported."
 947            )
 948
 949    elif backbone == "dinov3":
 950
 951        if model == "vit_s":
 952            encoder = ViT_DINOv3(
 953                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
 954                num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
 955            )
 956        elif model == "vit_s+":
 957            encoder = ViT_DINOv3(
 958                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
 959                num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 960                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 961            )
 962
 963        elif model == "vit_b":
 964            encoder = ViT_DINOv3(
 965                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32",
 966                layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
 967            )
 968        elif model == "vit_l":
 969            encoder = ViT_DINOv3(
 970                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
 971                depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 972                n_storage_tokens=4, mask_k_bias=True,
 973            )
 974        elif model == "vit_l+":
 975            encoder = ViT_DINOv3(
 976                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
 977                depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 978                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 979            )
 980        elif model == "vit_h+":
 981            encoder = ViT_DINOv3(
 982                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280,
 983                depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
 984                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
 985            )
 986        elif model == "vit_7b":
 987            encoder = ViT_DINOv3(
 988                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096,
 989                depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05,
 990                norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True,
 991                untie_global_and_local_cls_norm=True,
 992            )
 993        else:
 994            raise ValueError(
 995                f"'{model}' is not supported by DINOv3. Currently, "
 996                " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported."
 997            )
 998
 999    else:
1000        raise ValueError(
1001            "The 'UNETR' supported backbones are 'sam', 'cellpose_sam', 'sam2', 'sam3', "
1002            "'mae', 'scalemae', 'dinov2' or 'dinov3'. Please choose one of them."
1003        )
1004
1005    return encoder

Get vision transformer encoder.

Arguments:
  • backbone: The name of the vision transformer implementation. One of "sam" / "cellpose_sam" / "sam2" / "sam3" / "mae" / "scalemae" / "dinov2" / "dinov3".
  • model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • kwargs: Additional kwargs which can be expected by the vision transformer, e.g. 'base_resolution' for ViT_ScaleMAE.
Returns:

The vision transformer.