torch_em.model.vit

  1from functools import partial
  2from typing import Tuple, List
  3
  4import torch
  5import torch.nn as nn
  6
  7# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should
  8# only be optional dependencies for torch_em
  9try:
 10    from segment_anything.modeling import ImageEncoderViT
 11    _sam_import_success = True
 12except ImportError:
 13    ImageEncoderViT = object
 14    _sam_import_success = False
 15
 16try:
 17    from timm.models.vision_transformer import VisionTransformer, PatchEmbed
 18    _timm_import_success = True
 19except ImportError:
 20    VisionTransformer = object
 21    PatchEmbed = object
 22    _timm_import_success = False
 23
 24try:
 25    from sam2.modeling.backbones.hieradet import Hiera
 26    from sam2.modeling.position_encoding import PositionEmbeddingSine
 27    from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck
 28    _sam2_import_success = True
 29except ImportError:
 30    ImageEncoder = object
 31    _sam2_import_success = False
 32
 33try:
 34    from dinov2.models.vision_transformer import DinoVisionTransformer as DinoV2VisionTransformer
 35    from dinov2.layers import MemEffAttention, NestedTensorBlock as Block
 36    _dinov2_import_success = True
 37except ImportError:
 38    DinoV2VisionTransformer = object
 39    _dinov2_import_success = False
 40
 41try:
 42    from dinov3.models.vision_transformer import DinoVisionTransformer as DinoV3VisionTransformer
 43    _dinov3_import_success = True
 44except ImportError:
 45    DinoV3VisionTransformer = object
 46    _dinov3_import_success = False
 47
 48
 49class ViT_Sam(ImageEncoderViT):
 50    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
 51
 52    Based on:
 53    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 54
 55    Args:
 56        in_chans: The number of input channels.
 57        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 58        global_attn_indexes: The global attention indices.
 59        kwargs: Keyword arguments for the image encoder base class.
 60    """
 61    def __init__(
 62        self,
 63        in_chans: int = 3,
 64        embed_dim: int = 768,
 65        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
 66        **kwargs,
 67    ) -> None:
 68        if not _sam_import_success:
 69            raise RuntimeError(
 70                "The vision transformer backend can only be initialized if segment anything is installed. "
 71                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
 72                "and then rerun your code."
 73            )
 74
 75        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
 76        self.chunks_for_projection = global_attn_indexes
 77        self.in_chans = in_chans
 78        self.embed_dim = embed_dim
 79
 80    def forward(self, x: torch.Tensor) -> torch.Tensor:
 81        """Apply the vision transformer to input data.
 82
 83        Args:
 84            x: The input data.
 85
 86        Returns:
 87            The vision transformer output.
 88        """
 89        x = self.patch_embed(x)
 90        if self.pos_embed is not None:
 91            x = x + self.pos_embed
 92
 93        list_from_encoder = []
 94        for i, blk in enumerate(self.blocks):
 95            x = blk(x)
 96            if i in self.chunks_for_projection:
 97                list_from_encoder.append(x)
 98
 99        x = x.permute(0, 3, 1, 2)
100        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
101        return x, list_from_encoder[:3]
102
103
104class ViT_MAE(VisionTransformer):
105    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
106
107    Based on:
108    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
109
110    Args:
111        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
112        in_chans: The number of input channels.
113        depth: The depth of the vision transformer.
114        kwargs: Additional keyword arguments for the vision transformer base class.
115    """
116    def __init__(
117        self,
118        img_size: int = 1024,  # chosen to match our experiments with segment anything
119        in_chans: int = 3,
120        depth: int = 12,
121        **kwargs
122    ):
123        if not _timm_import_success:
124            raise RuntimeError(
125                "The vision transformer backend can only be initialized if timm is installed. "
126                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
127                "and then rerun your code"
128            )
129        super().__init__(img_size=img_size, depth=depth, **kwargs)
130        self.img_size = img_size
131        self.in_chans = in_chans
132        self.depth = depth
133
134    def convert_to_expected_dim(self, inputs_):
135        """@private
136        """
137        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
138        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
139        rdim = inputs_.shape[1]
140        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
141        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
142        inputs_ = inputs_.permute(0, 3, 1, 2)
143        return inputs_
144
145    def forward_features(self, x):
146        """@private
147        """
148        B = x.shape[0]
149        x = self.patch_embed(x)
150
151        cls_tokens = self.cls_token.expand(B, -1, -1)
152        x = torch.cat((cls_tokens, x), dim=1)
153
154        x = x + self.pos_embed
155        x = self.pos_drop(x)
156
157        # chunks obtained for getting the projections for conjuctions with upsampling blocks
158        _chunks = int(self.depth / 4)
159        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
160
161        list_from_encoder = []
162        for i, blk in enumerate(self.blocks):
163            x = blk(x)
164            if i in chunks_for_projection:
165                list_from_encoder.append(self.convert_to_expected_dim(x))
166
167        x = self.convert_to_expected_dim(x)
168        return x, list_from_encoder[:3]
169
170    def forward(self, x: torch.Tensor) -> torch.Tensor:
171        """Apply the vision transformer to input data.
172
173        Args:
174            x: The input data.
175
176        Returns:
177            The vision transformer output.
178        """
179        x, list_from_encoder = self.forward_features(x)
180        return x, list_from_encoder
181
182
183class ViT_Sam2(ImageEncoder):
184    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
185
186    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
187
188    Args:
189        backbone_channel_list: The channels throughout the entire backbone.
190        embed_dim: The initial embedding dimension.
191        num_heads: The initial number of heads.
192        stages: The number of blocks per stage.
193        global_att_blocks: The parameter to decide which blocks have global attention.
194        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
195        window_spec: The window size per stage, when not using global attention.
196        scalp: The count of lowest resolution features to discard.
197    """
198    def __init__(
199        self,
200        backbone_channel_list: List[int],
201        img_size: int = 1024,
202        embed_dim: int = 96,
203        num_heads: int = 1,
204        stages: Tuple[int, ...] = (2, 3, 16, 3),
205        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
206        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
207        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
208        scalp: int = 1,
209    ):
210        if not _sam2_import_success:
211            raise RuntimeError(
212                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
213                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
214                "and then rerun your code"
215            )
216
217        trunk = Hiera(
218            embed_dim=embed_dim,
219            num_heads=num_heads,
220            stages=stages,
221            global_att_blocks=global_att_blocks,
222            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
223            window_spec=window_spec,
224        )
225        neck = FpnNeck(
226            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
227            d_model=256,
228            backbone_channel_list=backbone_channel_list,
229            fpn_top_down_levels=[2, 3],
230            fpn_interp_model="nearest",
231        )
232
233        super().__init__(trunk=trunk, neck=neck, scalp=scalp)
234        self.scalp = scalp
235        self.embed_dim = embed_dim
236        self.img_size = img_size
237
238    def forward(self, x: torch.Tensor):
239        # The forward pass throught the backbone.
240        features, pos = self.neck(self.trunk(x))
241        if self.scalp > 0:  # This discard the "n" lowest resolution features.
242            features, pos = features[:-self.scalp], pos[:-self.scalp]
243
244        return features[-1], features
245
246
247#
248# Utilities for ScaleMAE's ViT
249#
250
251
252class CustomCompose:
253    def __init__(self, rescale_transform, other_transforms, src_transform):
254        self.rescale_transform = rescale_transform
255        self.other_transforms = other_transforms
256        self.src_transform = src_transform
257
258    def __call__(self, x, valid_masks=None):
259        if valid_masks is not None:
260            nodata = (x * (1 - valid_masks.float())).max()
261        x_aug = self.rescale_transform(x)
262        parms = self.rescale_transform._params
263
264        # sanity check, comment if this is working
265        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
266        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
267
268        if valid_masks is not None:
269            valid_masks = x_aug != nodata
270            _, c, h, w = x_aug.shape
271            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
272        else:
273            zero_ratio = -1
274
275        if self.other_transforms:
276            x_aug = self.other_transforms(x_aug)
277        x_src = self.src_transform(x_aug)
278        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
279
280        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
281        # assert (dx == dy).all()
282
283        h, w = x_aug.shape[-2:]
284        # assert h == w
285
286        return x_aug, x_src, dx / h, zero_ratio, valid_masks
287
288
289def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
290    """
291    grid_size: int of the grid height and width
292    res: array of size n, representing the resolution of a pixel (say, in meters),
293    return:
294    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
295    """
296    # res = torch.FloatTensor(res).to(device)
297    res = res.to(device)
298    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
299    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
300    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
301    grid = torch.stack(grid, dim=0)  # 2 x h x w
302
303    # grid = grid.reshape([2, 1, grid_size, grid_size])
304    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
305    _, n, h, w = grid.shape
306    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
307    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
308    if cls_token:
309        pos_embed = torch.cat(
310            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
311        )
312
313    return pos_embed
314
315
316def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
317    assert embed_dim % 2 == 0
318
319    # use half of dimensions to encode grid_h
320    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
321    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
322
323    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
324    return emb
325
326
327def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
328    """
329    embed_dim: output dimension for each position
330    pos: a list of positions to be encoded: size (M,)
331    out: (M, D)
332    """
333    assert embed_dim % 2 == 0
334    # old_shape = pos
335    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
336    omega /= embed_dim / 2.0
337    omega = 1.0 / 10000**omega  # (D/2,)
338
339    pos = pos.reshape(-1)  # (M,)
340    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
341
342    emb_sin = torch.sin(out)  # (M, D/2)
343    emb_cos = torch.cos(out)  # (M, D/2)
344
345    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
346    return emb
347
348
349class PatchEmbedUnSafe(PatchEmbed):
350    """Image to Patch Embedding"""
351
352    def forward(self, x):
353        B, C, H, W = x.shape
354
355        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
356        # assert H == self.img_size[0] and W == self.img_size[1], \
357        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
358
359        x = self.proj(x).flatten(2).transpose(1, 2)
360        return x
361
362
363class ViT_ScaleMAE(VisionTransformer):
364    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
365
366    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
367    the model on a different zoom factor dataset.
368    """
369
370    def __init__(
371        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
372    ):
373        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
374        self.img_size = img_size
375        self.in_chans = in_chans
376        self.depth = depth
377        self.base_resolution = base_resolution
378
379        self.patch_embed = PatchEmbedUnSafe(
380            img_size=img_size,
381            patch_size=patch_size,
382            in_chans=in_chans,
383            embed_dim=embed_dim,
384        )
385
386    def transform_inputs(self, x):
387        import kornia.augmentation as K
388        from kornia.constants import Resample
389
390        self._transforms = CustomCompose(
391            rescale_transform=K.RandomResizedCrop(
392                (448, 448),
393                ratio=(1.0, 1.0),
394                scale=(1.0, 1.0),
395                resample=Resample.BICUBIC.name,
396            ),
397            other_transforms=None,
398            src_transform=K.Resize((224, 224)),
399        )
400        x, _, ratios, _, _ = self._transforms(x)
401        input_res = ratios * self.base_resolution
402        return x, input_res
403
404    def convert_to_expected_dim(self, x):
405        inputs_ = x[:, 1:, :]  # removing the class tokens
406        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
407        rdim = inputs_.shape[1]
408        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
409        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
410        inputs_ = inputs_.permute(0, 3, 1, 2)
411        return inputs_
412
413    def forward_features(self, x):
414        x, input_res = self.transform_inputs(x)
415
416        B, _, h, w = x.shape
417        x = self.patch_embed(x)
418
419        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
420        pos_embed = get_2d_sincos_pos_embed_with_resolution(
421            x.shape[-1],
422            int(num_patches ** 0.5),
423            input_res,
424            cls_token=True,
425            device=x.device,
426        )
427
428        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
429        x = torch.cat((cls_tokens, x), dim=1)
430        x = x + pos_embed
431        x = self.pos_drop(x)
432
433        # chunks obtained for getting the projections for conjuctions with upsampling blocks
434        _chunks = int(self.depth / 4)
435        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
436
437        list_from_encoder = []
438        for i, blk in enumerate(self.blocks):
439            x = blk(x)
440            if i in chunks_for_projection:
441                list_from_encoder.append(self.convert_to_expected_dim(x))
442
443        x = self.convert_to_expected_dim(x)
444
445        return x, list_from_encoder
446
447    def forward(self, x):
448        x, list_from_encoder = self.forward_features(x)
449        return x, list_from_encoder
450
451
452class ViT_DINOv2(DinoV2VisionTransformer):
453    """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
454
455    Based on:
456    https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
457    """
458    def __init__(
459        self,
460        img_size: int = 224,
461        patch_size: int = 16,
462        depth: int = 12,
463        num_register_tokens: int = 0,
464        **kwargs
465    ):
466        if not _dinov2_import_success:
467            raise RuntimeError(
468                "The vision transformer backend can only be initialized if DINOv2 is installed. "
469                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
470                "and then rerun your code."
471            )
472
473        super().__init__(
474            img_size=img_size,
475            depth=depth,
476            patch_size=patch_size,
477            num_register_tokens=num_register_tokens,
478            **kwargs
479        )
480
481        self.img_size = img_size
482        self.num_register_tokens = num_register_tokens
483        self.patch_size = patch_size
484        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
485
486    def forward(self, x, masks=None) -> torch.Tensor:
487
488        B = x.shape[0]
489
490        x = self.prepare_tokens_with_masks(x)
491
492        list_of_encoder = []
493        for i, blk in enumerate(self.blocks):
494            x = blk(x)
495            if i in self.attn_outs:
496                list_of_encoder.append(x)
497
498        x = self.norm(x)
499        x = x[:, self.num_register_tokens + 1:].reshape(
500            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
501        ).permute(0, 3, 1, 2).contiguous()
502
503        list_of_encoder = [
504            o[:, self.num_register_tokens + 1:].reshape(
505                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
506            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
507        ]
508
509        return x, list_of_encoder[:3]
510
511
512class ViT_DINOv3(DinoV3VisionTransformer):
513    """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
514
515    Based on:
516    https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
517
518    Args:
519        img_size: The input image size.
520        patch_size: The patch size.
521        embed_dim: The embedding dimension.
522        depth: The depth of the network.
523        num_heads: The number of heads.
524        ffn_ratio: The FFN rato.
525        n_storage_tokens: The number of storage (class) tokens to remove.
526        kwargs: Keyword arguments for the image encoder base class.
527    """
528    def __init__(
529        self,
530        in_chans: int = 3,
531        img_size: int = 224,
532        patch_size: int = 16,
533        embed_dim: int = 768,
534        depth: int = 12,
535        num_heads: int = 12,
536        ffn_ratio: float = 4.0,
537        n_storage_tokens: int = 0,
538        **kwargs
539    ):
540        if not _dinov3_import_success:
541            raise RuntimeError(
542                "The vision transformer backend can only be initialized if DINOv3 is installed. "
543                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
544                "and then rerun your code."
545            )
546
547        super().__init__(
548            in_chans=in_chans,
549            img_size=img_size,
550            patch_size=patch_size,
551            embed_dim=embed_dim,
552            depth=depth,
553            num_heads=num_heads,
554            ffn_ratio=ffn_ratio,
555            n_storage_tokens=n_storage_tokens,
556            **kwargs
557        )
558
559        self.in_chans = in_chans
560        self.img_size = img_size
561        self.n_storage_tokens = n_storage_tokens
562        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
563
564    def forward(self, x) -> torch.Tensor:
565
566        B = x.shape[0]
567
568        x, hw_tuple = self.prepare_tokens_with_masks(x)
569
570        list_of_encoder = []
571        for i, blk in enumerate(self.blocks):
572            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
573            x = blk(x, rope_sincos)
574            if i in self.attn_outs:
575                list_of_encoder.append(x)
576
577        x = self.norm(x)
578        x = x[:, self.n_storage_tokens + 1:].reshape(
579            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
580        ).permute(0, 3, 1, 2).contiguous()
581
582        list_of_encoder = [
583            o[:, self.n_storage_tokens + 1:].reshape(
584                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
585            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
586        ]
587
588        return x, list_of_encoder[:3]
589
590
591def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
592    """Get vision transformer encoder.
593
594    Args:
595        backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
596        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
597        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
598        kwargs: Additional kwargs which can be expected by the vision transformer,
599            e.g. 'base_resolution' for `ViT_ScaleMAE`.
600
601    Returns:
602        The vision transformer.
603    """
604    if backbone == "sam":
605        if model == "vit_b":
606            encoder = ViT_Sam(
607                depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4,
608                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
609                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
610                global_attn_indexes=[2, 5, 8, 11],
611                window_size=14, out_chans=256,
612            )
613        elif model == "vit_l":
614            encoder = ViT_Sam(
615                depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4,
616                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
617                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
618                global_attn_indexes=[5, 11, 17, 23],
619                window_size=14, out_chans=256,
620            )
621        elif model == "vit_h":
622            encoder = ViT_Sam(
623                depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4,
624                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
625                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
626                global_attn_indexes=[7, 15, 23, 31],
627                window_size=14, out_chans=256,
628            )
629        else:
630            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
631
632    elif backbone == "sam2":
633        if model == "hvit_t":
634            encoder = ViT_Sam2(
635                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
636                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
637            )
638        elif model == "hvit_s":
639            encoder = ViT_Sam2(
640                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
641                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
642            )
643        elif model == "hvit_b":
644            encoder = ViT_Sam2(
645                img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
646            )
647        elif model == "hvit_l":
648            encoder = ViT_Sam2(
649                img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
650                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
651            )
652        else:
653            raise ValueError(
654                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
655            )
656
657    elif backbone == "mae":
658        if model == "vit_b":
659            encoder = ViT_MAE(
660                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
661                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
662            )
663        elif model == "vit_l":
664            encoder = ViT_MAE(
665                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
666                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
667            )
668        elif model == "vit_h":
669            encoder = ViT_MAE(
670                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
671                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
672            )
673        else:
674            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
675
676    elif backbone == "scalemae":
677        base_resolution = kwargs.get("base_resolution", 2.5)
678
679        if model == "vit_b":
680            encoder = ViT_ScaleMAE(
681                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
682                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
683                base_resolution=base_resolution,
684            )
685        elif model == "vit_l":
686            encoder = ViT_ScaleMAE(
687                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
688                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
689                base_resolution=base_resolution,
690            )
691        elif model == "vit_h":
692            encoder = ViT_ScaleMAE(
693                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
694                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
695                base_resolution=base_resolution,
696            )
697        else:
698            raise ValueError(
699                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
700            )
701
702    elif backbone == "dinov2":
703        block_fn = partial(Block, attn_class=MemEffAttention)
704        msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>."
705
706        if model.startswith("vit_s"):
707            assert model in ["vit_s", "vit_s_reg4"], msg
708            encoder = ViT_DINOv2(
709                img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
710                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
711                num_register_tokens=4 if model.endswith("_reg4") else 0,
712            )
713        elif model.startswith("vit_b"):
714            assert model in ["vit_b", "vit_b_reg4"], msg
715            encoder = ViT_DINOv2(
716                img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
717                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
718                num_register_tokens=4 if model.endswith("_reg4") else 0,
719            )
720        elif model.startswith("vit_l"):
721            assert model in ["vit_l", "vit_l_reg4"], msg
722            encoder = ViT_DINOv2(
723                img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
724                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
725                num_register_tokens=4 if model.endswith("_reg4") else 0,
726            )
727        elif model.startswith("vit_g"):
728            assert model in ["vit_g", "vit_g_reg4"], msg
729            encoder = ViT_DINOv2(
730                img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4,
731                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
732                num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu",
733            )
734        else:
735            raise ValueError(
736                f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported."
737            )
738
739    elif backbone == "dinov3":
740
741        if model == "vit_s":
742            encoder = ViT_DINOv3(
743                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
744                num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
745            )
746        elif model == "vit_s+":
747            encoder = ViT_DINOv3(
748                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
749                num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16",
750                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
751            )
752
753        elif model == "vit_b":
754            encoder = ViT_DINOv3(
755                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32",
756                layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
757            )
758        elif model == "vit_l":
759            encoder = ViT_DINOv3(
760                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
761                depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16",
762                n_storage_tokens=4, mask_k_bias=True,
763            )
764        elif model == "vit_l+":
765            encoder = ViT_DINOv3(
766                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
767                depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
768                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
769            )
770        elif model == "vit_h+":
771            encoder = ViT_DINOv3(
772                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280,
773                depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
774                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
775            )
776        elif model == "vit_7b":
777            encoder = ViT_DINOv3(
778                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096,
779                depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05,
780                norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True,
781                untie_global_and_local_cls_norm=True,
782            )
783        else:
784            raise ValueError(
785                f"'{model}' is not supported by DINOv3. Currently, "
786                " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported."
787            )
788
789    else:
790        raise ValueError(
791            "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae', 'scalemae' or 'dinov3'. "
792            "Please choose one of them."
793        )
794
795    return encoder
class ViT_Sam:
 50class ViT_Sam(ImageEncoderViT):
 51    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
 52
 53    Based on:
 54    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 55
 56    Args:
 57        in_chans: The number of input channels.
 58        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 59        global_attn_indexes: The global attention indices.
 60        kwargs: Keyword arguments for the image encoder base class.
 61    """
 62    def __init__(
 63        self,
 64        in_chans: int = 3,
 65        embed_dim: int = 768,
 66        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
 67        **kwargs,
 68    ) -> None:
 69        if not _sam_import_success:
 70            raise RuntimeError(
 71                "The vision transformer backend can only be initialized if segment anything is installed. "
 72                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
 73                "and then rerun your code."
 74            )
 75
 76        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
 77        self.chunks_for_projection = global_attn_indexes
 78        self.in_chans = in_chans
 79        self.embed_dim = embed_dim
 80
 81    def forward(self, x: torch.Tensor) -> torch.Tensor:
 82        """Apply the vision transformer to input data.
 83
 84        Args:
 85            x: The input data.
 86
 87        Returns:
 88            The vision transformer output.
 89        """
 90        x = self.patch_embed(x)
 91        if self.pos_embed is not None:
 92            x = x + self.pos_embed
 93
 94        list_from_encoder = []
 95        for i, blk in enumerate(self.blocks):
 96            x = blk(x)
 97            if i in self.chunks_for_projection:
 98                list_from_encoder.append(x)
 99
100        x = x.permute(0, 3, 1, 2)
101        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
102        return x, list_from_encoder[:3]

Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).

Based on: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

Arguments:
  • in_chans: The number of input channels.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • global_attn_indexes: The global attention indices.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam( in_chans: int = 3, embed_dim: int = 768, global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], **kwargs)
62    def __init__(
63        self,
64        in_chans: int = 3,
65        embed_dim: int = 768,
66        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
67        **kwargs,
68    ) -> None:
69        if not _sam_import_success:
70            raise RuntimeError(
71                "The vision transformer backend can only be initialized if segment anything is installed. "
72                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
73                "and then rerun your code."
74            )
75
76        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
77        self.chunks_for_projection = global_attn_indexes
78        self.in_chans = in_chans
79        self.embed_dim = embed_dim
chunks_for_projection
in_chans
embed_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
 81    def forward(self, x: torch.Tensor) -> torch.Tensor:
 82        """Apply the vision transformer to input data.
 83
 84        Args:
 85            x: The input data.
 86
 87        Returns:
 88            The vision transformer output.
 89        """
 90        x = self.patch_embed(x)
 91        if self.pos_embed is not None:
 92            x = x + self.pos_embed
 93
 94        list_from_encoder = []
 95        for i, blk in enumerate(self.blocks):
 96            x = blk(x)
 97            if i in self.chunks_for_projection:
 98                list_from_encoder.append(x)
 99
100        x = x.permute(0, 3, 1, 2)
101        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
102        return x, list_from_encoder[:3]

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_MAE:
105class ViT_MAE(VisionTransformer):
106    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
107
108    Based on:
109    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
110
111    Args:
112        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
113        in_chans: The number of input channels.
114        depth: The depth of the vision transformer.
115        kwargs: Additional keyword arguments for the vision transformer base class.
116    """
117    def __init__(
118        self,
119        img_size: int = 1024,  # chosen to match our experiments with segment anything
120        in_chans: int = 3,
121        depth: int = 12,
122        **kwargs
123    ):
124        if not _timm_import_success:
125            raise RuntimeError(
126                "The vision transformer backend can only be initialized if timm is installed. "
127                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
128                "and then rerun your code"
129            )
130        super().__init__(img_size=img_size, depth=depth, **kwargs)
131        self.img_size = img_size
132        self.in_chans = in_chans
133        self.depth = depth
134
135    def convert_to_expected_dim(self, inputs_):
136        """@private
137        """
138        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
139        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
140        rdim = inputs_.shape[1]
141        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
142        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
143        inputs_ = inputs_.permute(0, 3, 1, 2)
144        return inputs_
145
146    def forward_features(self, x):
147        """@private
148        """
149        B = x.shape[0]
150        x = self.patch_embed(x)
151
152        cls_tokens = self.cls_token.expand(B, -1, -1)
153        x = torch.cat((cls_tokens, x), dim=1)
154
155        x = x + self.pos_embed
156        x = self.pos_drop(x)
157
158        # chunks obtained for getting the projections for conjuctions with upsampling blocks
159        _chunks = int(self.depth / 4)
160        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
161
162        list_from_encoder = []
163        for i, blk in enumerate(self.blocks):
164            x = blk(x)
165            if i in chunks_for_projection:
166                list_from_encoder.append(self.convert_to_expected_dim(x))
167
168        x = self.convert_to_expected_dim(x)
169        return x, list_from_encoder[:3]
170
171    def forward(self, x: torch.Tensor) -> torch.Tensor:
172        """Apply the vision transformer to input data.
173
174        Args:
175            x: The input data.
176
177        Returns:
178            The vision transformer output.
179        """
180        x, list_from_encoder = self.forward_features(x)
181        return x, list_from_encoder

Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).

Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • in_chans: The number of input channels.
  • depth: The depth of the vision transformer.
  • kwargs: Additional keyword arguments for the vision transformer base class.
ViT_MAE(img_size: int = 1024, in_chans: int = 3, depth: int = 12, **kwargs)
117    def __init__(
118        self,
119        img_size: int = 1024,  # chosen to match our experiments with segment anything
120        in_chans: int = 3,
121        depth: int = 12,
122        **kwargs
123    ):
124        if not _timm_import_success:
125            raise RuntimeError(
126                "The vision transformer backend can only be initialized if timm is installed. "
127                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
128                "and then rerun your code"
129            )
130        super().__init__(img_size=img_size, depth=depth, **kwargs)
131        self.img_size = img_size
132        self.in_chans = in_chans
133        self.depth = depth
img_size
in_chans
depth
def forward(self, x: torch.Tensor) -> torch.Tensor:
171    def forward(self, x: torch.Tensor) -> torch.Tensor:
172        """Apply the vision transformer to input data.
173
174        Args:
175            x: The input data.
176
177        Returns:
178            The vision transformer output.
179        """
180        x, list_from_encoder = self.forward_features(x)
181        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_Sam2:
184class ViT_Sam2(ImageEncoder):
185    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
186
187    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
188
189    Args:
190        backbone_channel_list: The channels throughout the entire backbone.
191        embed_dim: The initial embedding dimension.
192        num_heads: The initial number of heads.
193        stages: The number of blocks per stage.
194        global_att_blocks: The parameter to decide which blocks have global attention.
195        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
196        window_spec: The window size per stage, when not using global attention.
197        scalp: The count of lowest resolution features to discard.
198    """
199    def __init__(
200        self,
201        backbone_channel_list: List[int],
202        img_size: int = 1024,
203        embed_dim: int = 96,
204        num_heads: int = 1,
205        stages: Tuple[int, ...] = (2, 3, 16, 3),
206        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
207        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
208        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
209        scalp: int = 1,
210    ):
211        if not _sam2_import_success:
212            raise RuntimeError(
213                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
214                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
215                "and then rerun your code"
216            )
217
218        trunk = Hiera(
219            embed_dim=embed_dim,
220            num_heads=num_heads,
221            stages=stages,
222            global_att_blocks=global_att_blocks,
223            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
224            window_spec=window_spec,
225        )
226        neck = FpnNeck(
227            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
228            d_model=256,
229            backbone_channel_list=backbone_channel_list,
230            fpn_top_down_levels=[2, 3],
231            fpn_interp_model="nearest",
232        )
233
234        super().__init__(trunk=trunk, neck=neck, scalp=scalp)
235        self.scalp = scalp
236        self.embed_dim = embed_dim
237        self.img_size = img_size
238
239    def forward(self, x: torch.Tensor):
240        # The forward pass throught the backbone.
241        features, pos = self.neck(self.trunk(x))
242        if self.scalp > 0:  # This discard the "n" lowest resolution features.
243            features, pos = features[:-self.scalp], pos[:-self.scalp]
244
245        return features[-1], features

Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).

Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.

Arguments:
  • backbone_channel_list: The channels throughout the entire backbone.
  • embed_dim: The initial embedding dimension.
  • num_heads: The initial number of heads.
  • stages: The number of blocks per stage.
  • global_att_blocks: The parameter to decide which blocks have global attention.
  • window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
  • window_spec: The window size per stage, when not using global attention.
  • scalp: The count of lowest resolution features to discard.
ViT_Sam2( backbone_channel_list: List[int], img_size: int = 1024, embed_dim: int = 96, num_heads: int = 1, stages: Tuple[int, ...] = (2, 3, 16, 3), global_att_blocks: Tuple[int, ...] = (12, 16, 20), window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), window_spec: Tuple[int, ...] = (8, 4, 14, 7), scalp: int = 1)
199    def __init__(
200        self,
201        backbone_channel_list: List[int],
202        img_size: int = 1024,
203        embed_dim: int = 96,
204        num_heads: int = 1,
205        stages: Tuple[int, ...] = (2, 3, 16, 3),
206        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
207        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
208        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
209        scalp: int = 1,
210    ):
211        if not _sam2_import_success:
212            raise RuntimeError(
213                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
214                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
215                "and then rerun your code"
216            )
217
218        trunk = Hiera(
219            embed_dim=embed_dim,
220            num_heads=num_heads,
221            stages=stages,
222            global_att_blocks=global_att_blocks,
223            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
224            window_spec=window_spec,
225        )
226        neck = FpnNeck(
227            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
228            d_model=256,
229            backbone_channel_list=backbone_channel_list,
230            fpn_top_down_levels=[2, 3],
231            fpn_interp_model="nearest",
232        )
233
234        super().__init__(trunk=trunk, neck=neck, scalp=scalp)
235        self.scalp = scalp
236        self.embed_dim = embed_dim
237        self.img_size = img_size
scalp
embed_dim
img_size
def forward(self, x: torch.Tensor):
239    def forward(self, x: torch.Tensor):
240        # The forward pass throught the backbone.
241        features, pos = self.neck(self.trunk(x))
242        if self.scalp > 0:  # This discard the "n" lowest resolution features.
243            features, pos = features[:-self.scalp], pos[:-self.scalp]
244
245        return features[-1], features
class CustomCompose:
253class CustomCompose:
254    def __init__(self, rescale_transform, other_transforms, src_transform):
255        self.rescale_transform = rescale_transform
256        self.other_transforms = other_transforms
257        self.src_transform = src_transform
258
259    def __call__(self, x, valid_masks=None):
260        if valid_masks is not None:
261            nodata = (x * (1 - valid_masks.float())).max()
262        x_aug = self.rescale_transform(x)
263        parms = self.rescale_transform._params
264
265        # sanity check, comment if this is working
266        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
267        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
268
269        if valid_masks is not None:
270            valid_masks = x_aug != nodata
271            _, c, h, w = x_aug.shape
272            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
273        else:
274            zero_ratio = -1
275
276        if self.other_transforms:
277            x_aug = self.other_transforms(x_aug)
278        x_src = self.src_transform(x_aug)
279        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
280
281        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
282        # assert (dx == dy).all()
283
284        h, w = x_aug.shape[-2:]
285        # assert h == w
286
287        return x_aug, x_src, dx / h, zero_ratio, valid_masks
CustomCompose(rescale_transform, other_transforms, src_transform)
254    def __init__(self, rescale_transform, other_transforms, src_transform):
255        self.rescale_transform = rescale_transform
256        self.other_transforms = other_transforms
257        self.src_transform = src_transform
rescale_transform
other_transforms
src_transform
def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device='cpu'):
290def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
291    """
292    grid_size: int of the grid height and width
293    res: array of size n, representing the resolution of a pixel (say, in meters),
294    return:
295    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
296    """
297    # res = torch.FloatTensor(res).to(device)
298    res = res.to(device)
299    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
300    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
301    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
302    grid = torch.stack(grid, dim=0)  # 2 x h x w
303
304    # grid = grid.reshape([2, 1, grid_size, grid_size])
305    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
306    _, n, h, w = grid.shape
307    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
308    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
309    if cls_token:
310        pos_embed = torch.cat(
311            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
312        )
313
314    return pos_embed

grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)

def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
317def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
318    assert embed_dim % 2 == 0
319
320    # use half of dimensions to encode grid_h
321    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
322    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
323
324    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
325    return emb
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
328def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
329    """
330    embed_dim: output dimension for each position
331    pos: a list of positions to be encoded: size (M,)
332    out: (M, D)
333    """
334    assert embed_dim % 2 == 0
335    # old_shape = pos
336    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
337    omega /= embed_dim / 2.0
338    omega = 1.0 / 10000**omega  # (D/2,)
339
340    pos = pos.reshape(-1)  # (M,)
341    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
342
343    emb_sin = torch.sin(out)  # (M, D/2)
344    emb_cos = torch.cos(out)  # (M, D/2)
345
346    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
347    return emb

embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)

class PatchEmbedUnSafe:
350class PatchEmbedUnSafe(PatchEmbed):
351    """Image to Patch Embedding"""
352
353    def forward(self, x):
354        B, C, H, W = x.shape
355
356        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
357        # assert H == self.img_size[0] and W == self.img_size[1], \
358        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
359
360        x = self.proj(x).flatten(2).transpose(1, 2)
361        return x

Image to Patch Embedding

def forward(self, x):
353    def forward(self, x):
354        B, C, H, W = x.shape
355
356        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
357        # assert H == self.img_size[0] and W == self.img_size[1], \
358        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
359
360        x = self.proj(x).flatten(2).transpose(1, 2)
361        return x
class ViT_ScaleMAE:
364class ViT_ScaleMAE(VisionTransformer):
365    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
366
367    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
368    the model on a different zoom factor dataset.
369    """
370
371    def __init__(
372        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
373    ):
374        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
375        self.img_size = img_size
376        self.in_chans = in_chans
377        self.depth = depth
378        self.base_resolution = base_resolution
379
380        self.patch_embed = PatchEmbedUnSafe(
381            img_size=img_size,
382            patch_size=patch_size,
383            in_chans=in_chans,
384            embed_dim=embed_dim,
385        )
386
387    def transform_inputs(self, x):
388        import kornia.augmentation as K
389        from kornia.constants import Resample
390
391        self._transforms = CustomCompose(
392            rescale_transform=K.RandomResizedCrop(
393                (448, 448),
394                ratio=(1.0, 1.0),
395                scale=(1.0, 1.0),
396                resample=Resample.BICUBIC.name,
397            ),
398            other_transforms=None,
399            src_transform=K.Resize((224, 224)),
400        )
401        x, _, ratios, _, _ = self._transforms(x)
402        input_res = ratios * self.base_resolution
403        return x, input_res
404
405    def convert_to_expected_dim(self, x):
406        inputs_ = x[:, 1:, :]  # removing the class tokens
407        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
408        rdim = inputs_.shape[1]
409        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
410        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
411        inputs_ = inputs_.permute(0, 3, 1, 2)
412        return inputs_
413
414    def forward_features(self, x):
415        x, input_res = self.transform_inputs(x)
416
417        B, _, h, w = x.shape
418        x = self.patch_embed(x)
419
420        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
421        pos_embed = get_2d_sincos_pos_embed_with_resolution(
422            x.shape[-1],
423            int(num_patches ** 0.5),
424            input_res,
425            cls_token=True,
426            device=x.device,
427        )
428
429        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
430        x = torch.cat((cls_tokens, x), dim=1)
431        x = x + pos_embed
432        x = self.pos_drop(x)
433
434        # chunks obtained for getting the projections for conjuctions with upsampling blocks
435        _chunks = int(self.depth / 4)
436        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
437
438        list_from_encoder = []
439        for i, blk in enumerate(self.blocks):
440            x = blk(x)
441            if i in chunks_for_projection:
442                list_from_encoder.append(self.convert_to_expected_dim(x))
443
444        x = self.convert_to_expected_dim(x)
445
446        return x, list_from_encoder
447
448    def forward(self, x):
449        x, list_from_encoder = self.forward_features(x)
450        return x, list_from_encoder

Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).

NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.

ViT_ScaleMAE( img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs)
371    def __init__(
372        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
373    ):
374        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
375        self.img_size = img_size
376        self.in_chans = in_chans
377        self.depth = depth
378        self.base_resolution = base_resolution
379
380        self.patch_embed = PatchEmbedUnSafe(
381            img_size=img_size,
382            patch_size=patch_size,
383            in_chans=in_chans,
384            embed_dim=embed_dim,
385        )
img_size
in_chans
depth
base_resolution
patch_embed
def transform_inputs(self, x):
387    def transform_inputs(self, x):
388        import kornia.augmentation as K
389        from kornia.constants import Resample
390
391        self._transforms = CustomCompose(
392            rescale_transform=K.RandomResizedCrop(
393                (448, 448),
394                ratio=(1.0, 1.0),
395                scale=(1.0, 1.0),
396                resample=Resample.BICUBIC.name,
397            ),
398            other_transforms=None,
399            src_transform=K.Resize((224, 224)),
400        )
401        x, _, ratios, _, _ = self._transforms(x)
402        input_res = ratios * self.base_resolution
403        return x, input_res
def convert_to_expected_dim(self, x):
405    def convert_to_expected_dim(self, x):
406        inputs_ = x[:, 1:, :]  # removing the class tokens
407        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
408        rdim = inputs_.shape[1]
409        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
410        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
411        inputs_ = inputs_.permute(0, 3, 1, 2)
412        return inputs_
def forward_features(self, x):
414    def forward_features(self, x):
415        x, input_res = self.transform_inputs(x)
416
417        B, _, h, w = x.shape
418        x = self.patch_embed(x)
419
420        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
421        pos_embed = get_2d_sincos_pos_embed_with_resolution(
422            x.shape[-1],
423            int(num_patches ** 0.5),
424            input_res,
425            cls_token=True,
426            device=x.device,
427        )
428
429        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
430        x = torch.cat((cls_tokens, x), dim=1)
431        x = x + pos_embed
432        x = self.pos_drop(x)
433
434        # chunks obtained for getting the projections for conjuctions with upsampling blocks
435        _chunks = int(self.depth / 4)
436        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
437
438        list_from_encoder = []
439        for i, blk in enumerate(self.blocks):
440            x = blk(x)
441            if i in chunks_for_projection:
442                list_from_encoder.append(self.convert_to_expected_dim(x))
443
444        x = self.convert_to_expected_dim(x)
445
446        return x, list_from_encoder
def forward(self, x):
448    def forward(self, x):
449        x, list_from_encoder = self.forward_features(x)
450        return x, list_from_encoder
class ViT_DINOv2:
453class ViT_DINOv2(DinoV2VisionTransformer):
454    """Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).
455
456    Based on:
457    https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.
458    """
459    def __init__(
460        self,
461        img_size: int = 224,
462        patch_size: int = 16,
463        depth: int = 12,
464        num_register_tokens: int = 0,
465        **kwargs
466    ):
467        if not _dinov2_import_success:
468            raise RuntimeError(
469                "The vision transformer backend can only be initialized if DINOv2 is installed. "
470                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
471                "and then rerun your code."
472            )
473
474        super().__init__(
475            img_size=img_size,
476            depth=depth,
477            patch_size=patch_size,
478            num_register_tokens=num_register_tokens,
479            **kwargs
480        )
481
482        self.img_size = img_size
483        self.num_register_tokens = num_register_tokens
484        self.patch_size = patch_size
485        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
486
487    def forward(self, x, masks=None) -> torch.Tensor:
488
489        B = x.shape[0]
490
491        x = self.prepare_tokens_with_masks(x)
492
493        list_of_encoder = []
494        for i, blk in enumerate(self.blocks):
495            x = blk(x)
496            if i in self.attn_outs:
497                list_of_encoder.append(x)
498
499        x = self.norm(x)
500        x = x[:, self.num_register_tokens + 1:].reshape(
501            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
502        ).permute(0, 3, 1, 2).contiguous()
503
504        list_of_encoder = [
505            o[:, self.num_register_tokens + 1:].reshape(
506                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
507            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
508        ]
509
510        return x, list_of_encoder[:3]

Vision Transformer derived from the DINOv2 Codebase (https://arxiv.org/abs/2304.07193).

Based on: https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py.

ViT_DINOv2( img_size: int = 224, patch_size: int = 16, depth: int = 12, num_register_tokens: int = 0, **kwargs)
459    def __init__(
460        self,
461        img_size: int = 224,
462        patch_size: int = 16,
463        depth: int = 12,
464        num_register_tokens: int = 0,
465        **kwargs
466    ):
467        if not _dinov2_import_success:
468            raise RuntimeError(
469                "The vision transformer backend can only be initialized if DINOv2 is installed. "
470                "Please install DINOv2 from https://github.com/facebookresearch/dinov2 "
471                "and then rerun your code."
472            )
473
474        super().__init__(
475            img_size=img_size,
476            depth=depth,
477            patch_size=patch_size,
478            num_register_tokens=num_register_tokens,
479            **kwargs
480        )
481
482        self.img_size = img_size
483        self.num_register_tokens = num_register_tokens
484        self.patch_size = patch_size
485        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
img_size
num_register_tokens
patch_size
attn_outs
def forward(self, x, masks=None) -> torch.Tensor:
487    def forward(self, x, masks=None) -> torch.Tensor:
488
489        B = x.shape[0]
490
491        x = self.prepare_tokens_with_masks(x)
492
493        list_of_encoder = []
494        for i, blk in enumerate(self.blocks):
495            x = blk(x)
496            if i in self.attn_outs:
497                list_of_encoder.append(x)
498
499        x = self.norm(x)
500        x = x[:, self.num_register_tokens + 1:].reshape(
501            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
502        ).permute(0, 3, 1, 2).contiguous()
503
504        list_of_encoder = [
505            o[:, self.num_register_tokens + 1:].reshape(
506                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
507            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
508        ]
509
510        return x, list_of_encoder[:3]
class ViT_DINOv3:
513class ViT_DINOv3(DinoV3VisionTransformer):
514    """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
515
516    Based on:
517    https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
518
519    Args:
520        img_size: The input image size.
521        patch_size: The patch size.
522        embed_dim: The embedding dimension.
523        depth: The depth of the network.
524        num_heads: The number of heads.
525        ffn_ratio: The FFN rato.
526        n_storage_tokens: The number of storage (class) tokens to remove.
527        kwargs: Keyword arguments for the image encoder base class.
528    """
529    def __init__(
530        self,
531        in_chans: int = 3,
532        img_size: int = 224,
533        patch_size: int = 16,
534        embed_dim: int = 768,
535        depth: int = 12,
536        num_heads: int = 12,
537        ffn_ratio: float = 4.0,
538        n_storage_tokens: int = 0,
539        **kwargs
540    ):
541        if not _dinov3_import_success:
542            raise RuntimeError(
543                "The vision transformer backend can only be initialized if DINOv3 is installed. "
544                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
545                "and then rerun your code."
546            )
547
548        super().__init__(
549            in_chans=in_chans,
550            img_size=img_size,
551            patch_size=patch_size,
552            embed_dim=embed_dim,
553            depth=depth,
554            num_heads=num_heads,
555            ffn_ratio=ffn_ratio,
556            n_storage_tokens=n_storage_tokens,
557            **kwargs
558        )
559
560        self.in_chans = in_chans
561        self.img_size = img_size
562        self.n_storage_tokens = n_storage_tokens
563        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
564
565    def forward(self, x) -> torch.Tensor:
566
567        B = x.shape[0]
568
569        x, hw_tuple = self.prepare_tokens_with_masks(x)
570
571        list_of_encoder = []
572        for i, blk in enumerate(self.blocks):
573            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
574            x = blk(x, rope_sincos)
575            if i in self.attn_outs:
576                list_of_encoder.append(x)
577
578        x = self.norm(x)
579        x = x[:, self.n_storage_tokens + 1:].reshape(
580            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
581        ).permute(0, 3, 1, 2).contiguous()
582
583        list_of_encoder = [
584            o[:, self.n_storage_tokens + 1:].reshape(
585                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
586            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
587        ]
588
589        return x, list_of_encoder[:3]

Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).

Based on: https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.

Arguments:
  • img_size: The input image size.
  • patch_size: The patch size.
  • embed_dim: The embedding dimension.
  • depth: The depth of the network.
  • num_heads: The number of heads.
  • ffn_ratio: The FFN rato.
  • n_storage_tokens: The number of storage (class) tokens to remove.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_DINOv3( in_chans: int = 3, img_size: int = 224, patch_size: int = 16, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, ffn_ratio: float = 4.0, n_storage_tokens: int = 0, **kwargs)
529    def __init__(
530        self,
531        in_chans: int = 3,
532        img_size: int = 224,
533        patch_size: int = 16,
534        embed_dim: int = 768,
535        depth: int = 12,
536        num_heads: int = 12,
537        ffn_ratio: float = 4.0,
538        n_storage_tokens: int = 0,
539        **kwargs
540    ):
541        if not _dinov3_import_success:
542            raise RuntimeError(
543                "The vision transformer backend can only be initialized if DINOv3 is installed. "
544                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
545                "and then rerun your code."
546            )
547
548        super().__init__(
549            in_chans=in_chans,
550            img_size=img_size,
551            patch_size=patch_size,
552            embed_dim=embed_dim,
553            depth=depth,
554            num_heads=num_heads,
555            ffn_ratio=ffn_ratio,
556            n_storage_tokens=n_storage_tokens,
557            **kwargs
558        )
559
560        self.in_chans = in_chans
561        self.img_size = img_size
562        self.n_storage_tokens = n_storage_tokens
563        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
in_chans
img_size
n_storage_tokens
attn_outs
def forward(self, x) -> torch.Tensor:
565    def forward(self, x) -> torch.Tensor:
566
567        B = x.shape[0]
568
569        x, hw_tuple = self.prepare_tokens_with_masks(x)
570
571        list_of_encoder = []
572        for i, blk in enumerate(self.blocks):
573            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
574            x = blk(x, rope_sincos)
575            if i in self.attn_outs:
576                list_of_encoder.append(x)
577
578        x = self.norm(x)
579        x = x[:, self.n_storage_tokens + 1:].reshape(
580            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
581        ).permute(0, 3, 1, 2).contiguous()
582
583        list_of_encoder = [
584            o[:, self.n_storage_tokens + 1:].reshape(
585                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
586            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
587        ]
588
589        return x, list_of_encoder[:3]
def get_vision_transformer( backbone: str, model: str, img_size: int = 1024, **kwargs) -> torch.nn.modules.module.Module:
592def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
593    """Get vision transformer encoder.
594
595    Args:
596        backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
597        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
598        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
599        kwargs: Additional kwargs which can be expected by the vision transformer,
600            e.g. 'base_resolution' for `ViT_ScaleMAE`.
601
602    Returns:
603        The vision transformer.
604    """
605    if backbone == "sam":
606        if model == "vit_b":
607            encoder = ViT_Sam(
608                depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4,
609                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
610                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
611                global_attn_indexes=[2, 5, 8, 11],
612                window_size=14, out_chans=256,
613            )
614        elif model == "vit_l":
615            encoder = ViT_Sam(
616                depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4,
617                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
618                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
619                global_attn_indexes=[5, 11, 17, 23],
620                window_size=14, out_chans=256,
621            )
622        elif model == "vit_h":
623            encoder = ViT_Sam(
624                depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4,
625                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
626                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
627                global_attn_indexes=[7, 15, 23, 31],
628                window_size=14, out_chans=256,
629            )
630        else:
631            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
632
633    elif backbone == "sam2":
634        if model == "hvit_t":
635            encoder = ViT_Sam2(
636                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
637                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
638            )
639        elif model == "hvit_s":
640            encoder = ViT_Sam2(
641                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
642                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
643            )
644        elif model == "hvit_b":
645            encoder = ViT_Sam2(
646                img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
647            )
648        elif model == "hvit_l":
649            encoder = ViT_Sam2(
650                img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
651                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
652            )
653        else:
654            raise ValueError(
655                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
656            )
657
658    elif backbone == "mae":
659        if model == "vit_b":
660            encoder = ViT_MAE(
661                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
662                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
663            )
664        elif model == "vit_l":
665            encoder = ViT_MAE(
666                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
667                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
668            )
669        elif model == "vit_h":
670            encoder = ViT_MAE(
671                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
672                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
673            )
674        else:
675            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
676
677    elif backbone == "scalemae":
678        base_resolution = kwargs.get("base_resolution", 2.5)
679
680        if model == "vit_b":
681            encoder = ViT_ScaleMAE(
682                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
683                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
684                base_resolution=base_resolution,
685            )
686        elif model == "vit_l":
687            encoder = ViT_ScaleMAE(
688                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
689                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
690                base_resolution=base_resolution,
691            )
692        elif model == "vit_h":
693            encoder = ViT_ScaleMAE(
694                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
695                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
696                base_resolution=base_resolution,
697            )
698        else:
699            raise ValueError(
700                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
701            )
702
703    elif backbone == "dinov2":
704        block_fn = partial(Block, attn_class=MemEffAttention)
705        msg = "The model name should be either 'vit_<X>' or 'vit_<X>_reg<Y>."
706
707        if model.startswith("vit_s"):
708            assert model in ["vit_s", "vit_s_reg4"], msg
709            encoder = ViT_DINOv2(
710                img_size=img_size, patch_size=14, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
711                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
712                num_register_tokens=4 if model.endswith("_reg4") else 0,
713            )
714        elif model.startswith("vit_b"):
715            assert model in ["vit_b", "vit_b_reg4"], msg
716            encoder = ViT_DINOv2(
717                img_size=img_size, patch_size=14, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
718                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
719                num_register_tokens=4 if model.endswith("_reg4") else 0,
720            )
721        elif model.startswith("vit_l"):
722            assert model in ["vit_l", "vit_l_reg4"], msg
723            encoder = ViT_DINOv2(
724                img_size=img_size, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
725                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
726                num_register_tokens=4 if model.endswith("_reg4") else 0,
727            )
728        elif model.startswith("vit_g"):
729            assert model in ["vit_g", "vit_g_reg4"], msg
730            encoder = ViT_DINOv2(
731                img_size=img_size, patch_size=14, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4,
732                block_fn=block_fn, in_chans=3, channel_adaptive=False, init_values=1e-5, block_chunks=0,
733                num_register_tokens=4 if model.endswith("_reg4") else 0, ffn_layer="swiglu",
734            )
735        else:
736            raise ValueError(
737                f"'{model}' is not supported by DINOv2. Currently, 'vit_s', 'vit_b', 'vit_l' and 'vit_g' are supported."
738            )
739
740    elif backbone == "dinov3":
741
742        if model == "vit_s":
743            encoder = ViT_DINOv3(
744                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
745                num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
746            )
747        elif model == "vit_s+":
748            encoder = ViT_DINOv3(
749                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
750                num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16",
751                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
752            )
753
754        elif model == "vit_b":
755            encoder = ViT_DINOv3(
756                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32",
757                layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
758            )
759        elif model == "vit_l":
760            encoder = ViT_DINOv3(
761                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
762                depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16",
763                n_storage_tokens=4, mask_k_bias=True,
764            )
765        elif model == "vit_l+":
766            encoder = ViT_DINOv3(
767                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
768                depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
769                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
770            )
771        elif model == "vit_h+":
772            encoder = ViT_DINOv3(
773                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280,
774                depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
775                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
776            )
777        elif model == "vit_7b":
778            encoder = ViT_DINOv3(
779                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096,
780                depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05,
781                norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True,
782                untie_global_and_local_cls_norm=True,
783            )
784        else:
785            raise ValueError(
786                f"'{model}' is not supported by DINOv3. Currently, "
787                " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported."
788            )
789
790    else:
791        raise ValueError(
792            "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae', 'scalemae' or 'dinov3'. "
793            "Please choose one of them."
794        )
795
796    return encoder

Get vision transformer encoder.

Arguments:
  • backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
  • model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • kwargs: Additional kwargs which can be expected by the vision transformer, e.g. 'base_resolution' for ViT_ScaleMAE.
Returns:

The vision transformer.