torch_em.model.vim

@private

  1"""@private
  2"""
  3
  4# installation from https://github.com/hustvl/Vim
  5# encoder from https://github.com/hustvl/Vim
  6# decoder from https://github.com/constantinpape/torch-em
  7
  8# pretrained model weights: vim_t - https://huggingface.co/hustvl/Vim-tiny/blob/main/vim_tiny_73p1.pth
  9
 10import random
 11
 12import torch
 13
 14from .unetr import UNETR
 15
 16try:
 17    from vim.models_mamba import VisionMamba, rms_norm_fn, RMSNorm, layer_norm_fn
 18    _have_vim_installed = True
 19except ImportError:
 20    VisionMamba = object
 21    rms_norm_fn = RMSNorm = layer_norm_fn = None
 22    _have_vim_installed = False
 23
 24try:
 25    from timm.models.vision_transformer import _cfg
 26except ImportError:
 27    _cfg = None
 28
 29
 30class ViM(VisionMamba):
 31    def __init__(self, **kwargs):
 32        assert _have_vim_installed, "Please install 'Vim'."
 33        super().__init__(**kwargs)
 34
 35    def convert_to_expected_dim(self, inputs_):
 36        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
 37        rdim = inputs_.shape[1]
 38        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
 39        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
 40        inputs_ = inputs_.permute(0, 3, 1, 2)
 41        return inputs_
 42
 43    def forward_features(
 44        self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False
 45    ):
 46        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
 47        # with slight modifications to add the dist_token
 48        x = self.patch_embed(x)
 49        B, M, _ = x.shape
 50
 51        if self.if_cls_token:
 52            if self.use_double_cls_token:
 53                cls_token_head = self.cls_token_head.expand(B, -1, -1)
 54                cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
 55                token_position = [0, M + 1]
 56                x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
 57                M = x.shape[1]
 58            else:
 59                if self.use_middle_cls_token:
 60                    cls_token = self.cls_token.expand(B, -1, -1)
 61                    token_position = M // 2
 62                    # add cls token in the middle
 63                    x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
 64                elif if_random_cls_token_position:
 65                    cls_token = self.cls_token.expand(B, -1, -1)
 66                    token_position = random.randint(0, M)
 67                    x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
 68                    print("token_position: ", token_position)
 69                else:
 70                    cls_token = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
 71                    token_position = 0
 72                    x = torch.cat((cls_token, x), dim=1)
 73                M = x.shape[1]
 74
 75        if self.if_abs_pos_embed:
 76            x = x + self.pos_embed
 77            x = self.pos_drop(x)
 78
 79        if if_random_token_rank:
 80            # general random shuffle index
 81            shuffle_indices = torch.randperm(M)
 82
 83            if isinstance(token_position, list):
 84                print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
 85            else:
 86                print("original value: ", x[0, token_position, 0])
 87            print("original token_position: ", token_position)
 88
 89            # execute shuffle
 90            x = x[:, shuffle_indices, :]
 91
 92            if isinstance(token_position, list):
 93                # find new position of cls token after shuffle
 94                new_token_position = [
 95                    torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))
 96                ]
 97                token_position = new_token_position
 98            else:
 99                # find new position of cls token after the shuffle
100                token_position = torch.where(shuffle_indices == token_position)[0].item()
101
102            if isinstance(token_position, list):
103                print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
104            else:
105                print("new value: ", x[0, token_position, 0])
106            print("new token_position: ", token_position)
107
108        if_flip_img_sequences = False
109        if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
110            x = x.flip([1])
111            if_flip_img_sequences = True
112
113        # mamba impl
114        residual = None
115        hidden_states = x
116        if not self.if_bidirectional:
117            for layer in self.layers:
118
119                if if_flip_img_sequences and self.if_rope:
120                    hidden_states = hidden_states.flip([1])
121                    if residual is not None:
122                        residual = residual.flip([1])
123
124                # rope about
125                if self.if_rope:
126                    hidden_states = self.rope(hidden_states)
127                    if residual is not None and self.if_rope_residual:
128                        residual = self.rope(residual)
129
130                if if_flip_img_sequences and self.if_rope:
131                    hidden_states = hidden_states.flip([1])
132                    if residual is not None:
133                        residual = residual.flip([1])
134
135                hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
136        else:
137            # get two layers in a single for-loop
138            for i in range(len(self.layers) // 2):
139                if self.if_rope:
140                    hidden_states = self.rope(hidden_states)
141                    if residual is not None and self.if_rope_residual:
142                        residual = self.rope(residual)
143
144                hidden_states_f, residual_f = self.layers[i * 2](
145                    hidden_states, residual, inference_params=inference_params
146                )
147                hidden_states_b, residual_b = self.layers[i * 2 + 1](
148                    hidden_states.flip([1]),
149                    None if residual is None else residual.flip([1]),
150                    inference_params=inference_params
151                )
152                hidden_states = hidden_states_f + hidden_states_b.flip([1])
153                residual = residual_f + residual_b.flip([1])
154
155        if not self.fused_add_norm:
156            if residual is None:
157                residual = hidden_states
158            else:
159                residual = residual + self.drop_path(hidden_states)
160            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
161        else:
162            # Set prenorm = False here since we don't need the residual
163            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
164            hidden_states = fused_add_norm_fn(
165                self.drop_path(hidden_states),
166                self.norm_f.weight,
167                self.norm_f.bias,
168                eps=self.norm_f.eps,
169                residual=residual,
170                prenorm=False,
171                residual_in_fp32=self.residual_in_fp32,
172            )
173
174        if self.final_pool_type == 'none':
175            return hidden_states[:, -1, :]
176        elif self.final_pool_type == 'mean':
177            return hidden_states.mean(dim=1)
178        elif self.final_pool_type == 'max':
179            return hidden_states.max(dim=1)
180        elif self.final_pool_type == 'all':
181            return hidden_states
182        else:
183            raise NotImplementedError
184
185    def forward(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
186        x = self.forward_features(x, inference_params, if_random_cls_token_position, if_random_token_rank)
187
188        if self.final_pool_type == 'max':
189            x = x.max(dim=1)[0]
190
191        if self.if_cls_token:  # remove the class token
192            x = x[:, 1:, :]
193
194        # let's get the patches back from the 1d tokens
195        x = self.convert_to_expected_dim(x)
196
197        return x  # from here, the tokens can be upsampled easily (N x H x W x C)
198
199
200def get_vim_encoder(model_type="vim_t", with_cls_token=True):
201    if model_type == "vim_t":
202        embed_dim = 192
203    elif model_type == "vim_s":
204        embed_dim = 384
205    elif model_type == "vim_b":
206        embed_dim = 768
207    else:
208        raise ValueError("Choose from 'vim_t' / 'vim_s' / 'vim_b'")
209
210    encoder = ViM(
211        img_size=1024,
212        patch_size=16,
213        embed_dim=embed_dim,
214        depth=24,
215        rms_norm=True,
216        residual_in_fp32=True,
217        fused_add_norm=True,
218        final_pool_type='all',
219        if_abs_pos_embed=True,
220        if_rope=False,
221        if_rope_residual=False,
222        bimamba_type="v2",
223        if_cls_token=with_cls_token,
224        if_divide_out=True,
225        use_middle_cls_token=True,
226    )
227    encoder.default_cfg = _cfg()
228    return encoder
229
230
231def get_vimunet_model(
232    out_channels, model_type="vim_t", with_cls_token=True, device=None, checkpoint=None
233):
234    if device is None:
235        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
236
237    encoder = get_vim_encoder(model_type, with_cls_token)
238
239    model_state = None
240    if checkpoint is not None:
241        state = torch.load(checkpoint, map_location="cpu", weights_only=False)
242
243        if checkpoint.endswith(".pth"):  # from Vim
244            encoder_state = state["model"]
245            encoder.load_state_dict(encoder_state)
246
247        else:  # from torch_em
248            model_state = state["model_state"]
249
250    encoder.img_size = encoder.patch_embed.img_size[0]
251
252    # TODO: Update design so that: we have a backbone to fetch encoder and decoder flexibly
253    # and is ideally not named as "UNETR" but something as for example "EncoderDecoderNet"
254    model = UNETR(
255        encoder=encoder,
256        out_channels=out_channels,
257        resize_input=False,
258        use_skip_connection=False,
259        final_activation="Sigmoid",
260    )
261
262    if model_state is not None:
263        model.load_state_dict(model_state)
264
265    model.to(device)
266
267    return model
class ViM:
 31class ViM(VisionMamba):
 32    def __init__(self, **kwargs):
 33        assert _have_vim_installed, "Please install 'Vim'."
 34        super().__init__(**kwargs)
 35
 36    def convert_to_expected_dim(self, inputs_):
 37        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
 38        rdim = inputs_.shape[1]
 39        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
 40        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
 41        inputs_ = inputs_.permute(0, 3, 1, 2)
 42        return inputs_
 43
 44    def forward_features(
 45        self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False
 46    ):
 47        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
 48        # with slight modifications to add the dist_token
 49        x = self.patch_embed(x)
 50        B, M, _ = x.shape
 51
 52        if self.if_cls_token:
 53            if self.use_double_cls_token:
 54                cls_token_head = self.cls_token_head.expand(B, -1, -1)
 55                cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
 56                token_position = [0, M + 1]
 57                x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
 58                M = x.shape[1]
 59            else:
 60                if self.use_middle_cls_token:
 61                    cls_token = self.cls_token.expand(B, -1, -1)
 62                    token_position = M // 2
 63                    # add cls token in the middle
 64                    x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
 65                elif if_random_cls_token_position:
 66                    cls_token = self.cls_token.expand(B, -1, -1)
 67                    token_position = random.randint(0, M)
 68                    x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
 69                    print("token_position: ", token_position)
 70                else:
 71                    cls_token = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
 72                    token_position = 0
 73                    x = torch.cat((cls_token, x), dim=1)
 74                M = x.shape[1]
 75
 76        if self.if_abs_pos_embed:
 77            x = x + self.pos_embed
 78            x = self.pos_drop(x)
 79
 80        if if_random_token_rank:
 81            # general random shuffle index
 82            shuffle_indices = torch.randperm(M)
 83
 84            if isinstance(token_position, list):
 85                print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
 86            else:
 87                print("original value: ", x[0, token_position, 0])
 88            print("original token_position: ", token_position)
 89
 90            # execute shuffle
 91            x = x[:, shuffle_indices, :]
 92
 93            if isinstance(token_position, list):
 94                # find new position of cls token after shuffle
 95                new_token_position = [
 96                    torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))
 97                ]
 98                token_position = new_token_position
 99            else:
100                # find new position of cls token after the shuffle
101                token_position = torch.where(shuffle_indices == token_position)[0].item()
102
103            if isinstance(token_position, list):
104                print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
105            else:
106                print("new value: ", x[0, token_position, 0])
107            print("new token_position: ", token_position)
108
109        if_flip_img_sequences = False
110        if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
111            x = x.flip([1])
112            if_flip_img_sequences = True
113
114        # mamba impl
115        residual = None
116        hidden_states = x
117        if not self.if_bidirectional:
118            for layer in self.layers:
119
120                if if_flip_img_sequences and self.if_rope:
121                    hidden_states = hidden_states.flip([1])
122                    if residual is not None:
123                        residual = residual.flip([1])
124
125                # rope about
126                if self.if_rope:
127                    hidden_states = self.rope(hidden_states)
128                    if residual is not None and self.if_rope_residual:
129                        residual = self.rope(residual)
130
131                if if_flip_img_sequences and self.if_rope:
132                    hidden_states = hidden_states.flip([1])
133                    if residual is not None:
134                        residual = residual.flip([1])
135
136                hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
137        else:
138            # get two layers in a single for-loop
139            for i in range(len(self.layers) // 2):
140                if self.if_rope:
141                    hidden_states = self.rope(hidden_states)
142                    if residual is not None and self.if_rope_residual:
143                        residual = self.rope(residual)
144
145                hidden_states_f, residual_f = self.layers[i * 2](
146                    hidden_states, residual, inference_params=inference_params
147                )
148                hidden_states_b, residual_b = self.layers[i * 2 + 1](
149                    hidden_states.flip([1]),
150                    None if residual is None else residual.flip([1]),
151                    inference_params=inference_params
152                )
153                hidden_states = hidden_states_f + hidden_states_b.flip([1])
154                residual = residual_f + residual_b.flip([1])
155
156        if not self.fused_add_norm:
157            if residual is None:
158                residual = hidden_states
159            else:
160                residual = residual + self.drop_path(hidden_states)
161            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
162        else:
163            # Set prenorm = False here since we don't need the residual
164            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
165            hidden_states = fused_add_norm_fn(
166                self.drop_path(hidden_states),
167                self.norm_f.weight,
168                self.norm_f.bias,
169                eps=self.norm_f.eps,
170                residual=residual,
171                prenorm=False,
172                residual_in_fp32=self.residual_in_fp32,
173            )
174
175        if self.final_pool_type == 'none':
176            return hidden_states[:, -1, :]
177        elif self.final_pool_type == 'mean':
178            return hidden_states.mean(dim=1)
179        elif self.final_pool_type == 'max':
180            return hidden_states.max(dim=1)
181        elif self.final_pool_type == 'all':
182            return hidden_states
183        else:
184            raise NotImplementedError
185
186    def forward(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
187        x = self.forward_features(x, inference_params, if_random_cls_token_position, if_random_token_rank)
188
189        if self.final_pool_type == 'max':
190            x = x.max(dim=1)[0]
191
192        if self.if_cls_token:  # remove the class token
193            x = x[:, 1:, :]
194
195        # let's get the patches back from the 1d tokens
196        x = self.convert_to_expected_dim(x)
197
198        return x  # from here, the tokens can be upsampled easily (N x H x W x C)
ViM(**kwargs)
32    def __init__(self, **kwargs):
33        assert _have_vim_installed, "Please install 'Vim'."
34        super().__init__(**kwargs)
def convert_to_expected_dim(self, inputs_):
36    def convert_to_expected_dim(self, inputs_):
37        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
38        rdim = inputs_.shape[1]
39        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
40        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
41        inputs_ = inputs_.permute(0, 3, 1, 2)
42        return inputs_
def forward_features( self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
 44    def forward_features(
 45        self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False
 46    ):
 47        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
 48        # with slight modifications to add the dist_token
 49        x = self.patch_embed(x)
 50        B, M, _ = x.shape
 51
 52        if self.if_cls_token:
 53            if self.use_double_cls_token:
 54                cls_token_head = self.cls_token_head.expand(B, -1, -1)
 55                cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
 56                token_position = [0, M + 1]
 57                x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
 58                M = x.shape[1]
 59            else:
 60                if self.use_middle_cls_token:
 61                    cls_token = self.cls_token.expand(B, -1, -1)
 62                    token_position = M // 2
 63                    # add cls token in the middle
 64                    x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
 65                elif if_random_cls_token_position:
 66                    cls_token = self.cls_token.expand(B, -1, -1)
 67                    token_position = random.randint(0, M)
 68                    x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
 69                    print("token_position: ", token_position)
 70                else:
 71                    cls_token = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
 72                    token_position = 0
 73                    x = torch.cat((cls_token, x), dim=1)
 74                M = x.shape[1]
 75
 76        if self.if_abs_pos_embed:
 77            x = x + self.pos_embed
 78            x = self.pos_drop(x)
 79
 80        if if_random_token_rank:
 81            # general random shuffle index
 82            shuffle_indices = torch.randperm(M)
 83
 84            if isinstance(token_position, list):
 85                print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
 86            else:
 87                print("original value: ", x[0, token_position, 0])
 88            print("original token_position: ", token_position)
 89
 90            # execute shuffle
 91            x = x[:, shuffle_indices, :]
 92
 93            if isinstance(token_position, list):
 94                # find new position of cls token after shuffle
 95                new_token_position = [
 96                    torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))
 97                ]
 98                token_position = new_token_position
 99            else:
100                # find new position of cls token after the shuffle
101                token_position = torch.where(shuffle_indices == token_position)[0].item()
102
103            if isinstance(token_position, list):
104                print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
105            else:
106                print("new value: ", x[0, token_position, 0])
107            print("new token_position: ", token_position)
108
109        if_flip_img_sequences = False
110        if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
111            x = x.flip([1])
112            if_flip_img_sequences = True
113
114        # mamba impl
115        residual = None
116        hidden_states = x
117        if not self.if_bidirectional:
118            for layer in self.layers:
119
120                if if_flip_img_sequences and self.if_rope:
121                    hidden_states = hidden_states.flip([1])
122                    if residual is not None:
123                        residual = residual.flip([1])
124
125                # rope about
126                if self.if_rope:
127                    hidden_states = self.rope(hidden_states)
128                    if residual is not None and self.if_rope_residual:
129                        residual = self.rope(residual)
130
131                if if_flip_img_sequences and self.if_rope:
132                    hidden_states = hidden_states.flip([1])
133                    if residual is not None:
134                        residual = residual.flip([1])
135
136                hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
137        else:
138            # get two layers in a single for-loop
139            for i in range(len(self.layers) // 2):
140                if self.if_rope:
141                    hidden_states = self.rope(hidden_states)
142                    if residual is not None and self.if_rope_residual:
143                        residual = self.rope(residual)
144
145                hidden_states_f, residual_f = self.layers[i * 2](
146                    hidden_states, residual, inference_params=inference_params
147                )
148                hidden_states_b, residual_b = self.layers[i * 2 + 1](
149                    hidden_states.flip([1]),
150                    None if residual is None else residual.flip([1]),
151                    inference_params=inference_params
152                )
153                hidden_states = hidden_states_f + hidden_states_b.flip([1])
154                residual = residual_f + residual_b.flip([1])
155
156        if not self.fused_add_norm:
157            if residual is None:
158                residual = hidden_states
159            else:
160                residual = residual + self.drop_path(hidden_states)
161            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
162        else:
163            # Set prenorm = False here since we don't need the residual
164            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
165            hidden_states = fused_add_norm_fn(
166                self.drop_path(hidden_states),
167                self.norm_f.weight,
168                self.norm_f.bias,
169                eps=self.norm_f.eps,
170                residual=residual,
171                prenorm=False,
172                residual_in_fp32=self.residual_in_fp32,
173            )
174
175        if self.final_pool_type == 'none':
176            return hidden_states[:, -1, :]
177        elif self.final_pool_type == 'mean':
178            return hidden_states.mean(dim=1)
179        elif self.final_pool_type == 'max':
180            return hidden_states.max(dim=1)
181        elif self.final_pool_type == 'all':
182            return hidden_states
183        else:
184            raise NotImplementedError
def forward( self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
186    def forward(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
187        x = self.forward_features(x, inference_params, if_random_cls_token_position, if_random_token_rank)
188
189        if self.final_pool_type == 'max':
190            x = x.max(dim=1)[0]
191
192        if self.if_cls_token:  # remove the class token
193            x = x[:, 1:, :]
194
195        # let's get the patches back from the 1d tokens
196        x = self.convert_to_expected_dim(x)
197
198        return x  # from here, the tokens can be upsampled easily (N x H x W x C)
def get_vim_encoder(model_type='vim_t', with_cls_token=True):
201def get_vim_encoder(model_type="vim_t", with_cls_token=True):
202    if model_type == "vim_t":
203        embed_dim = 192
204    elif model_type == "vim_s":
205        embed_dim = 384
206    elif model_type == "vim_b":
207        embed_dim = 768
208    else:
209        raise ValueError("Choose from 'vim_t' / 'vim_s' / 'vim_b'")
210
211    encoder = ViM(
212        img_size=1024,
213        patch_size=16,
214        embed_dim=embed_dim,
215        depth=24,
216        rms_norm=True,
217        residual_in_fp32=True,
218        fused_add_norm=True,
219        final_pool_type='all',
220        if_abs_pos_embed=True,
221        if_rope=False,
222        if_rope_residual=False,
223        bimamba_type="v2",
224        if_cls_token=with_cls_token,
225        if_divide_out=True,
226        use_middle_cls_token=True,
227    )
228    encoder.default_cfg = _cfg()
229    return encoder
def get_vimunet_model( out_channels, model_type='vim_t', with_cls_token=True, device=None, checkpoint=None):
232def get_vimunet_model(
233    out_channels, model_type="vim_t", with_cls_token=True, device=None, checkpoint=None
234):
235    if device is None:
236        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
237
238    encoder = get_vim_encoder(model_type, with_cls_token)
239
240    model_state = None
241    if checkpoint is not None:
242        state = torch.load(checkpoint, map_location="cpu", weights_only=False)
243
244        if checkpoint.endswith(".pth"):  # from Vim
245            encoder_state = state["model"]
246            encoder.load_state_dict(encoder_state)
247
248        else:  # from torch_em
249            model_state = state["model_state"]
250
251    encoder.img_size = encoder.patch_embed.img_size[0]
252
253    # TODO: Update design so that: we have a backbone to fetch encoder and decoder flexibly
254    # and is ideally not named as "UNETR" but something as for example "EncoderDecoderNet"
255    model = UNETR(
256        encoder=encoder,
257        out_channels=out_channels,
258        resize_input=False,
259        use_skip_connection=False,
260        final_activation="Sigmoid",
261    )
262
263    if model_state is not None:
264        model.load_state_dict(model_state)
265
266    model.to(device)
267
268    return model