torch_em.model.vim
@private
1"""@private 2""" 3 4# installation from https://github.com/hustvl/Vim 5# encoder from https://github.com/hustvl/Vim 6# decoder from https://github.com/constantinpape/torch-em 7 8# pretrained model weights: vim_t - https://huggingface.co/hustvl/Vim-tiny/blob/main/vim_tiny_73p1.pth 9 10import random 11 12import torch 13 14from .unetr import UNETR 15 16try: 17 from vim.models_mamba import VisionMamba, rms_norm_fn, RMSNorm, layer_norm_fn 18 _have_vim_installed = True 19except ImportError: 20 VisionMamba = object 21 rms_norm_fn = RMSNorm = layer_norm_fn = None 22 _have_vim_installed = False 23 24try: 25 from timm.models.vision_transformer import _cfg 26except ImportError: 27 _cfg = None 28 29 30class ViM(VisionMamba): 31 def __init__(self, **kwargs): 32 assert _have_vim_installed, "Please install 'Vim'." 33 super().__init__(**kwargs) 34 35 def convert_to_expected_dim(self, inputs_): 36 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 37 rdim = inputs_.shape[1] 38 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 39 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 40 inputs_ = inputs_.permute(0, 3, 1, 2) 41 return inputs_ 42 43 def forward_features( 44 self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False 45 ): 46 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 47 # with slight modifications to add the dist_token 48 x = self.patch_embed(x) 49 B, M, _ = x.shape 50 51 if self.if_cls_token: 52 if self.use_double_cls_token: 53 cls_token_head = self.cls_token_head.expand(B, -1, -1) 54 cls_token_tail = self.cls_token_tail.expand(B, -1, -1) 55 token_position = [0, M + 1] 56 x = torch.cat((cls_token_head, x, cls_token_tail), dim=1) 57 M = x.shape[1] 58 else: 59 if self.use_middle_cls_token: 60 cls_token = self.cls_token.expand(B, -1, -1) 61 token_position = M // 2 62 # add cls token in the middle 63 x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1) 64 elif if_random_cls_token_position: 65 cls_token = self.cls_token.expand(B, -1, -1) 66 token_position = random.randint(0, M) 67 x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1) 68 print("token_position: ", token_position) 69 else: 70 cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 71 token_position = 0 72 x = torch.cat((cls_token, x), dim=1) 73 M = x.shape[1] 74 75 if self.if_abs_pos_embed: 76 x = x + self.pos_embed 77 x = self.pos_drop(x) 78 79 if if_random_token_rank: 80 # general random shuffle index 81 shuffle_indices = torch.randperm(M) 82 83 if isinstance(token_position, list): 84 print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0]) 85 else: 86 print("original value: ", x[0, token_position, 0]) 87 print("original token_position: ", token_position) 88 89 # execute shuffle 90 x = x[:, shuffle_indices, :] 91 92 if isinstance(token_position, list): 93 # find new position of cls token after shuffle 94 new_token_position = [ 95 torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position)) 96 ] 97 token_position = new_token_position 98 else: 99 # find new position of cls token after the shuffle 100 token_position = torch.where(shuffle_indices == token_position)[0].item() 101 102 if isinstance(token_position, list): 103 print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0]) 104 else: 105 print("new value: ", x[0, token_position, 0]) 106 print("new token_position: ", token_position) 107 108 if_flip_img_sequences = False 109 if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5: 110 x = x.flip([1]) 111 if_flip_img_sequences = True 112 113 # mamba impl 114 residual = None 115 hidden_states = x 116 if not self.if_bidirectional: 117 for layer in self.layers: 118 119 if if_flip_img_sequences and self.if_rope: 120 hidden_states = hidden_states.flip([1]) 121 if residual is not None: 122 residual = residual.flip([1]) 123 124 # rope about 125 if self.if_rope: 126 hidden_states = self.rope(hidden_states) 127 if residual is not None and self.if_rope_residual: 128 residual = self.rope(residual) 129 130 if if_flip_img_sequences and self.if_rope: 131 hidden_states = hidden_states.flip([1]) 132 if residual is not None: 133 residual = residual.flip([1]) 134 135 hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params) 136 else: 137 # get two layers in a single for-loop 138 for i in range(len(self.layers) // 2): 139 if self.if_rope: 140 hidden_states = self.rope(hidden_states) 141 if residual is not None and self.if_rope_residual: 142 residual = self.rope(residual) 143 144 hidden_states_f, residual_f = self.layers[i * 2]( 145 hidden_states, residual, inference_params=inference_params 146 ) 147 hidden_states_b, residual_b = self.layers[i * 2 + 1]( 148 hidden_states.flip([1]), 149 None if residual is None else residual.flip([1]), 150 inference_params=inference_params 151 ) 152 hidden_states = hidden_states_f + hidden_states_b.flip([1]) 153 residual = residual_f + residual_b.flip([1]) 154 155 if not self.fused_add_norm: 156 if residual is None: 157 residual = hidden_states 158 else: 159 residual = residual + self.drop_path(hidden_states) 160 hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 161 else: 162 # Set prenorm = False here since we don't need the residual 163 fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 164 hidden_states = fused_add_norm_fn( 165 self.drop_path(hidden_states), 166 self.norm_f.weight, 167 self.norm_f.bias, 168 eps=self.norm_f.eps, 169 residual=residual, 170 prenorm=False, 171 residual_in_fp32=self.residual_in_fp32, 172 ) 173 174 if self.final_pool_type == 'none': 175 return hidden_states[:, -1, :] 176 elif self.final_pool_type == 'mean': 177 return hidden_states.mean(dim=1) 178 elif self.final_pool_type == 'max': 179 return hidden_states.max(dim=1) 180 elif self.final_pool_type == 'all': 181 return hidden_states 182 else: 183 raise NotImplementedError 184 185 def forward(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False): 186 x = self.forward_features(x, inference_params, if_random_cls_token_position, if_random_token_rank) 187 188 if self.final_pool_type == 'max': 189 x = x.max(dim=1)[0] 190 191 if self.if_cls_token: # remove the class token 192 x = x[:, 1:, :] 193 194 # let's get the patches back from the 1d tokens 195 x = self.convert_to_expected_dim(x) 196 197 return x # from here, the tokens can be upsampled easily (N x H x W x C) 198 199 200def get_vim_encoder(model_type="vim_t", with_cls_token=True): 201 if model_type == "vim_t": 202 embed_dim = 192 203 elif model_type == "vim_s": 204 embed_dim = 384 205 elif model_type == "vim_b": 206 embed_dim = 768 207 else: 208 raise ValueError("Choose from 'vim_t' / 'vim_s' / 'vim_b'") 209 210 encoder = ViM( 211 img_size=1024, 212 patch_size=16, 213 embed_dim=embed_dim, 214 depth=24, 215 rms_norm=True, 216 residual_in_fp32=True, 217 fused_add_norm=True, 218 final_pool_type='all', 219 if_abs_pos_embed=True, 220 if_rope=False, 221 if_rope_residual=False, 222 bimamba_type="v2", 223 if_cls_token=with_cls_token, 224 if_divide_out=True, 225 use_middle_cls_token=True, 226 ) 227 encoder.default_cfg = _cfg() 228 return encoder 229 230 231def get_vimunet_model( 232 out_channels, model_type="vim_t", with_cls_token=True, device=None, checkpoint=None 233): 234 if device is None: 235 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 236 237 encoder = get_vim_encoder(model_type, with_cls_token) 238 239 model_state = None 240 if checkpoint is not None: 241 state = torch.load(checkpoint, map_location="cpu", weights_only=False) 242 243 if checkpoint.endswith(".pth"): # from Vim 244 encoder_state = state["model"] 245 encoder.load_state_dict(encoder_state) 246 247 else: # from torch_em 248 model_state = state["model_state"] 249 250 encoder.img_size = encoder.patch_embed.img_size[0] 251 252 # TODO: Update design so that: we have a backbone to fetch encoder and decoder flexibly 253 # and is ideally not named as "UNETR" but something as for example "EncoderDecoderNet" 254 model = UNETR( 255 encoder=encoder, 256 out_channels=out_channels, 257 resize_input=False, 258 use_skip_connection=False, 259 final_activation="Sigmoid", 260 ) 261 262 if model_state is not None: 263 model.load_state_dict(model_state) 264 265 model.to(device) 266 267 return model
class
ViM:
31class ViM(VisionMamba): 32 def __init__(self, **kwargs): 33 assert _have_vim_installed, "Please install 'Vim'." 34 super().__init__(**kwargs) 35 36 def convert_to_expected_dim(self, inputs_): 37 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 38 rdim = inputs_.shape[1] 39 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 40 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 41 inputs_ = inputs_.permute(0, 3, 1, 2) 42 return inputs_ 43 44 def forward_features( 45 self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False 46 ): 47 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 48 # with slight modifications to add the dist_token 49 x = self.patch_embed(x) 50 B, M, _ = x.shape 51 52 if self.if_cls_token: 53 if self.use_double_cls_token: 54 cls_token_head = self.cls_token_head.expand(B, -1, -1) 55 cls_token_tail = self.cls_token_tail.expand(B, -1, -1) 56 token_position = [0, M + 1] 57 x = torch.cat((cls_token_head, x, cls_token_tail), dim=1) 58 M = x.shape[1] 59 else: 60 if self.use_middle_cls_token: 61 cls_token = self.cls_token.expand(B, -1, -1) 62 token_position = M // 2 63 # add cls token in the middle 64 x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1) 65 elif if_random_cls_token_position: 66 cls_token = self.cls_token.expand(B, -1, -1) 67 token_position = random.randint(0, M) 68 x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1) 69 print("token_position: ", token_position) 70 else: 71 cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 72 token_position = 0 73 x = torch.cat((cls_token, x), dim=1) 74 M = x.shape[1] 75 76 if self.if_abs_pos_embed: 77 x = x + self.pos_embed 78 x = self.pos_drop(x) 79 80 if if_random_token_rank: 81 # general random shuffle index 82 shuffle_indices = torch.randperm(M) 83 84 if isinstance(token_position, list): 85 print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0]) 86 else: 87 print("original value: ", x[0, token_position, 0]) 88 print("original token_position: ", token_position) 89 90 # execute shuffle 91 x = x[:, shuffle_indices, :] 92 93 if isinstance(token_position, list): 94 # find new position of cls token after shuffle 95 new_token_position = [ 96 torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position)) 97 ] 98 token_position = new_token_position 99 else: 100 # find new position of cls token after the shuffle 101 token_position = torch.where(shuffle_indices == token_position)[0].item() 102 103 if isinstance(token_position, list): 104 print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0]) 105 else: 106 print("new value: ", x[0, token_position, 0]) 107 print("new token_position: ", token_position) 108 109 if_flip_img_sequences = False 110 if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5: 111 x = x.flip([1]) 112 if_flip_img_sequences = True 113 114 # mamba impl 115 residual = None 116 hidden_states = x 117 if not self.if_bidirectional: 118 for layer in self.layers: 119 120 if if_flip_img_sequences and self.if_rope: 121 hidden_states = hidden_states.flip([1]) 122 if residual is not None: 123 residual = residual.flip([1]) 124 125 # rope about 126 if self.if_rope: 127 hidden_states = self.rope(hidden_states) 128 if residual is not None and self.if_rope_residual: 129 residual = self.rope(residual) 130 131 if if_flip_img_sequences and self.if_rope: 132 hidden_states = hidden_states.flip([1]) 133 if residual is not None: 134 residual = residual.flip([1]) 135 136 hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params) 137 else: 138 # get two layers in a single for-loop 139 for i in range(len(self.layers) // 2): 140 if self.if_rope: 141 hidden_states = self.rope(hidden_states) 142 if residual is not None and self.if_rope_residual: 143 residual = self.rope(residual) 144 145 hidden_states_f, residual_f = self.layers[i * 2]( 146 hidden_states, residual, inference_params=inference_params 147 ) 148 hidden_states_b, residual_b = self.layers[i * 2 + 1]( 149 hidden_states.flip([1]), 150 None if residual is None else residual.flip([1]), 151 inference_params=inference_params 152 ) 153 hidden_states = hidden_states_f + hidden_states_b.flip([1]) 154 residual = residual_f + residual_b.flip([1]) 155 156 if not self.fused_add_norm: 157 if residual is None: 158 residual = hidden_states 159 else: 160 residual = residual + self.drop_path(hidden_states) 161 hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 162 else: 163 # Set prenorm = False here since we don't need the residual 164 fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 165 hidden_states = fused_add_norm_fn( 166 self.drop_path(hidden_states), 167 self.norm_f.weight, 168 self.norm_f.bias, 169 eps=self.norm_f.eps, 170 residual=residual, 171 prenorm=False, 172 residual_in_fp32=self.residual_in_fp32, 173 ) 174 175 if self.final_pool_type == 'none': 176 return hidden_states[:, -1, :] 177 elif self.final_pool_type == 'mean': 178 return hidden_states.mean(dim=1) 179 elif self.final_pool_type == 'max': 180 return hidden_states.max(dim=1) 181 elif self.final_pool_type == 'all': 182 return hidden_states 183 else: 184 raise NotImplementedError 185 186 def forward(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False): 187 x = self.forward_features(x, inference_params, if_random_cls_token_position, if_random_token_rank) 188 189 if self.final_pool_type == 'max': 190 x = x.max(dim=1)[0] 191 192 if self.if_cls_token: # remove the class token 193 x = x[:, 1:, :] 194 195 # let's get the patches back from the 1d tokens 196 x = self.convert_to_expected_dim(x) 197 198 return x # from here, the tokens can be upsampled easily (N x H x W x C)
def
convert_to_expected_dim(self, inputs_):
36 def convert_to_expected_dim(self, inputs_): 37 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 38 rdim = inputs_.shape[1] 39 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 40 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 41 inputs_ = inputs_.permute(0, 3, 1, 2) 42 return inputs_
def
forward_features( self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
44 def forward_features( 45 self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False 46 ): 47 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 48 # with slight modifications to add the dist_token 49 x = self.patch_embed(x) 50 B, M, _ = x.shape 51 52 if self.if_cls_token: 53 if self.use_double_cls_token: 54 cls_token_head = self.cls_token_head.expand(B, -1, -1) 55 cls_token_tail = self.cls_token_tail.expand(B, -1, -1) 56 token_position = [0, M + 1] 57 x = torch.cat((cls_token_head, x, cls_token_tail), dim=1) 58 M = x.shape[1] 59 else: 60 if self.use_middle_cls_token: 61 cls_token = self.cls_token.expand(B, -1, -1) 62 token_position = M // 2 63 # add cls token in the middle 64 x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1) 65 elif if_random_cls_token_position: 66 cls_token = self.cls_token.expand(B, -1, -1) 67 token_position = random.randint(0, M) 68 x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1) 69 print("token_position: ", token_position) 70 else: 71 cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 72 token_position = 0 73 x = torch.cat((cls_token, x), dim=1) 74 M = x.shape[1] 75 76 if self.if_abs_pos_embed: 77 x = x + self.pos_embed 78 x = self.pos_drop(x) 79 80 if if_random_token_rank: 81 # general random shuffle index 82 shuffle_indices = torch.randperm(M) 83 84 if isinstance(token_position, list): 85 print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0]) 86 else: 87 print("original value: ", x[0, token_position, 0]) 88 print("original token_position: ", token_position) 89 90 # execute shuffle 91 x = x[:, shuffle_indices, :] 92 93 if isinstance(token_position, list): 94 # find new position of cls token after shuffle 95 new_token_position = [ 96 torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position)) 97 ] 98 token_position = new_token_position 99 else: 100 # find new position of cls token after the shuffle 101 token_position = torch.where(shuffle_indices == token_position)[0].item() 102 103 if isinstance(token_position, list): 104 print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0]) 105 else: 106 print("new value: ", x[0, token_position, 0]) 107 print("new token_position: ", token_position) 108 109 if_flip_img_sequences = False 110 if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5: 111 x = x.flip([1]) 112 if_flip_img_sequences = True 113 114 # mamba impl 115 residual = None 116 hidden_states = x 117 if not self.if_bidirectional: 118 for layer in self.layers: 119 120 if if_flip_img_sequences and self.if_rope: 121 hidden_states = hidden_states.flip([1]) 122 if residual is not None: 123 residual = residual.flip([1]) 124 125 # rope about 126 if self.if_rope: 127 hidden_states = self.rope(hidden_states) 128 if residual is not None and self.if_rope_residual: 129 residual = self.rope(residual) 130 131 if if_flip_img_sequences and self.if_rope: 132 hidden_states = hidden_states.flip([1]) 133 if residual is not None: 134 residual = residual.flip([1]) 135 136 hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params) 137 else: 138 # get two layers in a single for-loop 139 for i in range(len(self.layers) // 2): 140 if self.if_rope: 141 hidden_states = self.rope(hidden_states) 142 if residual is not None and self.if_rope_residual: 143 residual = self.rope(residual) 144 145 hidden_states_f, residual_f = self.layers[i * 2]( 146 hidden_states, residual, inference_params=inference_params 147 ) 148 hidden_states_b, residual_b = self.layers[i * 2 + 1]( 149 hidden_states.flip([1]), 150 None if residual is None else residual.flip([1]), 151 inference_params=inference_params 152 ) 153 hidden_states = hidden_states_f + hidden_states_b.flip([1]) 154 residual = residual_f + residual_b.flip([1]) 155 156 if not self.fused_add_norm: 157 if residual is None: 158 residual = hidden_states 159 else: 160 residual = residual + self.drop_path(hidden_states) 161 hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 162 else: 163 # Set prenorm = False here since we don't need the residual 164 fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 165 hidden_states = fused_add_norm_fn( 166 self.drop_path(hidden_states), 167 self.norm_f.weight, 168 self.norm_f.bias, 169 eps=self.norm_f.eps, 170 residual=residual, 171 prenorm=False, 172 residual_in_fp32=self.residual_in_fp32, 173 ) 174 175 if self.final_pool_type == 'none': 176 return hidden_states[:, -1, :] 177 elif self.final_pool_type == 'mean': 178 return hidden_states.mean(dim=1) 179 elif self.final_pool_type == 'max': 180 return hidden_states.max(dim=1) 181 elif self.final_pool_type == 'all': 182 return hidden_states 183 else: 184 raise NotImplementedError
def
forward( self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
186 def forward(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False): 187 x = self.forward_features(x, inference_params, if_random_cls_token_position, if_random_token_rank) 188 189 if self.final_pool_type == 'max': 190 x = x.max(dim=1)[0] 191 192 if self.if_cls_token: # remove the class token 193 x = x[:, 1:, :] 194 195 # let's get the patches back from the 1d tokens 196 x = self.convert_to_expected_dim(x) 197 198 return x # from here, the tokens can be upsampled easily (N x H x W x C)
def
get_vim_encoder(model_type='vim_t', with_cls_token=True):
201def get_vim_encoder(model_type="vim_t", with_cls_token=True): 202 if model_type == "vim_t": 203 embed_dim = 192 204 elif model_type == "vim_s": 205 embed_dim = 384 206 elif model_type == "vim_b": 207 embed_dim = 768 208 else: 209 raise ValueError("Choose from 'vim_t' / 'vim_s' / 'vim_b'") 210 211 encoder = ViM( 212 img_size=1024, 213 patch_size=16, 214 embed_dim=embed_dim, 215 depth=24, 216 rms_norm=True, 217 residual_in_fp32=True, 218 fused_add_norm=True, 219 final_pool_type='all', 220 if_abs_pos_embed=True, 221 if_rope=False, 222 if_rope_residual=False, 223 bimamba_type="v2", 224 if_cls_token=with_cls_token, 225 if_divide_out=True, 226 use_middle_cls_token=True, 227 ) 228 encoder.default_cfg = _cfg() 229 return encoder
def
get_vimunet_model( out_channels, model_type='vim_t', with_cls_token=True, device=None, checkpoint=None):
232def get_vimunet_model( 233 out_channels, model_type="vim_t", with_cls_token=True, device=None, checkpoint=None 234): 235 if device is None: 236 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 237 238 encoder = get_vim_encoder(model_type, with_cls_token) 239 240 model_state = None 241 if checkpoint is not None: 242 state = torch.load(checkpoint, map_location="cpu", weights_only=False) 243 244 if checkpoint.endswith(".pth"): # from Vim 245 encoder_state = state["model"] 246 encoder.load_state_dict(encoder_state) 247 248 else: # from torch_em 249 model_state = state["model_state"] 250 251 encoder.img_size = encoder.patch_embed.img_size[0] 252 253 # TODO: Update design so that: we have a backbone to fetch encoder and decoder flexibly 254 # and is ideally not named as "UNETR" but something as for example "EncoderDecoderNet" 255 model = UNETR( 256 encoder=encoder, 257 out_channels=out_channels, 258 resize_input=False, 259 use_skip_connection=False, 260 final_activation="Sigmoid", 261 ) 262 263 if model_state is not None: 264 model.load_state_dict(model_state) 265 266 model.to(device) 267 268 return model