torch_em.model.vim
1# installation from https://github.com/hustvl/Vim 2# encoder from https://github.com/hustvl/Vim 3# decoder from https://github.com/constantinpape/torch-em 4 5# pretrained model weights: vim_t - https://huggingface.co/hustvl/Vim-tiny/blob/main/vim_tiny_73p1.pth 6 7import torch 8 9from .unetr import UNETR 10 11try: 12 from vim.models_mamba import VisionMamba, rms_norm_fn, RMSNorm, layer_norm_fn 13 _have_vim_installed = True 14except ImportError: 15 VisionMamba = object 16 rms_norm_fn = RMSNorm = layer_norm_fn = None 17 _have_vim_installed = False 18 19try: 20 from timm.models.vision_transformer import _cfg 21except ImportError: 22 _cfg = None 23 24 25class ViM(VisionMamba): 26 def __init__( 27 self, 28 **kwargs 29 ): 30 assert _have_vim_installed, "Please install Vim." 31 super().__init__(**kwargs) 32 33 def convert_to_expected_dim(self, inputs_): 34 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 35 rdim = inputs_.shape[1] 36 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 37 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 38 inputs_ = inputs_.permute(0, 3, 1, 2) 39 return inputs_ 40 41 def forward_features(self, x, inference_params=None): 42 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 43 # with slight modifications to add the dist_token 44 x = self.patch_embed(x) 45 if self.if_cls_token: 46 cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 47 x = torch.cat((cls_token, x), dim=1) 48 49 if self.if_abs_pos_embed: 50 x = x + self.pos_embed 51 x = self.pos_drop(x) 52 53 # mamba implementation 54 residual = None 55 hidden_states = x 56 for layer in self.layers: 57 # rope about 58 if self.if_rope: 59 hidden_states = self.rope(hidden_states) 60 if residual is not None and self.if_rope_residual: 61 residual = self.rope(residual) 62 63 hidden_states, residual = layer( 64 hidden_states, residual, inference_params=inference_params 65 ) 66 67 if not self.fused_add_norm: 68 if residual is None: 69 residual = hidden_states 70 else: 71 residual = residual + self.drop_path(hidden_states) 72 hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 73 else: 74 # Set prenorm = False here since we don't need the residual 75 fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 76 hidden_states = fused_add_norm_fn( 77 self.drop_path(hidden_states), 78 self.norm_f.weight, 79 self.norm_f.bias, 80 eps=self.norm_f.eps, 81 residual=residual, 82 prenorm=False, 83 residual_in_fp32=self.residual_in_fp32, 84 ) 85 86 if self.final_pool_type == 'none': 87 return hidden_states[:, -1, :] 88 elif self.final_pool_type == 'mean': 89 return hidden_states.mean(dim=1) 90 elif self.final_pool_type == 'max': 91 return hidden_states.max(dim=1) 92 elif self.final_pool_type == 'all': 93 return hidden_states 94 else: 95 raise NotImplementedError 96 97 def forward(self, x, inference_params=None): 98 x = self.forward_features(x, inference_params) 99 100 if self.if_cls_token: # remove the class token 101 x = x[:, 1:, :] 102 103 # let's get the patches back from the 1d tokens 104 x = self.convert_to_expected_dim(x) 105 106 return x # from here, the tokens can be upsampled easily (N x H x W x C) 107 108 109def get_vim_encoder(model_type="vim_t", with_cls_token=True): 110 if model_type == "vim_t": 111 # `vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual_with_cls_token` 112 # *has an imagenet pretrained model 113 encoder = ViM( 114 img_size=1024, 115 patch_size=16, 116 embed_dim=192, 117 depth=24, 118 rms_norm=True, 119 residual_in_fp32=True, 120 fused_add_norm=True, 121 final_pool_type='all', 122 if_abs_pos_embed=True, 123 if_rope=True, 124 if_rope_residual=True, 125 bimamba_type="v2", 126 if_cls_token=with_cls_token, 127 ) 128 elif model_type == "vim_s": 129 # `vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual` 130 # AA: added a class token to the default models 131 encoder = ViM( 132 img_size=1024, 133 patch_size=16, 134 embed_dim=384, 135 depth=24, 136 rms_norm=True, 137 residual_in_fp32=True, 138 fused_add_norm=True, 139 final_pool_type='all', 140 if_abs_pos_embed=True, 141 if_rope=True, 142 if_rope_residual=True, 143 bimamba_type="v2", 144 if_cls_token=with_cls_token, 145 ) 146 elif model_type == "vim_b": 147 # `vim_base_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual` 148 # AA: added a class token to the default models 149 encoder = ViM( 150 img_size=1024, 151 patch_size=16, 152 embed_dim=768, 153 depth=24, 154 rms_norm=True, 155 residual_in_fp32=True, 156 fused_add_norm=True, 157 final_pool_type='all', 158 if_abs_pos_embed=True, 159 if_rope=True, 160 if_rope_residual=True, 161 bimamba_type="v2", 162 if_cls_token=with_cls_token, 163 ) 164 else: 165 raise ValueError("Choose from 'vim_t' / 'vim_s' / 'vim_b'") 166 167 encoder.default_cfg = _cfg() 168 return encoder 169 170 171def get_vimunet_model( 172 out_channels, model_type="vim_t", with_cls_token=True, device=None, checkpoint=None 173): 174 if device is None: 175 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 176 177 encoder = get_vim_encoder(model_type, with_cls_token) 178 179 model_state = None 180 if checkpoint is not None: 181 state = torch.load(checkpoint, map_location="cpu") 182 183 if checkpoint.endswith(".pth"): # from Vim 184 encoder_state = state["model"] 185 encoder.load_state_dict(encoder_state) 186 187 else: # from torch_em 188 model_state = state["model_state"] 189 190 encoder.img_size = encoder.patch_embed.img_size[0] 191 192 model = UNETR( 193 encoder=encoder, 194 out_channels=out_channels, 195 resize_input=False, 196 use_skip_connection=False, 197 final_activation="Sigmoid", 198 ) 199 200 if model_state is not None: 201 model.load_state_dict(model_state) 202 203 model.to(device) 204 205 return model
class
ViM:
26class ViM(VisionMamba): 27 def __init__( 28 self, 29 **kwargs 30 ): 31 assert _have_vim_installed, "Please install Vim." 32 super().__init__(**kwargs) 33 34 def convert_to_expected_dim(self, inputs_): 35 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 36 rdim = inputs_.shape[1] 37 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 38 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 39 inputs_ = inputs_.permute(0, 3, 1, 2) 40 return inputs_ 41 42 def forward_features(self, x, inference_params=None): 43 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 44 # with slight modifications to add the dist_token 45 x = self.patch_embed(x) 46 if self.if_cls_token: 47 cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 48 x = torch.cat((cls_token, x), dim=1) 49 50 if self.if_abs_pos_embed: 51 x = x + self.pos_embed 52 x = self.pos_drop(x) 53 54 # mamba implementation 55 residual = None 56 hidden_states = x 57 for layer in self.layers: 58 # rope about 59 if self.if_rope: 60 hidden_states = self.rope(hidden_states) 61 if residual is not None and self.if_rope_residual: 62 residual = self.rope(residual) 63 64 hidden_states, residual = layer( 65 hidden_states, residual, inference_params=inference_params 66 ) 67 68 if not self.fused_add_norm: 69 if residual is None: 70 residual = hidden_states 71 else: 72 residual = residual + self.drop_path(hidden_states) 73 hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 74 else: 75 # Set prenorm = False here since we don't need the residual 76 fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 77 hidden_states = fused_add_norm_fn( 78 self.drop_path(hidden_states), 79 self.norm_f.weight, 80 self.norm_f.bias, 81 eps=self.norm_f.eps, 82 residual=residual, 83 prenorm=False, 84 residual_in_fp32=self.residual_in_fp32, 85 ) 86 87 if self.final_pool_type == 'none': 88 return hidden_states[:, -1, :] 89 elif self.final_pool_type == 'mean': 90 return hidden_states.mean(dim=1) 91 elif self.final_pool_type == 'max': 92 return hidden_states.max(dim=1) 93 elif self.final_pool_type == 'all': 94 return hidden_states 95 else: 96 raise NotImplementedError 97 98 def forward(self, x, inference_params=None): 99 x = self.forward_features(x, inference_params) 100 101 if self.if_cls_token: # remove the class token 102 x = x[:, 1:, :] 103 104 # let's get the patches back from the 1d tokens 105 x = self.convert_to_expected_dim(x) 106 107 return x # from here, the tokens can be upsampled easily (N x H x W x C)
def
convert_to_expected_dim(self, inputs_):
34 def convert_to_expected_dim(self, inputs_): 35 # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) 36 rdim = inputs_.shape[1] 37 dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape 38 inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) 39 inputs_ = inputs_.permute(0, 3, 1, 2) 40 return inputs_
def
forward_features(self, x, inference_params=None):
42 def forward_features(self, x, inference_params=None): 43 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 44 # with slight modifications to add the dist_token 45 x = self.patch_embed(x) 46 if self.if_cls_token: 47 cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 48 x = torch.cat((cls_token, x), dim=1) 49 50 if self.if_abs_pos_embed: 51 x = x + self.pos_embed 52 x = self.pos_drop(x) 53 54 # mamba implementation 55 residual = None 56 hidden_states = x 57 for layer in self.layers: 58 # rope about 59 if self.if_rope: 60 hidden_states = self.rope(hidden_states) 61 if residual is not None and self.if_rope_residual: 62 residual = self.rope(residual) 63 64 hidden_states, residual = layer( 65 hidden_states, residual, inference_params=inference_params 66 ) 67 68 if not self.fused_add_norm: 69 if residual is None: 70 residual = hidden_states 71 else: 72 residual = residual + self.drop_path(hidden_states) 73 hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 74 else: 75 # Set prenorm = False here since we don't need the residual 76 fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 77 hidden_states = fused_add_norm_fn( 78 self.drop_path(hidden_states), 79 self.norm_f.weight, 80 self.norm_f.bias, 81 eps=self.norm_f.eps, 82 residual=residual, 83 prenorm=False, 84 residual_in_fp32=self.residual_in_fp32, 85 ) 86 87 if self.final_pool_type == 'none': 88 return hidden_states[:, -1, :] 89 elif self.final_pool_type == 'mean': 90 return hidden_states.mean(dim=1) 91 elif self.final_pool_type == 'max': 92 return hidden_states.max(dim=1) 93 elif self.final_pool_type == 'all': 94 return hidden_states 95 else: 96 raise NotImplementedError
def
forward(self, x, inference_params=None):
98 def forward(self, x, inference_params=None): 99 x = self.forward_features(x, inference_params) 100 101 if self.if_cls_token: # remove the class token 102 x = x[:, 1:, :] 103 104 # let's get the patches back from the 1d tokens 105 x = self.convert_to_expected_dim(x) 106 107 return x # from here, the tokens can be upsampled easily (N x H x W x C)
def
get_vim_encoder(model_type='vim_t', with_cls_token=True):
110def get_vim_encoder(model_type="vim_t", with_cls_token=True): 111 if model_type == "vim_t": 112 # `vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual_with_cls_token` 113 # *has an imagenet pretrained model 114 encoder = ViM( 115 img_size=1024, 116 patch_size=16, 117 embed_dim=192, 118 depth=24, 119 rms_norm=True, 120 residual_in_fp32=True, 121 fused_add_norm=True, 122 final_pool_type='all', 123 if_abs_pos_embed=True, 124 if_rope=True, 125 if_rope_residual=True, 126 bimamba_type="v2", 127 if_cls_token=with_cls_token, 128 ) 129 elif model_type == "vim_s": 130 # `vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual` 131 # AA: added a class token to the default models 132 encoder = ViM( 133 img_size=1024, 134 patch_size=16, 135 embed_dim=384, 136 depth=24, 137 rms_norm=True, 138 residual_in_fp32=True, 139 fused_add_norm=True, 140 final_pool_type='all', 141 if_abs_pos_embed=True, 142 if_rope=True, 143 if_rope_residual=True, 144 bimamba_type="v2", 145 if_cls_token=with_cls_token, 146 ) 147 elif model_type == "vim_b": 148 # `vim_base_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual` 149 # AA: added a class token to the default models 150 encoder = ViM( 151 img_size=1024, 152 patch_size=16, 153 embed_dim=768, 154 depth=24, 155 rms_norm=True, 156 residual_in_fp32=True, 157 fused_add_norm=True, 158 final_pool_type='all', 159 if_abs_pos_embed=True, 160 if_rope=True, 161 if_rope_residual=True, 162 bimamba_type="v2", 163 if_cls_token=with_cls_token, 164 ) 165 else: 166 raise ValueError("Choose from 'vim_t' / 'vim_s' / 'vim_b'") 167 168 encoder.default_cfg = _cfg() 169 return encoder
def
get_vimunet_model( out_channels, model_type='vim_t', with_cls_token=True, device=None, checkpoint=None):
172def get_vimunet_model( 173 out_channels, model_type="vim_t", with_cls_token=True, device=None, checkpoint=None 174): 175 if device is None: 176 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 177 178 encoder = get_vim_encoder(model_type, with_cls_token) 179 180 model_state = None 181 if checkpoint is not None: 182 state = torch.load(checkpoint, map_location="cpu") 183 184 if checkpoint.endswith(".pth"): # from Vim 185 encoder_state = state["model"] 186 encoder.load_state_dict(encoder_state) 187 188 else: # from torch_em 189 model_state = state["model_state"] 190 191 encoder.img_size = encoder.patch_embed.img_size[0] 192 193 model = UNETR( 194 encoder=encoder, 195 out_channels=out_channels, 196 resize_input=False, 197 use_skip_connection=False, 198 final_activation="Sigmoid", 199 ) 200 201 if model_state is not None: 202 model.load_state_dict(model_state) 203 204 model.to(device) 205 206 return model