torch_em.model.unetr

  1from collections import OrderedDict
  2from typing import Optional, Tuple, Union
  3
  4import torch
  5import torch.nn as nn
  6import torch.nn.functional as F
  7
  8from .unet import Decoder, ConvBlock2d, Upsampler2d
  9from .vit import get_vision_transformer
 10
 11try:
 12    from micro_sam.util import get_sam_model
 13except ImportError:
 14    get_sam_model = None
 15
 16
 17#
 18# UNETR IMPLEMENTATION [Vision Transformer (ViT from MAE / ViT from SAM) + UNet Decoder from `torch_em`]
 19#
 20
 21
 22class UNETR(nn.Module):
 23
 24    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 25
 26        if isinstance(checkpoint, str):
 27            if backbone == "sam" and isinstance(encoder, str):
 28                # If we have a SAM encoder, then we first try to load the full SAM Model
 29                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 30                # from the checkpoint
 31                try:
 32                    _, model = get_sam_model(
 33                        model_type=encoder,
 34                        checkpoint_path=checkpoint,
 35                        return_sam=True
 36                    )
 37                    encoder_state = model.image_encoder.state_dict()
 38                except Exception:
 39                    # Try loading the encoder state directly from a checkpoint.
 40                    encoder_state = torch.load(checkpoint)
 41
 42            elif backbone == "mae":
 43                # vit initialization hints from:
 44                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 45                encoder_state = torch.load(checkpoint)["model"]
 46                encoder_state = OrderedDict({
 47                    k: v for k, v in encoder_state.items()
 48                    if (k != "mask_token" and not k.startswith("decoder"))
 49                })
 50
 51                # let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 52                current_encoder_state = self.encoder.state_dict()
 53                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 54                    del self.encoder.head
 55
 56        else:
 57            encoder_state = checkpoint
 58
 59        self.encoder.load_state_dict(encoder_state)
 60
 61    def __init__(
 62        self,
 63        img_size: int = 1024,
 64        backbone: str = "sam",
 65        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 66        decoder: Optional[nn.Module] = None,
 67        out_channels: int = 1,
 68        use_sam_stats: bool = False,
 69        use_mae_stats: bool = False,
 70        resize_input: bool = True,
 71        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 72        final_activation: Optional[Union[str, nn.Module]] = None,
 73        use_skip_connection: bool = True,
 74        embed_dim: Optional[int] = None,
 75        use_conv_transpose=True,
 76    ) -> None:
 77        super().__init__()
 78
 79        self.use_sam_stats = use_sam_stats
 80        self.use_mae_stats = use_mae_stats
 81        self.use_skip_connection = use_skip_connection
 82        self.resize_input = resize_input
 83
 84        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
 85            print(f"Using {encoder} from {backbone.upper()}")
 86            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
 87            if encoder_checkpoint is not None:
 88                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
 89
 90            in_chans = self.encoder.in_chans
 91            if embed_dim is None:
 92                embed_dim = self.encoder.embed_dim
 93
 94        else:  # `nn.Module` ViT backbone
 95            self.encoder = encoder
 96
 97            have_neck = False
 98            for name, _ in self.encoder.named_parameters():
 99                if name.startswith("neck"):
100                    have_neck = True
101
102            if embed_dim is None:
103                if have_neck:
104                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
105                else:
106                    embed_dim = self.encoder.patch_embed.proj.out_channels
107
108            try:
109                in_chans = self.encoder.patch_embed.proj.in_channels
110            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
111                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
112
113        # parameters for the decoder network
114        depth = 3
115        initial_features = 64
116        gain = 2
117        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
118        scale_factors = depth * [2]
119        self.out_channels = out_channels
120
121        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
122        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
123
124        if decoder is None:
125            self.decoder = Decoder(
126                features=features_decoder,
127                scale_factors=scale_factors[::-1],
128                conv_block_impl=ConvBlock2d,
129                sampler_impl=_upsampler,
130                norm="OldDefault",
131            )
132        else:
133            self.decoder = decoder
134
135        if use_skip_connection:
136            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
137            self.deconv2 = nn.Sequential(
138                Deconv2DBlock(embed_dim, features_decoder[0]),
139                Deconv2DBlock(features_decoder[0], features_decoder[1])
140            )
141            self.deconv3 = nn.Sequential(
142                Deconv2DBlock(embed_dim, features_decoder[0]),
143                Deconv2DBlock(features_decoder[0], features_decoder[1]),
144                Deconv2DBlock(features_decoder[1], features_decoder[2])
145            )
146            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1], norm="OldDefault")
147        else:
148            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
149            self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
150            self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
151            self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])
152
153        self.base = ConvBlock2d(embed_dim, features_decoder[0], norm="OldDefault")
154
155        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
156
157        self.deconv_out = _upsampler(
158            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
159        )
160
161        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1], norm="OldDefault")
162
163        self.final_activation = self._get_activation(final_activation)
164
165    def _get_activation(self, activation):
166        return_activation = None
167        if activation is None:
168            return None
169        if isinstance(activation, nn.Module):
170            return activation
171        if isinstance(activation, str):
172            return_activation = getattr(nn, activation, None)
173        if return_activation is None:
174            raise ValueError(f"Invalid activation: {activation}")
175        return return_activation()
176
177    @staticmethod
178    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
179        """Compute the output size given input size and target long side length.
180        """
181        scale = long_side_length * 1.0 / max(oldh, oldw)
182        newh, neww = oldh * scale, oldw * scale
183        neww = int(neww + 0.5)
184        newh = int(newh + 0.5)
185        return (newh, neww)
186
187    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
188        """Resizes the image so that the longest side has the correct length.
189
190        Expects batched images with shape BxCxHxW and float format.
191        """
192        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
193        return F.interpolate(
194            image, target_size, mode="bilinear", align_corners=False, antialias=True
195        )
196
197    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
198        device = x.device
199
200        if self.use_sam_stats:
201            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
202            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
203        elif self.use_mae_stats:
204            # TODO: add mean std from mae experiments (or open up arguments for this)
205            raise NotImplementedError
206        else:
207            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
208            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
209
210        if self.resize_input:
211            x = self.resize_longest_side(x)
212        input_shape = x.shape[-2:]
213
214        x = (x - pixel_mean) / pixel_std
215        h, w = x.shape[-2:]
216        padh = self.encoder.img_size - h
217        padw = self.encoder.img_size - w
218        x = F.pad(x, (0, padw, 0, padh))
219        return x, input_shape
220
221    def postprocess_masks(
222        self,
223        masks: torch.Tensor,
224        input_size: Tuple[int, ...],
225        original_size: Tuple[int, ...],
226    ) -> torch.Tensor:
227        masks = F.interpolate(
228            masks,
229            (self.encoder.img_size, self.encoder.img_size),
230            mode="bilinear",
231            align_corners=False,
232        )
233        masks = masks[..., : input_size[0], : input_size[1]]
234        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
235        return masks
236
237    def forward(self, x):
238        original_shape = x.shape[-2:]
239
240        # Reshape the inputs to the shape expected by the encoder
241        # and normalize the inputs if normalization is part of the model.
242        x, input_shape = self.preprocess(x)
243
244        use_skip_connection = getattr(self, "use_skip_connection", True)
245
246        encoder_outputs = self.encoder(x)
247
248        if isinstance(encoder_outputs[-1], list):
249            # `encoder_outputs` can be arranged in only two forms:
250            #   - either we only return the image embeddings
251            #   - or, we return the image embeddings and the "list" of global attention layers
252            z12, from_encoder = encoder_outputs
253        else:
254            z12 = encoder_outputs
255
256        if use_skip_connection:
257            from_encoder = from_encoder[::-1]
258            z9 = self.deconv1(from_encoder[0])
259            z6 = self.deconv2(from_encoder[1])
260            z3 = self.deconv3(from_encoder[2])
261            z0 = self.deconv4(x)
262
263        else:
264            z9 = self.deconv1(z12)
265            z6 = self.deconv2(z9)
266            z3 = self.deconv3(z6)
267            z0 = self.deconv4(z3)
268
269        updated_from_encoder = [z9, z6, z3]
270
271        x = self.base(z12)
272        x = self.decoder(x, encoder_inputs=updated_from_encoder)
273        x = self.deconv_out(x)
274
275        x = torch.cat([x, z0], dim=1)
276        x = self.decoder_head(x)
277
278        x = self.out_conv(x)
279        if self.final_activation is not None:
280            x = self.final_activation(x)
281
282        x = self.postprocess_masks(x, input_shape, original_shape)
283        return x
284
285
286#
287#  ADDITIONAL FUNCTIONALITIES
288#
289
290
291class SingleDeconv2DBlock(nn.Module):
292    def __init__(self, scale_factor, in_channels, out_channels):
293        super().__init__()
294        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
295
296    def forward(self, x):
297        return self.block(x)
298
299
300class SingleConv2DBlock(nn.Module):
301    def __init__(self, in_channels, out_channels, kernel_size):
302        super().__init__()
303        self.block = nn.Conv2d(
304            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
305        )
306
307    def forward(self, x):
308        return self.block(x)
309
310
311class Conv2DBlock(nn.Module):
312    def __init__(self, in_channels, out_channels, kernel_size=3):
313        super().__init__()
314        self.block = nn.Sequential(
315            SingleConv2DBlock(in_channels, out_channels, kernel_size),
316            nn.BatchNorm2d(out_channels),
317            nn.ReLU(True)
318        )
319
320    def forward(self, x):
321        return self.block(x)
322
323
324class Deconv2DBlock(nn.Module):
325    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
326        super().__init__()
327        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
328        self.block = nn.Sequential(
329            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
330            SingleConv2DBlock(out_channels, out_channels, kernel_size),
331            nn.BatchNorm2d(out_channels),
332            nn.ReLU(True)
333        )
334
335    def forward(self, x):
336        return self.block(x)
class UNETR(torch.nn.modules.module.Module):
 23class UNETR(nn.Module):
 24
 25    def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
 26
 27        if isinstance(checkpoint, str):
 28            if backbone == "sam" and isinstance(encoder, str):
 29                # If we have a SAM encoder, then we first try to load the full SAM Model
 30                # (using micro_sam) and otherwise fall back on directly loading the encoder state
 31                # from the checkpoint
 32                try:
 33                    _, model = get_sam_model(
 34                        model_type=encoder,
 35                        checkpoint_path=checkpoint,
 36                        return_sam=True
 37                    )
 38                    encoder_state = model.image_encoder.state_dict()
 39                except Exception:
 40                    # Try loading the encoder state directly from a checkpoint.
 41                    encoder_state = torch.load(checkpoint)
 42
 43            elif backbone == "mae":
 44                # vit initialization hints from:
 45                #     - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
 46                encoder_state = torch.load(checkpoint)["model"]
 47                encoder_state = OrderedDict({
 48                    k: v for k, v in encoder_state.items()
 49                    if (k != "mask_token" and not k.startswith("decoder"))
 50                })
 51
 52                # let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
 53                current_encoder_state = self.encoder.state_dict()
 54                if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
 55                    del self.encoder.head
 56
 57        else:
 58            encoder_state = checkpoint
 59
 60        self.encoder.load_state_dict(encoder_state)
 61
 62    def __init__(
 63        self,
 64        img_size: int = 1024,
 65        backbone: str = "sam",
 66        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 67        decoder: Optional[nn.Module] = None,
 68        out_channels: int = 1,
 69        use_sam_stats: bool = False,
 70        use_mae_stats: bool = False,
 71        resize_input: bool = True,
 72        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 73        final_activation: Optional[Union[str, nn.Module]] = None,
 74        use_skip_connection: bool = True,
 75        embed_dim: Optional[int] = None,
 76        use_conv_transpose=True,
 77    ) -> None:
 78        super().__init__()
 79
 80        self.use_sam_stats = use_sam_stats
 81        self.use_mae_stats = use_mae_stats
 82        self.use_skip_connection = use_skip_connection
 83        self.resize_input = resize_input
 84
 85        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
 86            print(f"Using {encoder} from {backbone.upper()}")
 87            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
 88            if encoder_checkpoint is not None:
 89                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
 90
 91            in_chans = self.encoder.in_chans
 92            if embed_dim is None:
 93                embed_dim = self.encoder.embed_dim
 94
 95        else:  # `nn.Module` ViT backbone
 96            self.encoder = encoder
 97
 98            have_neck = False
 99            for name, _ in self.encoder.named_parameters():
100                if name.startswith("neck"):
101                    have_neck = True
102
103            if embed_dim is None:
104                if have_neck:
105                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
106                else:
107                    embed_dim = self.encoder.patch_embed.proj.out_channels
108
109            try:
110                in_chans = self.encoder.patch_embed.proj.in_channels
111            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
112                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
113
114        # parameters for the decoder network
115        depth = 3
116        initial_features = 64
117        gain = 2
118        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
119        scale_factors = depth * [2]
120        self.out_channels = out_channels
121
122        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
123        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
124
125        if decoder is None:
126            self.decoder = Decoder(
127                features=features_decoder,
128                scale_factors=scale_factors[::-1],
129                conv_block_impl=ConvBlock2d,
130                sampler_impl=_upsampler,
131                norm="OldDefault",
132            )
133        else:
134            self.decoder = decoder
135
136        if use_skip_connection:
137            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
138            self.deconv2 = nn.Sequential(
139                Deconv2DBlock(embed_dim, features_decoder[0]),
140                Deconv2DBlock(features_decoder[0], features_decoder[1])
141            )
142            self.deconv3 = nn.Sequential(
143                Deconv2DBlock(embed_dim, features_decoder[0]),
144                Deconv2DBlock(features_decoder[0], features_decoder[1]),
145                Deconv2DBlock(features_decoder[1], features_decoder[2])
146            )
147            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1], norm="OldDefault")
148        else:
149            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
150            self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
151            self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
152            self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])
153
154        self.base = ConvBlock2d(embed_dim, features_decoder[0], norm="OldDefault")
155
156        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
157
158        self.deconv_out = _upsampler(
159            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
160        )
161
162        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1], norm="OldDefault")
163
164        self.final_activation = self._get_activation(final_activation)
165
166    def _get_activation(self, activation):
167        return_activation = None
168        if activation is None:
169            return None
170        if isinstance(activation, nn.Module):
171            return activation
172        if isinstance(activation, str):
173            return_activation = getattr(nn, activation, None)
174        if return_activation is None:
175            raise ValueError(f"Invalid activation: {activation}")
176        return return_activation()
177
178    @staticmethod
179    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
180        """Compute the output size given input size and target long side length.
181        """
182        scale = long_side_length * 1.0 / max(oldh, oldw)
183        newh, neww = oldh * scale, oldw * scale
184        neww = int(neww + 0.5)
185        newh = int(newh + 0.5)
186        return (newh, neww)
187
188    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
189        """Resizes the image so that the longest side has the correct length.
190
191        Expects batched images with shape BxCxHxW and float format.
192        """
193        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
194        return F.interpolate(
195            image, target_size, mode="bilinear", align_corners=False, antialias=True
196        )
197
198    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
199        device = x.device
200
201        if self.use_sam_stats:
202            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
203            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
204        elif self.use_mae_stats:
205            # TODO: add mean std from mae experiments (or open up arguments for this)
206            raise NotImplementedError
207        else:
208            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
209            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
210
211        if self.resize_input:
212            x = self.resize_longest_side(x)
213        input_shape = x.shape[-2:]
214
215        x = (x - pixel_mean) / pixel_std
216        h, w = x.shape[-2:]
217        padh = self.encoder.img_size - h
218        padw = self.encoder.img_size - w
219        x = F.pad(x, (0, padw, 0, padh))
220        return x, input_shape
221
222    def postprocess_masks(
223        self,
224        masks: torch.Tensor,
225        input_size: Tuple[int, ...],
226        original_size: Tuple[int, ...],
227    ) -> torch.Tensor:
228        masks = F.interpolate(
229            masks,
230            (self.encoder.img_size, self.encoder.img_size),
231            mode="bilinear",
232            align_corners=False,
233        )
234        masks = masks[..., : input_size[0], : input_size[1]]
235        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
236        return masks
237
238    def forward(self, x):
239        original_shape = x.shape[-2:]
240
241        # Reshape the inputs to the shape expected by the encoder
242        # and normalize the inputs if normalization is part of the model.
243        x, input_shape = self.preprocess(x)
244
245        use_skip_connection = getattr(self, "use_skip_connection", True)
246
247        encoder_outputs = self.encoder(x)
248
249        if isinstance(encoder_outputs[-1], list):
250            # `encoder_outputs` can be arranged in only two forms:
251            #   - either we only return the image embeddings
252            #   - or, we return the image embeddings and the "list" of global attention layers
253            z12, from_encoder = encoder_outputs
254        else:
255            z12 = encoder_outputs
256
257        if use_skip_connection:
258            from_encoder = from_encoder[::-1]
259            z9 = self.deconv1(from_encoder[0])
260            z6 = self.deconv2(from_encoder[1])
261            z3 = self.deconv3(from_encoder[2])
262            z0 = self.deconv4(x)
263
264        else:
265            z9 = self.deconv1(z12)
266            z6 = self.deconv2(z9)
267            z3 = self.deconv3(z6)
268            z0 = self.deconv4(z3)
269
270        updated_from_encoder = [z9, z6, z3]
271
272        x = self.base(z12)
273        x = self.decoder(x, encoder_inputs=updated_from_encoder)
274        x = self.deconv_out(x)
275
276        x = torch.cat([x, z0], dim=1)
277        x = self.decoder_head(x)
278
279        x = self.out_conv(x)
280        if self.final_activation is not None:
281            x = self.final_activation(x)
282
283        x = self.postprocess_masks(x, input_shape, original_shape)
284        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

UNETR( img_size: int = 1024, backbone: str = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose=True)
 62    def __init__(
 63        self,
 64        img_size: int = 1024,
 65        backbone: str = "sam",
 66        encoder: Optional[Union[nn.Module, str]] = "vit_b",
 67        decoder: Optional[nn.Module] = None,
 68        out_channels: int = 1,
 69        use_sam_stats: bool = False,
 70        use_mae_stats: bool = False,
 71        resize_input: bool = True,
 72        encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
 73        final_activation: Optional[Union[str, nn.Module]] = None,
 74        use_skip_connection: bool = True,
 75        embed_dim: Optional[int] = None,
 76        use_conv_transpose=True,
 77    ) -> None:
 78        super().__init__()
 79
 80        self.use_sam_stats = use_sam_stats
 81        self.use_mae_stats = use_mae_stats
 82        self.use_skip_connection = use_skip_connection
 83        self.resize_input = resize_input
 84
 85        if isinstance(encoder, str):  # "vit_b" / "vit_l" / "vit_h"
 86            print(f"Using {encoder} from {backbone.upper()}")
 87            self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
 88            if encoder_checkpoint is not None:
 89                self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
 90
 91            in_chans = self.encoder.in_chans
 92            if embed_dim is None:
 93                embed_dim = self.encoder.embed_dim
 94
 95        else:  # `nn.Module` ViT backbone
 96            self.encoder = encoder
 97
 98            have_neck = False
 99            for name, _ in self.encoder.named_parameters():
100                if name.startswith("neck"):
101                    have_neck = True
102
103            if embed_dim is None:
104                if have_neck:
105                    embed_dim = self.encoder.neck[2].out_channels  # the value is 256
106                else:
107                    embed_dim = self.encoder.patch_embed.proj.out_channels
108
109            try:
110                in_chans = self.encoder.patch_embed.proj.in_channels
111            except AttributeError:  # for getting the input channels while using vit_t from MobileSam
112                in_chans = self.encoder.patch_embed.seq[0].c.in_channels
113
114        # parameters for the decoder network
115        depth = 3
116        initial_features = 64
117        gain = 2
118        features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1]
119        scale_factors = depth * [2]
120        self.out_channels = out_channels
121
122        # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
123        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
124
125        if decoder is None:
126            self.decoder = Decoder(
127                features=features_decoder,
128                scale_factors=scale_factors[::-1],
129                conv_block_impl=ConvBlock2d,
130                sampler_impl=_upsampler,
131                norm="OldDefault",
132            )
133        else:
134            self.decoder = decoder
135
136        if use_skip_connection:
137            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
138            self.deconv2 = nn.Sequential(
139                Deconv2DBlock(embed_dim, features_decoder[0]),
140                Deconv2DBlock(features_decoder[0], features_decoder[1])
141            )
142            self.deconv3 = nn.Sequential(
143                Deconv2DBlock(embed_dim, features_decoder[0]),
144                Deconv2DBlock(features_decoder[0], features_decoder[1]),
145                Deconv2DBlock(features_decoder[1], features_decoder[2])
146            )
147            self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1], norm="OldDefault")
148        else:
149            self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
150            self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
151            self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
152            self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])
153
154        self.base = ConvBlock2d(embed_dim, features_decoder[0], norm="OldDefault")
155
156        self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)
157
158        self.deconv_out = _upsampler(
159            scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
160        )
161
162        self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1], norm="OldDefault")
163
164        self.final_activation = self._get_activation(final_activation)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

use_sam_stats
use_mae_stats
use_skip_connection
resize_input
out_channels
base
out_conv
deconv_out
decoder_head
final_activation
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
178    @staticmethod
179    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
180        """Compute the output size given input size and target long side length.
181        """
182        scale = long_side_length * 1.0 / max(oldh, oldw)
183        newh, neww = oldh * scale, oldw * scale
184        neww = int(neww + 0.5)
185        newh = int(newh + 0.5)
186        return (newh, neww)

Compute the output size given input size and target long side length.

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
188    def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
189        """Resizes the image so that the longest side has the correct length.
190
191        Expects batched images with shape BxCxHxW and float format.
192        """
193        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
194        return F.interpolate(
195            image, target_size, mode="bilinear", align_corners=False, antialias=True
196        )

Resizes the image so that the longest side has the correct length.

Expects batched images with shape BxCxHxW and float format.

def preprocess(self, x: torch.Tensor) -> torch.Tensor:
198    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
199        device = x.device
200
201        if self.use_sam_stats:
202            pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
203            pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
204        elif self.use_mae_stats:
205            # TODO: add mean std from mae experiments (or open up arguments for this)
206            raise NotImplementedError
207        else:
208            pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
209            pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
210
211        if self.resize_input:
212            x = self.resize_longest_side(x)
213        input_shape = x.shape[-2:]
214
215        x = (x - pixel_mean) / pixel_std
216        h, w = x.shape[-2:]
217        padh = self.encoder.img_size - h
218        padw = self.encoder.img_size - w
219        x = F.pad(x, (0, padw, 0, padh))
220        return x, input_shape
def postprocess_masks( self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...]) -> torch.Tensor:
222    def postprocess_masks(
223        self,
224        masks: torch.Tensor,
225        input_size: Tuple[int, ...],
226        original_size: Tuple[int, ...],
227    ) -> torch.Tensor:
228        masks = F.interpolate(
229            masks,
230            (self.encoder.img_size, self.encoder.img_size),
231            mode="bilinear",
232            align_corners=False,
233        )
234        masks = masks[..., : input_size[0], : input_size[1]]
235        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
236        return masks
def forward(self, x):
238    def forward(self, x):
239        original_shape = x.shape[-2:]
240
241        # Reshape the inputs to the shape expected by the encoder
242        # and normalize the inputs if normalization is part of the model.
243        x, input_shape = self.preprocess(x)
244
245        use_skip_connection = getattr(self, "use_skip_connection", True)
246
247        encoder_outputs = self.encoder(x)
248
249        if isinstance(encoder_outputs[-1], list):
250            # `encoder_outputs` can be arranged in only two forms:
251            #   - either we only return the image embeddings
252            #   - or, we return the image embeddings and the "list" of global attention layers
253            z12, from_encoder = encoder_outputs
254        else:
255            z12 = encoder_outputs
256
257        if use_skip_connection:
258            from_encoder = from_encoder[::-1]
259            z9 = self.deconv1(from_encoder[0])
260            z6 = self.deconv2(from_encoder[1])
261            z3 = self.deconv3(from_encoder[2])
262            z0 = self.deconv4(x)
263
264        else:
265            z9 = self.deconv1(z12)
266            z6 = self.deconv2(z9)
267            z3 = self.deconv3(z6)
268            z0 = self.deconv4(z3)
269
270        updated_from_encoder = [z9, z6, z3]
271
272        x = self.base(z12)
273        x = self.decoder(x, encoder_inputs=updated_from_encoder)
274        x = self.deconv_out(x)
275
276        x = torch.cat([x, z0], dim=1)
277        x = self.decoder_head(x)
278
279        x = self.out_conv(x)
280        if self.final_activation is not None:
281            x = self.final_activation(x)
282
283        x = self.postprocess_masks(x, input_shape, original_shape)
284        return x

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SingleDeconv2DBlock(torch.nn.modules.module.Module):
292class SingleDeconv2DBlock(nn.Module):
293    def __init__(self, scale_factor, in_channels, out_channels):
294        super().__init__()
295        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
296
297    def forward(self, x):
298        return self.block(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SingleDeconv2DBlock(scale_factor, in_channels, out_channels)
293    def __init__(self, scale_factor, in_channels, out_channels):
294        super().__init__()
295        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

block
def forward(self, x):
297    def forward(self, x):
298        return self.block(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SingleConv2DBlock(torch.nn.modules.module.Module):
301class SingleConv2DBlock(nn.Module):
302    def __init__(self, in_channels, out_channels, kernel_size):
303        super().__init__()
304        self.block = nn.Conv2d(
305            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
306        )
307
308    def forward(self, x):
309        return self.block(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SingleConv2DBlock(in_channels, out_channels, kernel_size)
302    def __init__(self, in_channels, out_channels, kernel_size):
303        super().__init__()
304        self.block = nn.Conv2d(
305            in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
306        )

Initializes internal Module state, shared by both nn.Module and ScriptModule.

block
def forward(self, x):
308    def forward(self, x):
309        return self.block(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class Conv2DBlock(torch.nn.modules.module.Module):
312class Conv2DBlock(nn.Module):
313    def __init__(self, in_channels, out_channels, kernel_size=3):
314        super().__init__()
315        self.block = nn.Sequential(
316            SingleConv2DBlock(in_channels, out_channels, kernel_size),
317            nn.BatchNorm2d(out_channels),
318            nn.ReLU(True)
319        )
320
321    def forward(self, x):
322        return self.block(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Conv2DBlock(in_channels, out_channels, kernel_size=3)
313    def __init__(self, in_channels, out_channels, kernel_size=3):
314        super().__init__()
315        self.block = nn.Sequential(
316            SingleConv2DBlock(in_channels, out_channels, kernel_size),
317            nn.BatchNorm2d(out_channels),
318            nn.ReLU(True)
319        )

Initializes internal Module state, shared by both nn.Module and ScriptModule.

block
def forward(self, x):
321    def forward(self, x):
322        return self.block(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class Deconv2DBlock(torch.nn.modules.module.Module):
325class Deconv2DBlock(nn.Module):
326    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
327        super().__init__()
328        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
329        self.block = nn.Sequential(
330            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
331            SingleConv2DBlock(out_channels, out_channels, kernel_size),
332            nn.BatchNorm2d(out_channels),
333            nn.ReLU(True)
334        )
335
336    def forward(self, x):
337        return self.block(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Deconv2DBlock(in_channels, out_channels, kernel_size=3, use_conv_transpose=True)
326    def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
327        super().__init__()
328        _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
329        self.block = nn.Sequential(
330            _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
331            SingleConv2DBlock(out_channels, out_channels, kernel_size),
332            nn.BatchNorm2d(out_channels),
333            nn.ReLU(True)
334        )

Initializes internal Module state, shared by both nn.Module and ScriptModule.

block
def forward(self, x):
336    def forward(self, x):
337        return self.block(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile