torch_em.model.vit

  1from functools import partial
  2from typing import Tuple, List
  3
  4import torch
  5import torch.nn as nn
  6
  7# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should
  8# only be optional dependencies for torch_em
  9try:
 10    from segment_anything.modeling import ImageEncoderViT
 11    _sam_import_success = True
 12except ImportError:
 13    ImageEncoderViT = object
 14    _sam_import_success = False
 15
 16try:
 17    from timm.models.vision_transformer import VisionTransformer, PatchEmbed
 18    _timm_import_success = True
 19except ImportError:
 20    VisionTransformer = object
 21    PatchEmbed = object
 22    _timm_import_success = False
 23
 24try:
 25    from sam2.modeling.backbones.hieradet import Hiera
 26    from sam2.modeling.position_encoding import PositionEmbeddingSine
 27    from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck
 28    _sam2_import_success = True
 29except ImportError:
 30    ImageEncoder = object
 31    _sam2_import_success = False
 32
 33
 34try:
 35    from dinov3.models.vision_transformer import DinoVisionTransformer
 36    _dinov3_import_success = True
 37except ImportError:
 38    DinoVisionTransformer = object
 39    _dinov3_import_success = False
 40
 41
 42class ViT_Sam(ImageEncoderViT):
 43    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
 44
 45    Based on:
 46    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
 47
 48    Args:
 49        in_chans: The number of input channels.
 50        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
 51        global_attn_indexes: The global attention indices.
 52        kwargs: Keyword arguments for the image encoder base class.
 53    """
 54    def __init__(
 55        self,
 56        in_chans: int = 3,
 57        embed_dim: int = 768,
 58        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
 59        **kwargs,
 60    ) -> None:
 61        if not _sam_import_success:
 62            raise RuntimeError(
 63                "The vision transformer backend can only be initialized if segment anything is installed. "
 64                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
 65                "and then rerun your code."
 66            )
 67
 68        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
 69        self.chunks_for_projection = global_attn_indexes
 70        self.in_chans = in_chans
 71        self.embed_dim = embed_dim
 72
 73    def forward(self, x: torch.Tensor) -> torch.Tensor:
 74        """Apply the vision transformer to input data.
 75
 76        Args:
 77            x: The input data.
 78
 79        Returns:
 80            The vision transformer output.
 81        """
 82        x = self.patch_embed(x)
 83        if self.pos_embed is not None:
 84            x = x + self.pos_embed
 85
 86        list_from_encoder = []
 87        for i, blk in enumerate(self.blocks):
 88            x = blk(x)
 89            if i in self.chunks_for_projection:
 90                list_from_encoder.append(x)
 91
 92        x = x.permute(0, 3, 1, 2)
 93        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
 94        return x, list_from_encoder[:3]
 95
 96
 97class ViT_MAE(VisionTransformer):
 98    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
 99
100    Based on:
101    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
102
103    Args:
104        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
105        in_chans: The number of input channels.
106        depth: The depth of the vision transformer.
107        kwargs: Additional keyword arguments for the vision transformer base class.
108    """
109    def __init__(
110        self,
111        img_size: int = 1024,  # chosen to match our experiments with segment anything
112        in_chans: int = 3,
113        depth: int = 12,
114        **kwargs
115    ):
116        if not _timm_import_success:
117            raise RuntimeError(
118                "The vision transformer backend can only be initialized if timm is installed. "
119                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
120                "and then rerun your code"
121            )
122        super().__init__(img_size=img_size, depth=depth, **kwargs)
123        self.img_size = img_size
124        self.in_chans = in_chans
125        self.depth = depth
126
127    def convert_to_expected_dim(self, inputs_):
128        """@private
129        """
130        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
131        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
132        rdim = inputs_.shape[1]
133        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
134        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
135        inputs_ = inputs_.permute(0, 3, 1, 2)
136        return inputs_
137
138    def forward_features(self, x):
139        """@private
140        """
141        B = x.shape[0]
142        x = self.patch_embed(x)
143
144        cls_tokens = self.cls_token.expand(B, -1, -1)
145        x = torch.cat((cls_tokens, x), dim=1)
146
147        x = x + self.pos_embed
148        x = self.pos_drop(x)
149
150        # chunks obtained for getting the projections for conjuctions with upsampling blocks
151        _chunks = int(self.depth / 4)
152        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
153
154        list_from_encoder = []
155        for i, blk in enumerate(self.blocks):
156            x = blk(x)
157            if i in chunks_for_projection:
158                list_from_encoder.append(self.convert_to_expected_dim(x))
159
160        x = self.convert_to_expected_dim(x)
161        return x, list_from_encoder[:3]
162
163    def forward(self, x: torch.Tensor) -> torch.Tensor:
164        """Apply the vision transformer to input data.
165
166        Args:
167            x: The input data.
168
169        Returns:
170            The vision transformer output.
171        """
172        x, list_from_encoder = self.forward_features(x)
173        return x, list_from_encoder
174
175
176class ViT_Sam2(ImageEncoder):
177    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
178
179    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
180
181    Args:
182        backbone_channel_list: The channels throughout the entire backbone.
183        embed_dim: The initial embedding dimension.
184        num_heads: The initial number of heads.
185        stages: The number of blocks per stage.
186        global_att_blocks: The parameter to decide which blocks have global attention.
187        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
188        window_spec: The window size per stage, when not using global attention.
189        scalp: The count of lowest resolution features to discard.
190    """
191    def __init__(
192        self,
193        backbone_channel_list: List[int],
194        img_size: int = 1024,
195        embed_dim: int = 96,
196        num_heads: int = 1,
197        stages: Tuple[int, ...] = (2, 3, 16, 3),
198        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
199        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
200        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
201        scalp: int = 1,
202    ):
203        if not _sam2_import_success:
204            raise RuntimeError(
205                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
206                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
207                "and then rerun your code"
208            )
209
210        trunk = Hiera(
211            embed_dim=embed_dim,
212            num_heads=num_heads,
213            stages=stages,
214            global_att_blocks=global_att_blocks,
215            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
216            window_spec=window_spec,
217        )
218        neck = FpnNeck(
219            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
220            d_model=256,
221            backbone_channel_list=backbone_channel_list,
222            fpn_top_down_levels=[2, 3],
223            fpn_interp_model="nearest",
224        )
225
226        super().__init__(trunk=trunk, neck=neck, scalp=scalp)
227        self.scalp = scalp
228        self.embed_dim = embed_dim
229        self.img_size = img_size
230
231    def forward(self, x: torch.Tensor):
232        # The forward pass throught the backbone.
233        features, pos = self.neck(self.trunk(x))
234        if self.scalp > 0:  # This discard the "n" lowest resolution features.
235            features, pos = features[:-self.scalp], pos[:-self.scalp]
236
237        return features[-1], features
238
239
240#
241# Utilities for ScaleMAE's ViT
242#
243
244
245class CustomCompose:
246    def __init__(self, rescale_transform, other_transforms, src_transform):
247        self.rescale_transform = rescale_transform
248        self.other_transforms = other_transforms
249        self.src_transform = src_transform
250
251    def __call__(self, x, valid_masks=None):
252        if valid_masks is not None:
253            nodata = (x * (1 - valid_masks.float())).max()
254        x_aug = self.rescale_transform(x)
255        parms = self.rescale_transform._params
256
257        # sanity check, comment if this is working
258        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
259        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
260
261        if valid_masks is not None:
262            valid_masks = x_aug != nodata
263            _, c, h, w = x_aug.shape
264            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
265        else:
266            zero_ratio = -1
267
268        if self.other_transforms:
269            x_aug = self.other_transforms(x_aug)
270        x_src = self.src_transform(x_aug)
271        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
272
273        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
274        # assert (dx == dy).all()
275
276        h, w = x_aug.shape[-2:]
277        # assert h == w
278
279        return x_aug, x_src, dx / h, zero_ratio, valid_masks
280
281
282def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
283    """
284    grid_size: int of the grid height and width
285    res: array of size n, representing the resolution of a pixel (say, in meters),
286    return:
287    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
288    """
289    # res = torch.FloatTensor(res).to(device)
290    res = res.to(device)
291    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
292    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
293    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
294    grid = torch.stack(grid, dim=0)  # 2 x h x w
295
296    # grid = grid.reshape([2, 1, grid_size, grid_size])
297    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
298    _, n, h, w = grid.shape
299    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
300    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
301    if cls_token:
302        pos_embed = torch.cat(
303            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
304        )
305
306    return pos_embed
307
308
309def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
310    assert embed_dim % 2 == 0
311
312    # use half of dimensions to encode grid_h
313    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
314    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
315
316    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
317    return emb
318
319
320def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
321    """
322    embed_dim: output dimension for each position
323    pos: a list of positions to be encoded: size (M,)
324    out: (M, D)
325    """
326    assert embed_dim % 2 == 0
327    # old_shape = pos
328    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
329    omega /= embed_dim / 2.0
330    omega = 1.0 / 10000**omega  # (D/2,)
331
332    pos = pos.reshape(-1)  # (M,)
333    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
334
335    emb_sin = torch.sin(out)  # (M, D/2)
336    emb_cos = torch.cos(out)  # (M, D/2)
337
338    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
339    return emb
340
341
342class PatchEmbedUnSafe(PatchEmbed):
343    """Image to Patch Embedding"""
344
345    def forward(self, x):
346        B, C, H, W = x.shape
347
348        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
349        # assert H == self.img_size[0] and W == self.img_size[1], \
350        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
351
352        x = self.proj(x).flatten(2).transpose(1, 2)
353        return x
354
355
356class ViT_ScaleMAE(VisionTransformer):
357    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
358
359    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
360    the model on a different zoom factor dataset.
361    """
362
363    def __init__(
364        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
365    ):
366        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
367        self.img_size = img_size
368        self.in_chans = in_chans
369        self.depth = depth
370        self.base_resolution = base_resolution
371
372        self.patch_embed = PatchEmbedUnSafe(
373            img_size=img_size,
374            patch_size=patch_size,
375            in_chans=in_chans,
376            embed_dim=embed_dim,
377        )
378
379    def transform_inputs(self, x):
380        import kornia.augmentation as K
381        from kornia.constants import Resample
382
383        self._transforms = CustomCompose(
384            rescale_transform=K.RandomResizedCrop(
385                (448, 448),
386                ratio=(1.0, 1.0),
387                scale=(1.0, 1.0),
388                resample=Resample.BICUBIC.name,
389            ),
390            other_transforms=None,
391            src_transform=K.Resize((224, 224)),
392        )
393        x, _, ratios, _, _ = self._transforms(x)
394        input_res = ratios * self.base_resolution
395        return x, input_res
396
397    def convert_to_expected_dim(self, x):
398        inputs_ = x[:, 1:, :]  # removing the class tokens
399        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
400        rdim = inputs_.shape[1]
401        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
402        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
403        inputs_ = inputs_.permute(0, 3, 1, 2)
404        return inputs_
405
406    def forward_features(self, x):
407        x, input_res = self.transform_inputs(x)
408
409        B, _, h, w = x.shape
410        x = self.patch_embed(x)
411
412        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
413        pos_embed = get_2d_sincos_pos_embed_with_resolution(
414            x.shape[-1],
415            int(num_patches ** 0.5),
416            input_res,
417            cls_token=True,
418            device=x.device,
419        )
420
421        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
422        x = torch.cat((cls_tokens, x), dim=1)
423        x = x + pos_embed
424        x = self.pos_drop(x)
425
426        # chunks obtained for getting the projections for conjuctions with upsampling blocks
427        _chunks = int(self.depth / 4)
428        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
429
430        list_from_encoder = []
431        for i, blk in enumerate(self.blocks):
432            x = blk(x)
433            if i in chunks_for_projection:
434                list_from_encoder.append(self.convert_to_expected_dim(x))
435
436        x = self.convert_to_expected_dim(x)
437
438        return x, list_from_encoder
439
440    def forward(self, x):
441        x, list_from_encoder = self.forward_features(x)
442        return x, list_from_encoder
443
444
445class ViT_DINOv3(DinoVisionTransformer):
446    """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
447
448    Based on:
449    https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
450
451    Args:
452        img_size: The input image size.
453        patch_size: The patch size.
454        embed_dim: The embedding dimension.
455        depth: The depth of the network.
456        num_heads: The number of heads.
457        ffn_ratio: The FFN rato.
458        n_storage_tokens: The number of storage (class) tokens to remove.
459        kwargs: Keyword arguments for the image encoder base class.
460    """
461    def __init__(
462        self,
463        in_chans: int = 3,
464        img_size: int = 224,
465        patch_size: int = 16,
466        embed_dim: int = 768,
467        depth: int = 12,
468        num_heads: int = 12,
469        ffn_ratio: float = 4.0,
470        n_storage_tokens: int = 0,
471        **kwargs
472    ):
473        if not _dinov3_import_success:
474            raise RuntimeError(
475                "The vision transformer backend can only be initialized if DINOv3 is installed. "
476                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
477                "and then rerun your code."
478            )
479
480        super().__init__(
481            in_chans=in_chans,
482            img_size=img_size,
483            patch_size=patch_size,
484            embed_dim=embed_dim,
485            depth=depth,
486            num_heads=num_heads,
487            ffn_ratio=ffn_ratio,
488            n_storage_tokens=n_storage_tokens,
489            **kwargs
490        )
491
492        self.in_chans = in_chans
493        self.img_size = img_size
494        self.n_storage_tokens = n_storage_tokens
495        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
496
497    def forward(self, x) -> torch.Tensor:
498
499        B = x.shape[0]
500
501        x, hw_tuple = self.prepare_tokens_with_masks(x)
502
503        list_of_encoder = []
504        for i, blk in enumerate(self.blocks):
505            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
506            x = blk(x, rope_sincos)
507            if i in self.attn_outs:
508                list_of_encoder.append(x)
509
510        x = self.norm(x)
511        x = x[:, self.n_storage_tokens + 1:].reshape(
512            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
513        ).permute(0, 3, 1, 2).contiguous()
514
515        list_of_encoder = [
516            o[:, self.n_storage_tokens + 1:].reshape(
517                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
518            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
519        ]
520
521        return x, list_of_encoder[:3]
522
523
524def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
525    """Get vision transformer encoder.
526
527    Args:
528        backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
529        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
530        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
531        kwargs: Additional kwargs which can be expected by the vision transformer,
532            e.g. 'base_resolution' for `ViT_ScaleMAE`.
533
534    Returns:
535        The vision transformer.
536    """
537    if backbone == "sam":
538        if model == "vit_b":
539            encoder = ViT_Sam(
540                depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4,
541                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
542                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
543                global_attn_indexes=[2, 5, 8, 11],
544                window_size=14, out_chans=256,
545            )
546        elif model == "vit_l":
547            encoder = ViT_Sam(
548                depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4,
549                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
550                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
551                global_attn_indexes=[5, 11, 17, 23],
552                window_size=14, out_chans=256,
553            )
554        elif model == "vit_h":
555            encoder = ViT_Sam(
556                depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4,
557                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
558                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
559                global_attn_indexes=[7, 15, 23, 31],
560                window_size=14, out_chans=256,
561            )
562        else:
563            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
564
565    elif backbone == "sam2":
566        if model == "hvit_t":
567            encoder = ViT_Sam2(
568                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
569                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
570            )
571        elif model == "hvit_s":
572            encoder = ViT_Sam2(
573                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
574                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
575            )
576        elif model == "hvit_b":
577            encoder = ViT_Sam2(
578                img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
579            )
580        elif model == "hvit_l":
581            encoder = ViT_Sam2(
582                img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
583                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
584            )
585        else:
586            raise ValueError(
587                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
588            )
589
590    elif backbone == "mae":
591        if model == "vit_b":
592            encoder = ViT_MAE(
593                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
594                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
595            )
596        elif model == "vit_l":
597            encoder = ViT_MAE(
598                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
599                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
600            )
601        elif model == "vit_h":
602            encoder = ViT_MAE(
603                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
604                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
605            )
606        else:
607            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
608
609    elif backbone == "scalemae":
610        base_resolution = kwargs.get("base_resolution", 2.5)
611
612        if model == "vit_b":
613            encoder = ViT_ScaleMAE(
614                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
615                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
616                base_resolution=base_resolution,
617            )
618        elif model == "vit_l":
619            encoder = ViT_ScaleMAE(
620                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
621                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
622                base_resolution=base_resolution,
623            )
624        elif model == "vit_h":
625            encoder = ViT_ScaleMAE(
626                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
627                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
628                base_resolution=base_resolution,
629            )
630        else:
631            raise ValueError(
632                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
633            )
634
635    elif backbone == "dinov3":
636
637        if model == "vit_s":
638            encoder = ViT_DINOv3(
639                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
640                num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
641            )
642        elif model == "vit_s+":
643            encoder = ViT_DINOv3(
644                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
645                num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16",
646                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
647            )
648
649        elif model == "vit_b":
650            encoder = ViT_DINOv3(
651                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32",
652                layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
653            )
654        elif model == "vit_l":
655            encoder = ViT_DINOv3(
656                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
657                depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16",
658                n_storage_tokens=4, mask_k_bias=True,
659            )
660        elif model == "vit_l+":
661            encoder = ViT_DINOv3(
662                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
663                depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
664                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
665            )
666        elif model == "vit_h+":
667            encoder = ViT_DINOv3(
668                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280,
669                depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
670                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
671            )
672        elif model == "vit_7b":
673            encoder = ViT_DINOv3(
674                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096,
675                depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05,
676                norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True,
677                untie_global_and_local_cls_norm=True,
678            )
679        else:
680            raise ValueError(
681                f"'{model}' is not supported by DINOv3. Currently, "
682                " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported."
683            )
684
685    else:
686        raise ValueError(
687            "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae', 'scalemae' or 'dinov3'. "
688            "Please choose one of them."
689        )
690
691    return encoder
class ViT_Sam:
43class ViT_Sam(ImageEncoderViT):
44    """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).
45
46    Based on:
47    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
48
49    Args:
50        in_chans: The number of input channels.
51        embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
52        global_attn_indexes: The global attention indices.
53        kwargs: Keyword arguments for the image encoder base class.
54    """
55    def __init__(
56        self,
57        in_chans: int = 3,
58        embed_dim: int = 768,
59        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
60        **kwargs,
61    ) -> None:
62        if not _sam_import_success:
63            raise RuntimeError(
64                "The vision transformer backend can only be initialized if segment anything is installed. "
65                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
66                "and then rerun your code."
67            )
68
69        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
70        self.chunks_for_projection = global_attn_indexes
71        self.in_chans = in_chans
72        self.embed_dim = embed_dim
73
74    def forward(self, x: torch.Tensor) -> torch.Tensor:
75        """Apply the vision transformer to input data.
76
77        Args:
78            x: The input data.
79
80        Returns:
81            The vision transformer output.
82        """
83        x = self.patch_embed(x)
84        if self.pos_embed is not None:
85            x = x + self.pos_embed
86
87        list_from_encoder = []
88        for i, blk in enumerate(self.blocks):
89            x = blk(x)
90            if i in self.chunks_for_projection:
91                list_from_encoder.append(x)
92
93        x = x.permute(0, 3, 1, 2)
94        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
95        return x, list_from_encoder[:3]

Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643).

Based on: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

Arguments:
  • in_chans: The number of input channels.
  • embed_dim: The embedding dimension, corresponding to the number of output channels of the vision transformer.
  • global_attn_indexes: The global attention indices.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_Sam( in_chans: int = 3, embed_dim: int = 768, global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11], **kwargs)
55    def __init__(
56        self,
57        in_chans: int = 3,
58        embed_dim: int = 768,
59        global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
60        **kwargs,
61    ) -> None:
62        if not _sam_import_success:
63            raise RuntimeError(
64                "The vision transformer backend can only be initialized if segment anything is installed. "
65                "Please install segment anything from https://github.com/facebookresearch/segment-anything "
66                "and then rerun your code."
67            )
68
69        super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
70        self.chunks_for_projection = global_attn_indexes
71        self.in_chans = in_chans
72        self.embed_dim = embed_dim
chunks_for_projection
in_chans
embed_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
74    def forward(self, x: torch.Tensor) -> torch.Tensor:
75        """Apply the vision transformer to input data.
76
77        Args:
78            x: The input data.
79
80        Returns:
81            The vision transformer output.
82        """
83        x = self.patch_embed(x)
84        if self.pos_embed is not None:
85            x = x + self.pos_embed
86
87        list_from_encoder = []
88        for i, blk in enumerate(self.blocks):
89            x = blk(x)
90            if i in self.chunks_for_projection:
91                list_from_encoder.append(x)
92
93        x = x.permute(0, 3, 1, 2)
94        list_from_encoder = [e.permute(0, 3, 1, 2) for e in list_from_encoder]
95        return x, list_from_encoder[:3]

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_MAE:
 98class ViT_MAE(VisionTransformer):
 99    """Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).
100
101    Based on:
102    https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53
103
104    Args:
105        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
106        in_chans: The number of input channels.
107        depth: The depth of the vision transformer.
108        kwargs: Additional keyword arguments for the vision transformer base class.
109    """
110    def __init__(
111        self,
112        img_size: int = 1024,  # chosen to match our experiments with segment anything
113        in_chans: int = 3,
114        depth: int = 12,
115        **kwargs
116    ):
117        if not _timm_import_success:
118            raise RuntimeError(
119                "The vision transformer backend can only be initialized if timm is installed. "
120                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
121                "and then rerun your code"
122            )
123        super().__init__(img_size=img_size, depth=depth, **kwargs)
124        self.img_size = img_size
125        self.in_chans = in_chans
126        self.depth = depth
127
128    def convert_to_expected_dim(self, inputs_):
129        """@private
130        """
131        inputs_ = inputs_[:, 1:, :]  # removing the class tokens
132        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
133        rdim = inputs_.shape[1]
134        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
135        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
136        inputs_ = inputs_.permute(0, 3, 1, 2)
137        return inputs_
138
139    def forward_features(self, x):
140        """@private
141        """
142        B = x.shape[0]
143        x = self.patch_embed(x)
144
145        cls_tokens = self.cls_token.expand(B, -1, -1)
146        x = torch.cat((cls_tokens, x), dim=1)
147
148        x = x + self.pos_embed
149        x = self.pos_drop(x)
150
151        # chunks obtained for getting the projections for conjuctions with upsampling blocks
152        _chunks = int(self.depth / 4)
153        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
154
155        list_from_encoder = []
156        for i, blk in enumerate(self.blocks):
157            x = blk(x)
158            if i in chunks_for_projection:
159                list_from_encoder.append(self.convert_to_expected_dim(x))
160
161        x = self.convert_to_expected_dim(x)
162        return x, list_from_encoder[:3]
163
164    def forward(self, x: torch.Tensor) -> torch.Tensor:
165        """Apply the vision transformer to input data.
166
167        Args:
168            x: The input data.
169
170        Returns:
171            The vision transformer output.
172        """
173        x, list_from_encoder = self.forward_features(x)
174        return x, list_from_encoder

Vision Transformer derived from the Masked Auto Encoder Codebase (https://arxiv.org/abs/2111.06377).

Based on: https://github.com/facebookresearch/mae/blob/main/models_vit.py#L20-L53

Arguments:
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • in_chans: The number of input channels.
  • depth: The depth of the vision transformer.
  • kwargs: Additional keyword arguments for the vision transformer base class.
ViT_MAE(img_size: int = 1024, in_chans: int = 3, depth: int = 12, **kwargs)
110    def __init__(
111        self,
112        img_size: int = 1024,  # chosen to match our experiments with segment anything
113        in_chans: int = 3,
114        depth: int = 12,
115        **kwargs
116    ):
117        if not _timm_import_success:
118            raise RuntimeError(
119                "The vision transformer backend can only be initialized if timm is installed. "
120                "Please install timm (using conda/mamba) for using https://github.com/facebookresearch/mae/ "
121                "and then rerun your code"
122            )
123        super().__init__(img_size=img_size, depth=depth, **kwargs)
124        self.img_size = img_size
125        self.in_chans = in_chans
126        self.depth = depth
img_size
in_chans
depth
def forward(self, x: torch.Tensor) -> torch.Tensor:
164    def forward(self, x: torch.Tensor) -> torch.Tensor:
165        """Apply the vision transformer to input data.
166
167        Args:
168            x: The input data.
169
170        Returns:
171            The vision transformer output.
172        """
173        x, list_from_encoder = self.forward_features(x)
174        return x, list_from_encoder

Apply the vision transformer to input data.

Arguments:
  • x: The input data.
Returns:

The vision transformer output.

class ViT_Sam2:
177class ViT_Sam2(ImageEncoder):
178    """Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).
179
180    Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.
181
182    Args:
183        backbone_channel_list: The channels throughout the entire backbone.
184        embed_dim: The initial embedding dimension.
185        num_heads: The initial number of heads.
186        stages: The number of blocks per stage.
187        global_att_blocks: The parameter to decide which blocks have global attention.
188        window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
189        window_spec: The window size per stage, when not using global attention.
190        scalp: The count of lowest resolution features to discard.
191    """
192    def __init__(
193        self,
194        backbone_channel_list: List[int],
195        img_size: int = 1024,
196        embed_dim: int = 96,
197        num_heads: int = 1,
198        stages: Tuple[int, ...] = (2, 3, 16, 3),
199        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
200        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
201        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
202        scalp: int = 1,
203    ):
204        if not _sam2_import_success:
205            raise RuntimeError(
206                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
207                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
208                "and then rerun your code"
209            )
210
211        trunk = Hiera(
212            embed_dim=embed_dim,
213            num_heads=num_heads,
214            stages=stages,
215            global_att_blocks=global_att_blocks,
216            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
217            window_spec=window_spec,
218        )
219        neck = FpnNeck(
220            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
221            d_model=256,
222            backbone_channel_list=backbone_channel_list,
223            fpn_top_down_levels=[2, 3],
224            fpn_interp_model="nearest",
225        )
226
227        super().__init__(trunk=trunk, neck=neck, scalp=scalp)
228        self.scalp = scalp
229        self.embed_dim = embed_dim
230        self.img_size = img_size
231
232    def forward(self, x: torch.Tensor):
233        # The forward pass throught the backbone.
234        features, pos = self.neck(self.trunk(x))
235        if self.scalp > 0:  # This discard the "n" lowest resolution features.
236            features, pos = features[:-self.scalp], pos[:-self.scalp]
237
238        return features[-1], features

Vision Transformer derived from the Segment Anything 2 Codebase (https://arxiv.org/abs/2408.00714).

Based on https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/backbones/image_encoder.py.

Arguments:
  • backbone_channel_list: The channels throughout the entire backbone.
  • embed_dim: The initial embedding dimension.
  • num_heads: The initial number of heads.
  • stages: The number of blocks per stage.
  • global_att_blocks: The parameter to decide which blocks have global attention.
  • window_pos_embed_bkg_spatial_size: The spatial size of windowed positional embedding.
  • window_spec: The window size per stage, when not using global attention.
  • scalp: The count of lowest resolution features to discard.
ViT_Sam2( backbone_channel_list: List[int], img_size: int = 1024, embed_dim: int = 96, num_heads: int = 1, stages: Tuple[int, ...] = (2, 3, 16, 3), global_att_blocks: Tuple[int, ...] = (12, 16, 20), window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), window_spec: Tuple[int, ...] = (8, 4, 14, 7), scalp: int = 1)
192    def __init__(
193        self,
194        backbone_channel_list: List[int],
195        img_size: int = 1024,
196        embed_dim: int = 96,
197        num_heads: int = 1,
198        stages: Tuple[int, ...] = (2, 3, 16, 3),
199        global_att_blocks: Tuple[int, ...] = (12, 16, 20),
200        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
201        window_spec: Tuple[int, ...] = (8, 4, 14, 7),
202        scalp: int = 1,
203    ):
204        if not _sam2_import_success:
205            raise RuntimeError(
206                "The vision transformer backend can only be initialized if segment anything 2 is installed. "
207                "Please install segment anything 2 from https://github.com/facebookresearch/sam2 "
208                "and then rerun your code"
209            )
210
211        trunk = Hiera(
212            embed_dim=embed_dim,
213            num_heads=num_heads,
214            stages=stages,
215            global_att_blocks=global_att_blocks,
216            window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
217            window_spec=window_spec,
218        )
219        neck = FpnNeck(
220            position_encoding=PositionEmbeddingSine(num_pos_feats=256),
221            d_model=256,
222            backbone_channel_list=backbone_channel_list,
223            fpn_top_down_levels=[2, 3],
224            fpn_interp_model="nearest",
225        )
226
227        super().__init__(trunk=trunk, neck=neck, scalp=scalp)
228        self.scalp = scalp
229        self.embed_dim = embed_dim
230        self.img_size = img_size
scalp
embed_dim
img_size
def forward(self, x: torch.Tensor):
232    def forward(self, x: torch.Tensor):
233        # The forward pass throught the backbone.
234        features, pos = self.neck(self.trunk(x))
235        if self.scalp > 0:  # This discard the "n" lowest resolution features.
236            features, pos = features[:-self.scalp], pos[:-self.scalp]
237
238        return features[-1], features
class CustomCompose:
246class CustomCompose:
247    def __init__(self, rescale_transform, other_transforms, src_transform):
248        self.rescale_transform = rescale_transform
249        self.other_transforms = other_transforms
250        self.src_transform = src_transform
251
252    def __call__(self, x, valid_masks=None):
253        if valid_masks is not None:
254            nodata = (x * (1 - valid_masks.float())).max()
255        x_aug = self.rescale_transform(x)
256        parms = self.rescale_transform._params
257
258        # sanity check, comment if this is working
259        # valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
260        # assert (x_aug==self.rescale_transform(x, params=parms)).all() #
261
262        if valid_masks is not None:
263            valid_masks = x_aug != nodata
264            _, c, h, w = x_aug.shape
265            zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
266        else:
267            zero_ratio = -1
268
269        if self.other_transforms:
270            x_aug = self.other_transforms(x_aug)
271        x_src = self.src_transform(x_aug)
272        dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]
273
274        # dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
275        # assert (dx == dy).all()
276
277        h, w = x_aug.shape[-2:]
278        # assert h == w
279
280        return x_aug, x_src, dx / h, zero_ratio, valid_masks
CustomCompose(rescale_transform, other_transforms, src_transform)
247    def __init__(self, rescale_transform, other_transforms, src_transform):
248        self.rescale_transform = rescale_transform
249        self.other_transforms = other_transforms
250        self.src_transform = src_transform
rescale_transform
other_transforms
src_transform
def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device='cpu'):
283def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
284    """
285    grid_size: int of the grid height and width
286    res: array of size n, representing the resolution of a pixel (say, in meters),
287    return:
288    pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
289    """
290    # res = torch.FloatTensor(res).to(device)
291    res = res.to(device)
292    grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
293    grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
294    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here h goes first,direction reversed for numpy
295    grid = torch.stack(grid, dim=0)  # 2 x h x w
296
297    # grid = grid.reshape([2, 1, grid_size, grid_size])
298    grid = torch.einsum("chw,n->cnhw", grid, res)  # 2 x n x h x w
299    _, n, h, w = grid.shape
300    pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)  # (nxH*W, D/2)
301    pos_embed = pos_embed.reshape(n, h * w, embed_dim)
302    if cls_token:
303        pos_embed = torch.cat(
304            [torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
305        )
306
307    return pos_embed

grid_size: int of the grid height and width res: array of size n, representing the resolution of a pixel (say, in meters), return: pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)

def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
310def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
311    assert embed_dim % 2 == 0
312
313    # use half of dimensions to encode grid_h
314    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
315    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)
316
317    emb = torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)
318    return emb
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
321def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
322    """
323    embed_dim: output dimension for each position
324    pos: a list of positions to be encoded: size (M,)
325    out: (M, D)
326    """
327    assert embed_dim % 2 == 0
328    # old_shape = pos
329    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
330    omega /= embed_dim / 2.0
331    omega = 1.0 / 10000**omega  # (D/2,)
332
333    pos = pos.reshape(-1)  # (M,)
334    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
335
336    emb_sin = torch.sin(out)  # (M, D/2)
337    emb_cos = torch.cos(out)  # (M, D/2)
338
339    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
340    return emb

embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)

class PatchEmbedUnSafe:
343class PatchEmbedUnSafe(PatchEmbed):
344    """Image to Patch Embedding"""
345
346    def forward(self, x):
347        B, C, H, W = x.shape
348
349        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
350        # assert H == self.img_size[0] and W == self.img_size[1], \
351        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
352
353        x = self.proj(x).flatten(2).transpose(1, 2)
354        return x

Image to Patch Embedding

def forward(self, x):
346    def forward(self, x):
347        B, C, H, W = x.shape
348
349        # NOTE: Comment code from ScaleMAE: Dropped size check in timm
350        # assert H == self.img_size[0] and W == self.img_size[1], \
351        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
352
353        x = self.proj(x).flatten(2).transpose(1, 2)
354        return x
class ViT_ScaleMAE:
357class ViT_ScaleMAE(VisionTransformer):
358    """Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
359
360    NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using
361    the model on a different zoom factor dataset.
362    """
363
364    def __init__(
365        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
366    ):
367        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
368        self.img_size = img_size
369        self.in_chans = in_chans
370        self.depth = depth
371        self.base_resolution = base_resolution
372
373        self.patch_embed = PatchEmbedUnSafe(
374            img_size=img_size,
375            patch_size=patch_size,
376            in_chans=in_chans,
377            embed_dim=embed_dim,
378        )
379
380    def transform_inputs(self, x):
381        import kornia.augmentation as K
382        from kornia.constants import Resample
383
384        self._transforms = CustomCompose(
385            rescale_transform=K.RandomResizedCrop(
386                (448, 448),
387                ratio=(1.0, 1.0),
388                scale=(1.0, 1.0),
389                resample=Resample.BICUBIC.name,
390            ),
391            other_transforms=None,
392            src_transform=K.Resize((224, 224)),
393        )
394        x, _, ratios, _, _ = self._transforms(x)
395        input_res = ratios * self.base_resolution
396        return x, input_res
397
398    def convert_to_expected_dim(self, x):
399        inputs_ = x[:, 1:, :]  # removing the class tokens
400        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
401        rdim = inputs_.shape[1]
402        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
403        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
404        inputs_ = inputs_.permute(0, 3, 1, 2)
405        return inputs_
406
407    def forward_features(self, x):
408        x, input_res = self.transform_inputs(x)
409
410        B, _, h, w = x.shape
411        x = self.patch_embed(x)
412
413        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
414        pos_embed = get_2d_sincos_pos_embed_with_resolution(
415            x.shape[-1],
416            int(num_patches ** 0.5),
417            input_res,
418            cls_token=True,
419            device=x.device,
420        )
421
422        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
423        x = torch.cat((cls_tokens, x), dim=1)
424        x = x + pos_embed
425        x = self.pos_drop(x)
426
427        # chunks obtained for getting the projections for conjuctions with upsampling blocks
428        _chunks = int(self.depth / 4)
429        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
430
431        list_from_encoder = []
432        for i, blk in enumerate(self.blocks):
433            x = blk(x)
434            if i in chunks_for_projection:
435                list_from_encoder.append(self.convert_to_expected_dim(x))
436
437        x = self.convert_to_expected_dim(x)
438
439        return x, list_from_encoder
440
441    def forward(self, x):
442        x, list_from_encoder = self.forward_features(x)
443        return x, list_from_encoder

Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).

NOTE: For downstream tasks, the "base_resoulution" parameter needs to be adjusted manually when using the model on a different zoom factor dataset.

ViT_ScaleMAE( img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs)
364    def __init__(
365        self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, base_resolution=2.5, **kwargs
366    ):
367        super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
368        self.img_size = img_size
369        self.in_chans = in_chans
370        self.depth = depth
371        self.base_resolution = base_resolution
372
373        self.patch_embed = PatchEmbedUnSafe(
374            img_size=img_size,
375            patch_size=patch_size,
376            in_chans=in_chans,
377            embed_dim=embed_dim,
378        )
img_size
in_chans
depth
base_resolution
patch_embed
def transform_inputs(self, x):
380    def transform_inputs(self, x):
381        import kornia.augmentation as K
382        from kornia.constants import Resample
383
384        self._transforms = CustomCompose(
385            rescale_transform=K.RandomResizedCrop(
386                (448, 448),
387                ratio=(1.0, 1.0),
388                scale=(1.0, 1.0),
389                resample=Resample.BICUBIC.name,
390            ),
391            other_transforms=None,
392            src_transform=K.Resize((224, 224)),
393        )
394        x, _, ratios, _, _ = self._transforms(x)
395        input_res = ratios * self.base_resolution
396        return x, input_res
def convert_to_expected_dim(self, x):
398    def convert_to_expected_dim(self, x):
399        inputs_ = x[:, 1:, :]  # removing the class tokens
400        # reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
401        rdim = inputs_.shape[1]
402        dshape = int(rdim ** 0.5)  # finding square root of the outputs for obtaining the patch shape
403        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
404        inputs_ = inputs_.permute(0, 3, 1, 2)
405        return inputs_
def forward_features(self, x):
407    def forward_features(self, x):
408        x, input_res = self.transform_inputs(x)
409
410        B, _, h, w = x.shape
411        x = self.patch_embed(x)
412
413        num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
414        pos_embed = get_2d_sincos_pos_embed_with_resolution(
415            x.shape[-1],
416            int(num_patches ** 0.5),
417            input_res,
418            cls_token=True,
419            device=x.device,
420        )
421
422        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
423        x = torch.cat((cls_tokens, x), dim=1)
424        x = x + pos_embed
425        x = self.pos_drop(x)
426
427        # chunks obtained for getting the projections for conjuctions with upsampling blocks
428        _chunks = int(self.depth / 4)
429        chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]
430
431        list_from_encoder = []
432        for i, blk in enumerate(self.blocks):
433            x = blk(x)
434            if i in chunks_for_projection:
435                list_from_encoder.append(self.convert_to_expected_dim(x))
436
437        x = self.convert_to_expected_dim(x)
438
439        return x, list_from_encoder
def forward(self, x):
441    def forward(self, x):
442        x, list_from_encoder = self.forward_features(x)
443        return x, list_from_encoder
class ViT_DINOv3:
446class ViT_DINOv3(DinoVisionTransformer):
447    """Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).
448
449    Based on:
450    https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.
451
452    Args:
453        img_size: The input image size.
454        patch_size: The patch size.
455        embed_dim: The embedding dimension.
456        depth: The depth of the network.
457        num_heads: The number of heads.
458        ffn_ratio: The FFN rato.
459        n_storage_tokens: The number of storage (class) tokens to remove.
460        kwargs: Keyword arguments for the image encoder base class.
461    """
462    def __init__(
463        self,
464        in_chans: int = 3,
465        img_size: int = 224,
466        patch_size: int = 16,
467        embed_dim: int = 768,
468        depth: int = 12,
469        num_heads: int = 12,
470        ffn_ratio: float = 4.0,
471        n_storage_tokens: int = 0,
472        **kwargs
473    ):
474        if not _dinov3_import_success:
475            raise RuntimeError(
476                "The vision transformer backend can only be initialized if DINOv3 is installed. "
477                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
478                "and then rerun your code."
479            )
480
481        super().__init__(
482            in_chans=in_chans,
483            img_size=img_size,
484            patch_size=patch_size,
485            embed_dim=embed_dim,
486            depth=depth,
487            num_heads=num_heads,
488            ffn_ratio=ffn_ratio,
489            n_storage_tokens=n_storage_tokens,
490            **kwargs
491        )
492
493        self.in_chans = in_chans
494        self.img_size = img_size
495        self.n_storage_tokens = n_storage_tokens
496        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
497
498    def forward(self, x) -> torch.Tensor:
499
500        B = x.shape[0]
501
502        x, hw_tuple = self.prepare_tokens_with_masks(x)
503
504        list_of_encoder = []
505        for i, blk in enumerate(self.blocks):
506            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
507            x = blk(x, rope_sincos)
508            if i in self.attn_outs:
509                list_of_encoder.append(x)
510
511        x = self.norm(x)
512        x = x[:, self.n_storage_tokens + 1:].reshape(
513            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
514        ).permute(0, 3, 1, 2).contiguous()
515
516        list_of_encoder = [
517            o[:, self.n_storage_tokens + 1:].reshape(
518                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
519            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
520        ]
521
522        return x, list_of_encoder[:3]

Vision Transformer derived from the DINOv3 Codebase (https://arxiv.org/abs/2508.10104).

Based on: https://github.com/facebookresearch/dinov3/blob/main/dinov3/models/vision_transformer.py.

Arguments:
  • img_size: The input image size.
  • patch_size: The patch size.
  • embed_dim: The embedding dimension.
  • depth: The depth of the network.
  • num_heads: The number of heads.
  • ffn_ratio: The FFN rato.
  • n_storage_tokens: The number of storage (class) tokens to remove.
  • kwargs: Keyword arguments for the image encoder base class.
ViT_DINOv3( in_chans: int = 3, img_size: int = 224, patch_size: int = 16, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, ffn_ratio: float = 4.0, n_storage_tokens: int = 0, **kwargs)
462    def __init__(
463        self,
464        in_chans: int = 3,
465        img_size: int = 224,
466        patch_size: int = 16,
467        embed_dim: int = 768,
468        depth: int = 12,
469        num_heads: int = 12,
470        ffn_ratio: float = 4.0,
471        n_storage_tokens: int = 0,
472        **kwargs
473    ):
474        if not _dinov3_import_success:
475            raise RuntimeError(
476                "The vision transformer backend can only be initialized if DINOv3 is installed. "
477                "Please install DINOv3 from https://github.com/facebookresearch/dinov3 "
478                "and then rerun your code."
479            )
480
481        super().__init__(
482            in_chans=in_chans,
483            img_size=img_size,
484            patch_size=patch_size,
485            embed_dim=embed_dim,
486            depth=depth,
487            num_heads=num_heads,
488            ffn_ratio=ffn_ratio,
489            n_storage_tokens=n_storage_tokens,
490            **kwargs
491        )
492
493        self.in_chans = in_chans
494        self.img_size = img_size
495        self.n_storage_tokens = n_storage_tokens
496        self.attn_outs = [i for i in range(depth) if i % 3 == 2]
in_chans
img_size
n_storage_tokens
attn_outs
def forward(self, x) -> torch.Tensor:
498    def forward(self, x) -> torch.Tensor:
499
500        B = x.shape[0]
501
502        x, hw_tuple = self.prepare_tokens_with_masks(x)
503
504        list_of_encoder = []
505        for i, blk in enumerate(self.blocks):
506            rope_sincos = self.rope_embed(H=hw_tuple[0], W=hw_tuple[1])
507            x = blk(x, rope_sincos)
508            if i in self.attn_outs:
509                list_of_encoder.append(x)
510
511        x = self.norm(x)
512        x = x[:, self.n_storage_tokens + 1:].reshape(
513            B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
514        ).permute(0, 3, 1, 2).contiguous()
515
516        list_of_encoder = [
517            o[:, self.n_storage_tokens + 1:].reshape(
518                B, self.img_size // self.patch_size, self.img_size // self.patch_size, -1
519            ).permute(0, 3, 1, 2).contiguous() for o in list_of_encoder
520        ]
521
522        return x, list_of_encoder[:3]
def get_vision_transformer( backbone: str, model: str, img_size: int = 1024, **kwargs) -> torch.nn.modules.module.Module:
525def get_vision_transformer(backbone: str, model: str, img_size: int = 1024, **kwargs) -> nn.Module:
526    """Get vision transformer encoder.
527
528    Args:
529        backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
530        model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
531        img_size: The size of the input for the image encoder. Input images will be resized to match this size.
532        kwargs: Additional kwargs which can be expected by the vision transformer,
533            e.g. 'base_resolution' for `ViT_ScaleMAE`.
534
535    Returns:
536        The vision transformer.
537    """
538    if backbone == "sam":
539        if model == "vit_b":
540            encoder = ViT_Sam(
541                depth=12, embed_dim=768, img_size=img_size, mlp_ratio=4,
542                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
543                num_heads=12, patch_size=16, qkv_bias=True, use_rel_pos=True,
544                global_attn_indexes=[2, 5, 8, 11],
545                window_size=14, out_chans=256,
546            )
547        elif model == "vit_l":
548            encoder = ViT_Sam(
549                depth=24, embed_dim=1024, img_size=img_size, mlp_ratio=4,
550                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
551                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
552                global_attn_indexes=[5, 11, 17, 23],
553                window_size=14, out_chans=256,
554            )
555        elif model == "vit_h":
556            encoder = ViT_Sam(
557                depth=32, embed_dim=1280, img_size=img_size, mlp_ratio=4,
558                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
559                num_heads=16, patch_size=16, qkv_bias=True, use_rel_pos=True,
560                global_attn_indexes=[7, 15, 23, 31],
561                window_size=14, out_chans=256,
562            )
563        else:
564            raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
565
566    elif backbone == "sam2":
567        if model == "hvit_t":
568            encoder = ViT_Sam2(
569                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 7, 2], global_att_blocks=[5, 7, 9],
570                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
571            )
572        elif model == "hvit_s":
573            encoder = ViT_Sam2(
574                img_size=img_size, embed_dim=96, num_heads=1, stages=[1, 2, 11, 2], global_att_blocks=[7, 10, 13],
575                window_pos_embed_bkg_spatial_size=[7, 7], backbone_channel_list=[768, 384, 192, 96],
576            )
577        elif model == "hvit_b":
578            encoder = ViT_Sam2(
579                img_size=img_size, embed_dim=112, num_heads=2, backbone_channel_list=[896, 448, 224, 112],
580            )
581        elif model == "hvit_l":
582            encoder = ViT_Sam2(
583                img_size=img_size, embed_dim=144, num_heads=2, stages=[2, 6, 36, 4], global_att_blocks=[23, 33, 43],
584                window_spec=[8, 4, 16, 8], backbone_channel_list=[1152, 576, 288, 144],
585            )
586        else:
587            raise ValueError(
588                f"'{model}' is not supported by SAM2. Currently, 'hvit_t', 'hvit_s', 'hvit_b', 'hvit_l' are supported."
589            )
590
591    elif backbone == "mae":
592        if model == "vit_b":
593            encoder = ViT_MAE(
594                img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
595                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
596            )
597        elif model == "vit_l":
598            encoder = ViT_MAE(
599                img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
600                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
601            )
602        elif model == "vit_h":
603            encoder = ViT_MAE(
604                img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
605                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
606            )
607        else:
608            raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")
609
610    elif backbone == "scalemae":
611        base_resolution = kwargs.get("base_resolution", 2.5)
612
613        if model == "vit_b":
614            encoder = ViT_ScaleMAE(
615                img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
616                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
617                base_resolution=base_resolution,
618            )
619        elif model == "vit_l":
620            encoder = ViT_ScaleMAE(
621                img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
622                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
623                base_resolution=base_resolution,
624            )
625        elif model == "vit_h":
626            encoder = ViT_ScaleMAE(
627                img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
628                mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
629                base_resolution=base_resolution,
630            )
631        else:
632            raise ValueError(
633                f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
634            )
635
636    elif backbone == "dinov3":
637
638        if model == "vit_s":
639            encoder = ViT_DINOv3(
640                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
641                num_heads=6, layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
642            )
643        elif model == "vit_s+":
644            encoder = ViT_DINOv3(
645                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=384,
646                num_heads=6, ffn_ratio=6, layerscale_init=1.0e-05, norm_layer="layernormbf16",
647                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
648            )
649
650        elif model == "vit_b":
651            encoder = ViT_DINOv3(
652                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32",
653                layerscale_init=1.0e-05, norm_layer="layernormbf16", n_storage_tokens=4, mask_k_bias=True,
654            )
655        elif model == "vit_l":
656            encoder = ViT_DINOv3(
657                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
658                depth=24, num_heads=16, layerscale_init=1.0e-05, norm_layer="layernormbf16",
659                n_storage_tokens=4, mask_k_bias=True,
660            )
661        elif model == "vit_l+":
662            encoder = ViT_DINOv3(
663                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1024,
664                depth=24, num_heads=16, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
665                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
666            )
667        elif model == "vit_h+":
668            encoder = ViT_DINOv3(
669                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=1280,
670                depth=32, num_heads=20, ffn_ratio=6.0, layerscale_init=1.0e-05, norm_layer="layernormbf16",
671                ffn_layer="swiglu", n_storage_tokens=4, mask_k_bias=True,
672            )
673        elif model == "vit_7b":
674            encoder = ViT_DINOv3(
675                img_size=img_size, pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=4096,
676                depth=40, num_heads=32, ffn_ratio=3, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05,
677                norm_layer="layernormbf16", ffn_layer="swiglu64", n_storage_tokens=4, mask_k_bias=True,
678                untie_global_and_local_cls_norm=True,
679            )
680        else:
681            raise ValueError(
682                f"'{model}' is not supported by DINOv3. Currently, "
683                " 'vit_s', 'vit_s+', 'vit_b', 'vit_l', 'vit_l+', 'vit_h+'. 'vit_7b' are supported."
684            )
685
686    else:
687        raise ValueError(
688            "The 'UNETR' supported backbones are 'sam', 'sam2', 'mae', 'scalemae' or 'dinov3'. "
689            "Please choose one of them."
690        )
691
692    return encoder

Get vision transformer encoder.

Arguments:
  • backbone: The name of the vision transformer implementation. One of "sam" / "mae" / "scalemae".
  • model: The name of the model. One of "vit_b", "vit_l" or "vit_h".
  • img_size: The size of the input for the image encoder. Input images will be resized to match this size.
  • kwargs: Additional kwargs which can be expected by the vision transformer, e.g. 'base_resolution' for ViT_ScaleMAE.
Returns:

The vision transformer.