torch_em.model.vim

  1# installation from https://github.com/hustvl/Vim
  2# encoder from https://github.com/hustvl/Vim
  3# decoder from https://github.com/constantinpape/torch-em
  4
  5# pretrained model weights: vim_t - https://huggingface.co/hustvl/Vim-tiny/blob/main/vim_tiny_73p1.pth
  6
  7import torch
  8
  9from .unetr import UNETR
 10
 11try:
 12    from vim.models_mamba import VisionMamba, rms_norm_fn, RMSNorm, layer_norm_fn
 13    _have_vim_installed = True
 14except ImportError:
 15    VisionMamba = object
 16    rms_norm_fn = RMSNorm = layer_norm_fn = None
 17    _have_vim_installed = False
 18
 19try:
 20    from timm.models.vision_transformer import _cfg
 21except ImportError:
 22    _cfg = None
 23
 24
 25class ViM(VisionMamba):
 26    def __init__(
 27        self,
 28        **kwargs
 29    ):
 30        assert _have_vim_installed, "Please install Vim."
 31        super().__init__(**kwargs)
 32
 33    def convert_to_expected_dim(self, inputs_):
 34        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
 35        rdim = inputs_.shape[1]
 36        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
 37        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
 38        inputs_ = inputs_.permute(0, 3, 1, 2)
 39        return inputs_
 40
 41    def forward_features(self, x, inference_params=None):
 42        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
 43        # with slight modifications to add the dist_token
 44        x = self.patch_embed(x)
 45        if self.if_cls_token:
 46            cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
 47            x = torch.cat((cls_token, x), dim=1)
 48
 49        if self.if_abs_pos_embed:
 50            x = x + self.pos_embed
 51            x = self.pos_drop(x)
 52
 53        # mamba implementation
 54        residual = None
 55        hidden_states = x
 56        for layer in self.layers:
 57            # rope about
 58            if self.if_rope:
 59                hidden_states = self.rope(hidden_states)
 60                if residual is not None and self.if_rope_residual:
 61                    residual = self.rope(residual)
 62
 63            hidden_states, residual = layer(
 64                hidden_states, residual, inference_params=inference_params
 65            )
 66
 67        if not self.fused_add_norm:
 68            if residual is None:
 69                residual = hidden_states
 70            else:
 71                residual = residual + self.drop_path(hidden_states)
 72            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
 73        else:
 74            # Set prenorm = False here since we don't need the residual
 75            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
 76            hidden_states = fused_add_norm_fn(
 77                self.drop_path(hidden_states),
 78                self.norm_f.weight,
 79                self.norm_f.bias,
 80                eps=self.norm_f.eps,
 81                residual=residual,
 82                prenorm=False,
 83                residual_in_fp32=self.residual_in_fp32,
 84            )
 85
 86        if self.final_pool_type == 'none':
 87            return hidden_states[:, -1, :]
 88        elif self.final_pool_type == 'mean':
 89            return hidden_states.mean(dim=1)
 90        elif self.final_pool_type == 'max':
 91            return hidden_states.max(dim=1)
 92        elif self.final_pool_type == 'all':
 93            return hidden_states
 94        else:
 95            raise NotImplementedError
 96
 97    def forward(self, x, inference_params=None):
 98        x = self.forward_features(x, inference_params)
 99
100        if self.if_cls_token:  # remove the class token
101            x = x[:, 1:, :]
102
103        # let's get the patches back from the 1d tokens
104        x = self.convert_to_expected_dim(x)
105
106        return x  # from here, the tokens can be upsampled easily (N x H x W x C)
107
108
109def get_vim_encoder(model_type="vim_t", with_cls_token=True):
110    if model_type == "vim_t":
111        # `vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual_with_cls_token`
112        # *has an imagenet pretrained model
113        encoder = ViM(
114            img_size=1024,
115            patch_size=16,
116            embed_dim=192,
117            depth=24,
118            rms_norm=True,
119            residual_in_fp32=True,
120            fused_add_norm=True,
121            final_pool_type='all',
122            if_abs_pos_embed=True,
123            if_rope=True,
124            if_rope_residual=True,
125            bimamba_type="v2",
126            if_cls_token=with_cls_token,
127        )
128    elif model_type == "vim_s":
129        # `vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual`
130        # AA: added a class token to the default models
131        encoder = ViM(
132            img_size=1024,
133            patch_size=16,
134            embed_dim=384,
135            depth=24,
136            rms_norm=True,
137            residual_in_fp32=True,
138            fused_add_norm=True,
139            final_pool_type='all',
140            if_abs_pos_embed=True,
141            if_rope=True,
142            if_rope_residual=True,
143            bimamba_type="v2",
144            if_cls_token=with_cls_token,
145        )
146    elif model_type == "vim_b":
147        # `vim_base_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual`
148        # AA: added a class token to the default models
149        encoder = ViM(
150            img_size=1024,
151            patch_size=16,
152            embed_dim=768,
153            depth=24,
154            rms_norm=True,
155            residual_in_fp32=True,
156            fused_add_norm=True,
157            final_pool_type='all',
158            if_abs_pos_embed=True,
159            if_rope=True,
160            if_rope_residual=True,
161            bimamba_type="v2",
162            if_cls_token=with_cls_token,
163        )
164    else:
165        raise ValueError("Choose from 'vim_t' / 'vim_s' / 'vim_b'")
166
167    encoder.default_cfg = _cfg()
168    return encoder
169
170
171def get_vimunet_model(
172    out_channels, model_type="vim_t", with_cls_token=True, device=None, checkpoint=None
173):
174    if device is None:
175        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
176
177    encoder = get_vim_encoder(model_type, with_cls_token)
178
179    model_state = None
180    if checkpoint is not None:
181        state = torch.load(checkpoint, map_location="cpu")
182
183        if checkpoint.endswith(".pth"):  # from Vim
184            encoder_state = state["model"]
185            encoder.load_state_dict(encoder_state)
186
187        else:  # from torch_em
188            model_state = state["model_state"]
189
190    encoder.img_size = encoder.patch_embed.img_size[0]
191
192    model = UNETR(
193        encoder=encoder,
194        out_channels=out_channels,
195        resize_input=False,
196        use_skip_connection=False,
197        final_activation="Sigmoid",
198    )
199
200    if model_state is not None:
201        model.load_state_dict(model_state)
202
203    model.to(device)
204
205    return model
class ViM:
 26class ViM(VisionMamba):
 27    def __init__(
 28        self,
 29        **kwargs
 30    ):
 31        assert _have_vim_installed, "Please install Vim."
 32        super().__init__(**kwargs)
 33
 34    def convert_to_expected_dim(self, inputs_):
 35        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
 36        rdim = inputs_.shape[1]
 37        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
 38        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
 39        inputs_ = inputs_.permute(0, 3, 1, 2)
 40        return inputs_
 41
 42    def forward_features(self, x, inference_params=None):
 43        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
 44        # with slight modifications to add the dist_token
 45        x = self.patch_embed(x)
 46        if self.if_cls_token:
 47            cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
 48            x = torch.cat((cls_token, x), dim=1)
 49
 50        if self.if_abs_pos_embed:
 51            x = x + self.pos_embed
 52            x = self.pos_drop(x)
 53
 54        # mamba implementation
 55        residual = None
 56        hidden_states = x
 57        for layer in self.layers:
 58            # rope about
 59            if self.if_rope:
 60                hidden_states = self.rope(hidden_states)
 61                if residual is not None and self.if_rope_residual:
 62                    residual = self.rope(residual)
 63
 64            hidden_states, residual = layer(
 65                hidden_states, residual, inference_params=inference_params
 66            )
 67
 68        if not self.fused_add_norm:
 69            if residual is None:
 70                residual = hidden_states
 71            else:
 72                residual = residual + self.drop_path(hidden_states)
 73            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
 74        else:
 75            # Set prenorm = False here since we don't need the residual
 76            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
 77            hidden_states = fused_add_norm_fn(
 78                self.drop_path(hidden_states),
 79                self.norm_f.weight,
 80                self.norm_f.bias,
 81                eps=self.norm_f.eps,
 82                residual=residual,
 83                prenorm=False,
 84                residual_in_fp32=self.residual_in_fp32,
 85            )
 86
 87        if self.final_pool_type == 'none':
 88            return hidden_states[:, -1, :]
 89        elif self.final_pool_type == 'mean':
 90            return hidden_states.mean(dim=1)
 91        elif self.final_pool_type == 'max':
 92            return hidden_states.max(dim=1)
 93        elif self.final_pool_type == 'all':
 94            return hidden_states
 95        else:
 96            raise NotImplementedError
 97
 98    def forward(self, x, inference_params=None):
 99        x = self.forward_features(x, inference_params)
100
101        if self.if_cls_token:  # remove the class token
102            x = x[:, 1:, :]
103
104        # let's get the patches back from the 1d tokens
105        x = self.convert_to_expected_dim(x)
106
107        return x  # from here, the tokens can be upsampled easily (N x H x W x C)
ViM(**kwargs)
27    def __init__(
28        self,
29        **kwargs
30    ):
31        assert _have_vim_installed, "Please install Vim."
32        super().__init__(**kwargs)
def convert_to_expected_dim(self, inputs_):
34    def convert_to_expected_dim(self, inputs_):
35        # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C)
36        rdim = inputs_.shape[1]
37        dshape = int(rdim ** 0.5)  # finding the square root of the outputs for obtaining the patch shape
38        inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
39        inputs_ = inputs_.permute(0, 3, 1, 2)
40        return inputs_
def forward_features(self, x, inference_params=None):
42    def forward_features(self, x, inference_params=None):
43        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
44        # with slight modifications to add the dist_token
45        x = self.patch_embed(x)
46        if self.if_cls_token:
47            cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
48            x = torch.cat((cls_token, x), dim=1)
49
50        if self.if_abs_pos_embed:
51            x = x + self.pos_embed
52            x = self.pos_drop(x)
53
54        # mamba implementation
55        residual = None
56        hidden_states = x
57        for layer in self.layers:
58            # rope about
59            if self.if_rope:
60                hidden_states = self.rope(hidden_states)
61                if residual is not None and self.if_rope_residual:
62                    residual = self.rope(residual)
63
64            hidden_states, residual = layer(
65                hidden_states, residual, inference_params=inference_params
66            )
67
68        if not self.fused_add_norm:
69            if residual is None:
70                residual = hidden_states
71            else:
72                residual = residual + self.drop_path(hidden_states)
73            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
74        else:
75            # Set prenorm = False here since we don't need the residual
76            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
77            hidden_states = fused_add_norm_fn(
78                self.drop_path(hidden_states),
79                self.norm_f.weight,
80                self.norm_f.bias,
81                eps=self.norm_f.eps,
82                residual=residual,
83                prenorm=False,
84                residual_in_fp32=self.residual_in_fp32,
85            )
86
87        if self.final_pool_type == 'none':
88            return hidden_states[:, -1, :]
89        elif self.final_pool_type == 'mean':
90            return hidden_states.mean(dim=1)
91        elif self.final_pool_type == 'max':
92            return hidden_states.max(dim=1)
93        elif self.final_pool_type == 'all':
94            return hidden_states
95        else:
96            raise NotImplementedError
def forward(self, x, inference_params=None):
 98    def forward(self, x, inference_params=None):
 99        x = self.forward_features(x, inference_params)
100
101        if self.if_cls_token:  # remove the class token
102            x = x[:, 1:, :]
103
104        # let's get the patches back from the 1d tokens
105        x = self.convert_to_expected_dim(x)
106
107        return x  # from here, the tokens can be upsampled easily (N x H x W x C)
def get_vim_encoder(model_type='vim_t', with_cls_token=True):
110def get_vim_encoder(model_type="vim_t", with_cls_token=True):
111    if model_type == "vim_t":
112        # `vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual_with_cls_token`
113        # *has an imagenet pretrained model
114        encoder = ViM(
115            img_size=1024,
116            patch_size=16,
117            embed_dim=192,
118            depth=24,
119            rms_norm=True,
120            residual_in_fp32=True,
121            fused_add_norm=True,
122            final_pool_type='all',
123            if_abs_pos_embed=True,
124            if_rope=True,
125            if_rope_residual=True,
126            bimamba_type="v2",
127            if_cls_token=with_cls_token,
128        )
129    elif model_type == "vim_s":
130        # `vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual`
131        # AA: added a class token to the default models
132        encoder = ViM(
133            img_size=1024,
134            patch_size=16,
135            embed_dim=384,
136            depth=24,
137            rms_norm=True,
138            residual_in_fp32=True,
139            fused_add_norm=True,
140            final_pool_type='all',
141            if_abs_pos_embed=True,
142            if_rope=True,
143            if_rope_residual=True,
144            bimamba_type="v2",
145            if_cls_token=with_cls_token,
146        )
147    elif model_type == "vim_b":
148        # `vim_base_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual`
149        # AA: added a class token to the default models
150        encoder = ViM(
151            img_size=1024,
152            patch_size=16,
153            embed_dim=768,
154            depth=24,
155            rms_norm=True,
156            residual_in_fp32=True,
157            fused_add_norm=True,
158            final_pool_type='all',
159            if_abs_pos_embed=True,
160            if_rope=True,
161            if_rope_residual=True,
162            bimamba_type="v2",
163            if_cls_token=with_cls_token,
164        )
165    else:
166        raise ValueError("Choose from 'vim_t' / 'vim_s' / 'vim_b'")
167
168    encoder.default_cfg = _cfg()
169    return encoder
def get_vimunet_model( out_channels, model_type='vim_t', with_cls_token=True, device=None, checkpoint=None):
172def get_vimunet_model(
173    out_channels, model_type="vim_t", with_cls_token=True, device=None, checkpoint=None
174):
175    if device is None:
176        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
177
178    encoder = get_vim_encoder(model_type, with_cls_token)
179
180    model_state = None
181    if checkpoint is not None:
182        state = torch.load(checkpoint, map_location="cpu")
183
184        if checkpoint.endswith(".pth"):  # from Vim
185            encoder_state = state["model"]
186            encoder.load_state_dict(encoder_state)
187
188        else:  # from torch_em
189            model_state = state["model_state"]
190
191    encoder.img_size = encoder.patch_embed.img_size[0]
192
193    model = UNETR(
194        encoder=encoder,
195        out_channels=out_channels,
196        resize_input=False,
197        use_skip_connection=False,
198        final_activation="Sigmoid",
199    )
200
201    if model_state is not None:
202        model.load_state_dict(model_state)
203
204    model.to(device)
205
206    return model