torch_em.model.unetr
1from collections import OrderedDict 2from typing import Optional, Tuple, Union 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7 8from .unet import Decoder, ConvBlock2d, Upsampler2d 9from .vit import get_vision_transformer 10 11try: 12 from micro_sam.util import get_sam_model 13except ImportError: 14 get_sam_model = None 15 16 17# 18# UNETR IMPLEMENTATION [Vision Transformer (ViT from MAE / ViT from SAM) + UNet Decoder from `torch_em`] 19# 20 21 22class UNETR(nn.Module): 23 24 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 25 26 if isinstance(checkpoint, str): 27 if backbone == "sam" and isinstance(encoder, str): 28 # If we have a SAM encoder, then we first try to load the full SAM Model 29 # (using micro_sam) and otherwise fall back on directly loading the encoder state 30 # from the checkpoint 31 try: 32 _, model = get_sam_model( 33 model_type=encoder, 34 checkpoint_path=checkpoint, 35 return_sam=True 36 ) 37 encoder_state = model.image_encoder.state_dict() 38 except Exception: 39 # Try loading the encoder state directly from a checkpoint. 40 encoder_state = torch.load(checkpoint) 41 42 elif backbone == "mae": 43 # vit initialization hints from: 44 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 45 encoder_state = torch.load(checkpoint)["model"] 46 encoder_state = OrderedDict({ 47 k: v for k, v in encoder_state.items() 48 if (k != "mask_token" and not k.startswith("decoder")) 49 }) 50 51 # let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 52 current_encoder_state = self.encoder.state_dict() 53 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 54 del self.encoder.head 55 56 else: 57 encoder_state = checkpoint 58 59 self.encoder.load_state_dict(encoder_state) 60 61 def __init__( 62 self, 63 img_size: int = 1024, 64 backbone: str = "sam", 65 encoder: Optional[Union[nn.Module, str]] = "vit_b", 66 decoder: Optional[nn.Module] = None, 67 out_channels: int = 1, 68 use_sam_stats: bool = False, 69 use_mae_stats: bool = False, 70 resize_input: bool = True, 71 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 72 final_activation: Optional[Union[str, nn.Module]] = None, 73 use_skip_connection: bool = True, 74 embed_dim: Optional[int] = None, 75 use_conv_transpose=True, 76 ) -> None: 77 super().__init__() 78 79 self.use_sam_stats = use_sam_stats 80 self.use_mae_stats = use_mae_stats 81 self.use_skip_connection = use_skip_connection 82 self.resize_input = resize_input 83 84 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 85 print(f"Using {encoder} from {backbone.upper()}") 86 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) 87 if encoder_checkpoint is not None: 88 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 89 90 in_chans = self.encoder.in_chans 91 if embed_dim is None: 92 embed_dim = self.encoder.embed_dim 93 94 else: # `nn.Module` ViT backbone 95 self.encoder = encoder 96 97 have_neck = False 98 for name, _ in self.encoder.named_parameters(): 99 if name.startswith("neck"): 100 have_neck = True 101 102 if embed_dim is None: 103 if have_neck: 104 embed_dim = self.encoder.neck[2].out_channels # the value is 256 105 else: 106 embed_dim = self.encoder.patch_embed.proj.out_channels 107 108 try: 109 in_chans = self.encoder.patch_embed.proj.in_channels 110 except AttributeError: # for getting the input channels while using vit_t from MobileSam 111 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 112 113 # parameters for the decoder network 114 depth = 3 115 initial_features = 64 116 gain = 2 117 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 118 scale_factors = depth * [2] 119 self.out_channels = out_channels 120 121 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 122 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 123 124 if decoder is None: 125 self.decoder = Decoder( 126 features=features_decoder, 127 scale_factors=scale_factors[::-1], 128 conv_block_impl=ConvBlock2d, 129 sampler_impl=_upsampler, 130 norm="OldDefault", 131 ) 132 else: 133 self.decoder = decoder 134 135 if use_skip_connection: 136 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 137 self.deconv2 = nn.Sequential( 138 Deconv2DBlock(embed_dim, features_decoder[0]), 139 Deconv2DBlock(features_decoder[0], features_decoder[1]) 140 ) 141 self.deconv3 = nn.Sequential( 142 Deconv2DBlock(embed_dim, features_decoder[0]), 143 Deconv2DBlock(features_decoder[0], features_decoder[1]), 144 Deconv2DBlock(features_decoder[1], features_decoder[2]) 145 ) 146 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1], norm="OldDefault") 147 else: 148 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 149 self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) 150 self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) 151 self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) 152 153 self.base = ConvBlock2d(embed_dim, features_decoder[0], norm="OldDefault") 154 155 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 156 157 self.deconv_out = _upsampler( 158 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 159 ) 160 161 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1], norm="OldDefault") 162 163 self.final_activation = self._get_activation(final_activation) 164 165 def _get_activation(self, activation): 166 return_activation = None 167 if activation is None: 168 return None 169 if isinstance(activation, nn.Module): 170 return activation 171 if isinstance(activation, str): 172 return_activation = getattr(nn, activation, None) 173 if return_activation is None: 174 raise ValueError(f"Invalid activation: {activation}") 175 return return_activation() 176 177 @staticmethod 178 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 179 """Compute the output size given input size and target long side length. 180 """ 181 scale = long_side_length * 1.0 / max(oldh, oldw) 182 newh, neww = oldh * scale, oldw * scale 183 neww = int(neww + 0.5) 184 newh = int(newh + 0.5) 185 return (newh, neww) 186 187 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 188 """Resizes the image so that the longest side has the correct length. 189 190 Expects batched images with shape BxCxHxW and float format. 191 """ 192 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 193 return F.interpolate( 194 image, target_size, mode="bilinear", align_corners=False, antialias=True 195 ) 196 197 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 198 device = x.device 199 200 if self.use_sam_stats: 201 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 202 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 203 elif self.use_mae_stats: 204 # TODO: add mean std from mae experiments (or open up arguments for this) 205 raise NotImplementedError 206 else: 207 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 208 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 209 210 if self.resize_input: 211 x = self.resize_longest_side(x) 212 input_shape = x.shape[-2:] 213 214 x = (x - pixel_mean) / pixel_std 215 h, w = x.shape[-2:] 216 padh = self.encoder.img_size - h 217 padw = self.encoder.img_size - w 218 x = F.pad(x, (0, padw, 0, padh)) 219 return x, input_shape 220 221 def postprocess_masks( 222 self, 223 masks: torch.Tensor, 224 input_size: Tuple[int, ...], 225 original_size: Tuple[int, ...], 226 ) -> torch.Tensor: 227 masks = F.interpolate( 228 masks, 229 (self.encoder.img_size, self.encoder.img_size), 230 mode="bilinear", 231 align_corners=False, 232 ) 233 masks = masks[..., : input_size[0], : input_size[1]] 234 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 235 return masks 236 237 def forward(self, x): 238 original_shape = x.shape[-2:] 239 240 # Reshape the inputs to the shape expected by the encoder 241 # and normalize the inputs if normalization is part of the model. 242 x, input_shape = self.preprocess(x) 243 244 use_skip_connection = getattr(self, "use_skip_connection", True) 245 246 encoder_outputs = self.encoder(x) 247 248 if isinstance(encoder_outputs[-1], list): 249 # `encoder_outputs` can be arranged in only two forms: 250 # - either we only return the image embeddings 251 # - or, we return the image embeddings and the "list" of global attention layers 252 z12, from_encoder = encoder_outputs 253 else: 254 z12 = encoder_outputs 255 256 if use_skip_connection: 257 from_encoder = from_encoder[::-1] 258 z9 = self.deconv1(from_encoder[0]) 259 z6 = self.deconv2(from_encoder[1]) 260 z3 = self.deconv3(from_encoder[2]) 261 z0 = self.deconv4(x) 262 263 else: 264 z9 = self.deconv1(z12) 265 z6 = self.deconv2(z9) 266 z3 = self.deconv3(z6) 267 z0 = self.deconv4(z3) 268 269 updated_from_encoder = [z9, z6, z3] 270 271 x = self.base(z12) 272 x = self.decoder(x, encoder_inputs=updated_from_encoder) 273 x = self.deconv_out(x) 274 275 x = torch.cat([x, z0], dim=1) 276 x = self.decoder_head(x) 277 278 x = self.out_conv(x) 279 if self.final_activation is not None: 280 x = self.final_activation(x) 281 282 x = self.postprocess_masks(x, input_shape, original_shape) 283 return x 284 285 286# 287# ADDITIONAL FUNCTIONALITIES 288# 289 290 291class SingleDeconv2DBlock(nn.Module): 292 def __init__(self, scale_factor, in_channels, out_channels): 293 super().__init__() 294 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 295 296 def forward(self, x): 297 return self.block(x) 298 299 300class SingleConv2DBlock(nn.Module): 301 def __init__(self, in_channels, out_channels, kernel_size): 302 super().__init__() 303 self.block = nn.Conv2d( 304 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 305 ) 306 307 def forward(self, x): 308 return self.block(x) 309 310 311class Conv2DBlock(nn.Module): 312 def __init__(self, in_channels, out_channels, kernel_size=3): 313 super().__init__() 314 self.block = nn.Sequential( 315 SingleConv2DBlock(in_channels, out_channels, kernel_size), 316 nn.BatchNorm2d(out_channels), 317 nn.ReLU(True) 318 ) 319 320 def forward(self, x): 321 return self.block(x) 322 323 324class Deconv2DBlock(nn.Module): 325 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 326 super().__init__() 327 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 328 self.block = nn.Sequential( 329 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 330 SingleConv2DBlock(out_channels, out_channels, kernel_size), 331 nn.BatchNorm2d(out_channels), 332 nn.ReLU(True) 333 ) 334 335 def forward(self, x): 336 return self.block(x)
23class UNETR(nn.Module): 24 25 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 26 27 if isinstance(checkpoint, str): 28 if backbone == "sam" and isinstance(encoder, str): 29 # If we have a SAM encoder, then we first try to load the full SAM Model 30 # (using micro_sam) and otherwise fall back on directly loading the encoder state 31 # from the checkpoint 32 try: 33 _, model = get_sam_model( 34 model_type=encoder, 35 checkpoint_path=checkpoint, 36 return_sam=True 37 ) 38 encoder_state = model.image_encoder.state_dict() 39 except Exception: 40 # Try loading the encoder state directly from a checkpoint. 41 encoder_state = torch.load(checkpoint) 42 43 elif backbone == "mae": 44 # vit initialization hints from: 45 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 46 encoder_state = torch.load(checkpoint)["model"] 47 encoder_state = OrderedDict({ 48 k: v for k, v in encoder_state.items() 49 if (k != "mask_token" and not k.startswith("decoder")) 50 }) 51 52 # let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 53 current_encoder_state = self.encoder.state_dict() 54 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 55 del self.encoder.head 56 57 else: 58 encoder_state = checkpoint 59 60 self.encoder.load_state_dict(encoder_state) 61 62 def __init__( 63 self, 64 img_size: int = 1024, 65 backbone: str = "sam", 66 encoder: Optional[Union[nn.Module, str]] = "vit_b", 67 decoder: Optional[nn.Module] = None, 68 out_channels: int = 1, 69 use_sam_stats: bool = False, 70 use_mae_stats: bool = False, 71 resize_input: bool = True, 72 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 73 final_activation: Optional[Union[str, nn.Module]] = None, 74 use_skip_connection: bool = True, 75 embed_dim: Optional[int] = None, 76 use_conv_transpose=True, 77 ) -> None: 78 super().__init__() 79 80 self.use_sam_stats = use_sam_stats 81 self.use_mae_stats = use_mae_stats 82 self.use_skip_connection = use_skip_connection 83 self.resize_input = resize_input 84 85 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 86 print(f"Using {encoder} from {backbone.upper()}") 87 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) 88 if encoder_checkpoint is not None: 89 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 90 91 in_chans = self.encoder.in_chans 92 if embed_dim is None: 93 embed_dim = self.encoder.embed_dim 94 95 else: # `nn.Module` ViT backbone 96 self.encoder = encoder 97 98 have_neck = False 99 for name, _ in self.encoder.named_parameters(): 100 if name.startswith("neck"): 101 have_neck = True 102 103 if embed_dim is None: 104 if have_neck: 105 embed_dim = self.encoder.neck[2].out_channels # the value is 256 106 else: 107 embed_dim = self.encoder.patch_embed.proj.out_channels 108 109 try: 110 in_chans = self.encoder.patch_embed.proj.in_channels 111 except AttributeError: # for getting the input channels while using vit_t from MobileSam 112 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 113 114 # parameters for the decoder network 115 depth = 3 116 initial_features = 64 117 gain = 2 118 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 119 scale_factors = depth * [2] 120 self.out_channels = out_channels 121 122 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 123 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 124 125 if decoder is None: 126 self.decoder = Decoder( 127 features=features_decoder, 128 scale_factors=scale_factors[::-1], 129 conv_block_impl=ConvBlock2d, 130 sampler_impl=_upsampler, 131 norm="OldDefault", 132 ) 133 else: 134 self.decoder = decoder 135 136 if use_skip_connection: 137 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 138 self.deconv2 = nn.Sequential( 139 Deconv2DBlock(embed_dim, features_decoder[0]), 140 Deconv2DBlock(features_decoder[0], features_decoder[1]) 141 ) 142 self.deconv3 = nn.Sequential( 143 Deconv2DBlock(embed_dim, features_decoder[0]), 144 Deconv2DBlock(features_decoder[0], features_decoder[1]), 145 Deconv2DBlock(features_decoder[1], features_decoder[2]) 146 ) 147 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1], norm="OldDefault") 148 else: 149 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 150 self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) 151 self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) 152 self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) 153 154 self.base = ConvBlock2d(embed_dim, features_decoder[0], norm="OldDefault") 155 156 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 157 158 self.deconv_out = _upsampler( 159 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 160 ) 161 162 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1], norm="OldDefault") 163 164 self.final_activation = self._get_activation(final_activation) 165 166 def _get_activation(self, activation): 167 return_activation = None 168 if activation is None: 169 return None 170 if isinstance(activation, nn.Module): 171 return activation 172 if isinstance(activation, str): 173 return_activation = getattr(nn, activation, None) 174 if return_activation is None: 175 raise ValueError(f"Invalid activation: {activation}") 176 return return_activation() 177 178 @staticmethod 179 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 180 """Compute the output size given input size and target long side length. 181 """ 182 scale = long_side_length * 1.0 / max(oldh, oldw) 183 newh, neww = oldh * scale, oldw * scale 184 neww = int(neww + 0.5) 185 newh = int(newh + 0.5) 186 return (newh, neww) 187 188 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 189 """Resizes the image so that the longest side has the correct length. 190 191 Expects batched images with shape BxCxHxW and float format. 192 """ 193 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 194 return F.interpolate( 195 image, target_size, mode="bilinear", align_corners=False, antialias=True 196 ) 197 198 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 199 device = x.device 200 201 if self.use_sam_stats: 202 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 203 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 204 elif self.use_mae_stats: 205 # TODO: add mean std from mae experiments (or open up arguments for this) 206 raise NotImplementedError 207 else: 208 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 209 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 210 211 if self.resize_input: 212 x = self.resize_longest_side(x) 213 input_shape = x.shape[-2:] 214 215 x = (x - pixel_mean) / pixel_std 216 h, w = x.shape[-2:] 217 padh = self.encoder.img_size - h 218 padw = self.encoder.img_size - w 219 x = F.pad(x, (0, padw, 0, padh)) 220 return x, input_shape 221 222 def postprocess_masks( 223 self, 224 masks: torch.Tensor, 225 input_size: Tuple[int, ...], 226 original_size: Tuple[int, ...], 227 ) -> torch.Tensor: 228 masks = F.interpolate( 229 masks, 230 (self.encoder.img_size, self.encoder.img_size), 231 mode="bilinear", 232 align_corners=False, 233 ) 234 masks = masks[..., : input_size[0], : input_size[1]] 235 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 236 return masks 237 238 def forward(self, x): 239 original_shape = x.shape[-2:] 240 241 # Reshape the inputs to the shape expected by the encoder 242 # and normalize the inputs if normalization is part of the model. 243 x, input_shape = self.preprocess(x) 244 245 use_skip_connection = getattr(self, "use_skip_connection", True) 246 247 encoder_outputs = self.encoder(x) 248 249 if isinstance(encoder_outputs[-1], list): 250 # `encoder_outputs` can be arranged in only two forms: 251 # - either we only return the image embeddings 252 # - or, we return the image embeddings and the "list" of global attention layers 253 z12, from_encoder = encoder_outputs 254 else: 255 z12 = encoder_outputs 256 257 if use_skip_connection: 258 from_encoder = from_encoder[::-1] 259 z9 = self.deconv1(from_encoder[0]) 260 z6 = self.deconv2(from_encoder[1]) 261 z3 = self.deconv3(from_encoder[2]) 262 z0 = self.deconv4(x) 263 264 else: 265 z9 = self.deconv1(z12) 266 z6 = self.deconv2(z9) 267 z3 = self.deconv3(z6) 268 z0 = self.deconv4(z3) 269 270 updated_from_encoder = [z9, z6, z3] 271 272 x = self.base(z12) 273 x = self.decoder(x, encoder_inputs=updated_from_encoder) 274 x = self.deconv_out(x) 275 276 x = torch.cat([x, z0], dim=1) 277 x = self.decoder_head(x) 278 279 x = self.out_conv(x) 280 if self.final_activation is not None: 281 x = self.final_activation(x) 282 283 x = self.postprocess_masks(x, input_shape, original_shape) 284 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
62 def __init__( 63 self, 64 img_size: int = 1024, 65 backbone: str = "sam", 66 encoder: Optional[Union[nn.Module, str]] = "vit_b", 67 decoder: Optional[nn.Module] = None, 68 out_channels: int = 1, 69 use_sam_stats: bool = False, 70 use_mae_stats: bool = False, 71 resize_input: bool = True, 72 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 73 final_activation: Optional[Union[str, nn.Module]] = None, 74 use_skip_connection: bool = True, 75 embed_dim: Optional[int] = None, 76 use_conv_transpose=True, 77 ) -> None: 78 super().__init__() 79 80 self.use_sam_stats = use_sam_stats 81 self.use_mae_stats = use_mae_stats 82 self.use_skip_connection = use_skip_connection 83 self.resize_input = resize_input 84 85 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 86 print(f"Using {encoder} from {backbone.upper()}") 87 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) 88 if encoder_checkpoint is not None: 89 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 90 91 in_chans = self.encoder.in_chans 92 if embed_dim is None: 93 embed_dim = self.encoder.embed_dim 94 95 else: # `nn.Module` ViT backbone 96 self.encoder = encoder 97 98 have_neck = False 99 for name, _ in self.encoder.named_parameters(): 100 if name.startswith("neck"): 101 have_neck = True 102 103 if embed_dim is None: 104 if have_neck: 105 embed_dim = self.encoder.neck[2].out_channels # the value is 256 106 else: 107 embed_dim = self.encoder.patch_embed.proj.out_channels 108 109 try: 110 in_chans = self.encoder.patch_embed.proj.in_channels 111 except AttributeError: # for getting the input channels while using vit_t from MobileSam 112 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 113 114 # parameters for the decoder network 115 depth = 3 116 initial_features = 64 117 gain = 2 118 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 119 scale_factors = depth * [2] 120 self.out_channels = out_channels 121 122 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 123 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 124 125 if decoder is None: 126 self.decoder = Decoder( 127 features=features_decoder, 128 scale_factors=scale_factors[::-1], 129 conv_block_impl=ConvBlock2d, 130 sampler_impl=_upsampler, 131 norm="OldDefault", 132 ) 133 else: 134 self.decoder = decoder 135 136 if use_skip_connection: 137 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 138 self.deconv2 = nn.Sequential( 139 Deconv2DBlock(embed_dim, features_decoder[0]), 140 Deconv2DBlock(features_decoder[0], features_decoder[1]) 141 ) 142 self.deconv3 = nn.Sequential( 143 Deconv2DBlock(embed_dim, features_decoder[0]), 144 Deconv2DBlock(features_decoder[0], features_decoder[1]), 145 Deconv2DBlock(features_decoder[1], features_decoder[2]) 146 ) 147 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1], norm="OldDefault") 148 else: 149 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 150 self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) 151 self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) 152 self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) 153 154 self.base = ConvBlock2d(embed_dim, features_decoder[0], norm="OldDefault") 155 156 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 157 158 self.deconv_out = _upsampler( 159 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 160 ) 161 162 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1], norm="OldDefault") 163 164 self.final_activation = self._get_activation(final_activation)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
178 @staticmethod 179 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 180 """Compute the output size given input size and target long side length. 181 """ 182 scale = long_side_length * 1.0 / max(oldh, oldw) 183 newh, neww = oldh * scale, oldw * scale 184 neww = int(neww + 0.5) 185 newh = int(newh + 0.5) 186 return (newh, neww)
Compute the output size given input size and target long side length.
188 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 189 """Resizes the image so that the longest side has the correct length. 190 191 Expects batched images with shape BxCxHxW and float format. 192 """ 193 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 194 return F.interpolate( 195 image, target_size, mode="bilinear", align_corners=False, antialias=True 196 )
Resizes the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW and float format.
198 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 199 device = x.device 200 201 if self.use_sam_stats: 202 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 203 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 204 elif self.use_mae_stats: 205 # TODO: add mean std from mae experiments (or open up arguments for this) 206 raise NotImplementedError 207 else: 208 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 209 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 210 211 if self.resize_input: 212 x = self.resize_longest_side(x) 213 input_shape = x.shape[-2:] 214 215 x = (x - pixel_mean) / pixel_std 216 h, w = x.shape[-2:] 217 padh = self.encoder.img_size - h 218 padw = self.encoder.img_size - w 219 x = F.pad(x, (0, padw, 0, padh)) 220 return x, input_shape
222 def postprocess_masks( 223 self, 224 masks: torch.Tensor, 225 input_size: Tuple[int, ...], 226 original_size: Tuple[int, ...], 227 ) -> torch.Tensor: 228 masks = F.interpolate( 229 masks, 230 (self.encoder.img_size, self.encoder.img_size), 231 mode="bilinear", 232 align_corners=False, 233 ) 234 masks = masks[..., : input_size[0], : input_size[1]] 235 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 236 return masks
238 def forward(self, x): 239 original_shape = x.shape[-2:] 240 241 # Reshape the inputs to the shape expected by the encoder 242 # and normalize the inputs if normalization is part of the model. 243 x, input_shape = self.preprocess(x) 244 245 use_skip_connection = getattr(self, "use_skip_connection", True) 246 247 encoder_outputs = self.encoder(x) 248 249 if isinstance(encoder_outputs[-1], list): 250 # `encoder_outputs` can be arranged in only two forms: 251 # - either we only return the image embeddings 252 # - or, we return the image embeddings and the "list" of global attention layers 253 z12, from_encoder = encoder_outputs 254 else: 255 z12 = encoder_outputs 256 257 if use_skip_connection: 258 from_encoder = from_encoder[::-1] 259 z9 = self.deconv1(from_encoder[0]) 260 z6 = self.deconv2(from_encoder[1]) 261 z3 = self.deconv3(from_encoder[2]) 262 z0 = self.deconv4(x) 263 264 else: 265 z9 = self.deconv1(z12) 266 z6 = self.deconv2(z9) 267 z3 = self.deconv3(z6) 268 z0 = self.deconv4(z3) 269 270 updated_from_encoder = [z9, z6, z3] 271 272 x = self.base(z12) 273 x = self.decoder(x, encoder_inputs=updated_from_encoder) 274 x = self.deconv_out(x) 275 276 x = torch.cat([x, z0], dim=1) 277 x = self.decoder_head(x) 278 279 x = self.out_conv(x) 280 if self.final_activation is not None: 281 x = self.final_activation(x) 282 283 x = self.postprocess_masks(x, input_shape, original_shape) 284 return x
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
292class SingleDeconv2DBlock(nn.Module): 293 def __init__(self, scale_factor, in_channels, out_channels): 294 super().__init__() 295 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 296 297 def forward(self, x): 298 return self.block(x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
293 def __init__(self, scale_factor, in_channels, out_channels): 294 super().__init__() 295 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
301class SingleConv2DBlock(nn.Module): 302 def __init__(self, in_channels, out_channels, kernel_size): 303 super().__init__() 304 self.block = nn.Conv2d( 305 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 306 ) 307 308 def forward(self, x): 309 return self.block(x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
302 def __init__(self, in_channels, out_channels, kernel_size): 303 super().__init__() 304 self.block = nn.Conv2d( 305 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 306 )
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
312class Conv2DBlock(nn.Module): 313 def __init__(self, in_channels, out_channels, kernel_size=3): 314 super().__init__() 315 self.block = nn.Sequential( 316 SingleConv2DBlock(in_channels, out_channels, kernel_size), 317 nn.BatchNorm2d(out_channels), 318 nn.ReLU(True) 319 ) 320 321 def forward(self, x): 322 return self.block(x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
313 def __init__(self, in_channels, out_channels, kernel_size=3): 314 super().__init__() 315 self.block = nn.Sequential( 316 SingleConv2DBlock(in_channels, out_channels, kernel_size), 317 nn.BatchNorm2d(out_channels), 318 nn.ReLU(True) 319 )
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
325class Deconv2DBlock(nn.Module): 326 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 327 super().__init__() 328 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 329 self.block = nn.Sequential( 330 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 331 SingleConv2DBlock(out_channels, out_channels, kernel_size), 332 nn.BatchNorm2d(out_channels), 333 nn.ReLU(True) 334 ) 335 336 def forward(self, x): 337 return self.block(x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
326 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 327 super().__init__() 328 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 329 self.block = nn.Sequential( 330 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 331 SingleConv2DBlock(out_channels, out_channels, kernel_size), 332 nn.BatchNorm2d(out_channels), 333 nn.ReLU(True) 334 )
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile