torch_em.model.unetr
1from collections import OrderedDict 2from typing import Optional, Tuple, Union 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7 8from .vit import get_vision_transformer 9from .unet import Decoder, ConvBlock2d, Upsampler2d 10 11try: 12 from micro_sam.util import get_sam_model 13except ImportError: 14 get_sam_model = None 15 16 17# 18# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / MAE / ScaleMAE) + UNet Decoder from `torch_em`] 19# 20 21 22class UNETR(nn.Module): 23 """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder. 24 25 Args: 26 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 27 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 28 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 29 decoder: The convolutional decoder. 30 out_channels: The number of output channels of the UNETR. 31 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 32 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 33 resize_input: Whether to resize the input images to match `img_size`. 34 By default, it resizes the inputs to match the `img_size`. 35 encoder_checkpoint: Checkpoint for initializing the vision transformer. 36 Can either be a filepath or an already loaded checkpoint. 37 final_activation: The activation to apply to the UNETR output. 38 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 39 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 40 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 41 By default, it uses resampling for upsampling. 42 """ 43 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 44 """Function to load pretrained weights to the image encoder. 45 """ 46 if isinstance(checkpoint, str): 47 if backbone == "sam" and isinstance(encoder, str): 48 # If we have a SAM encoder, then we first try to load the full SAM Model 49 # (using micro_sam) and otherwise fall back on directly loading the encoder state 50 # from the checkpoint 51 try: 52 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 53 encoder_state = model.image_encoder.state_dict() 54 except Exception: 55 # Try loading the encoder state directly from a checkpoint. 56 encoder_state = torch.load(checkpoint, weights_only=False) 57 58 elif backbone == "mae": 59 # vit initialization hints from: 60 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 61 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 62 encoder_state = OrderedDict({ 63 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 64 }) 65 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 66 current_encoder_state = self.encoder.state_dict() 67 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 68 del self.encoder.head 69 70 elif backbone == "scalemae": 71 # Load the encoder state directly from a checkpoint. 72 encoder_state = torch.load(checkpoint)["model"] 73 encoder_state = OrderedDict({ 74 k: v for k, v in encoder_state.items() 75 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 76 }) 77 78 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 79 current_encoder_state = self.encoder.state_dict() 80 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 81 del self.encoder.head 82 83 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 84 del self.encoder.pos_embed 85 86 else: 87 encoder_state = checkpoint 88 89 self.encoder.load_state_dict(encoder_state) 90 91 def __init__( 92 self, 93 img_size: int = 1024, 94 backbone: str = "sam", 95 encoder: Optional[Union[nn.Module, str]] = "vit_b", 96 decoder: Optional[nn.Module] = None, 97 out_channels: int = 1, 98 use_sam_stats: bool = False, 99 use_mae_stats: bool = False, 100 resize_input: bool = True, 101 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 102 final_activation: Optional[Union[str, nn.Module]] = None, 103 use_skip_connection: bool = True, 104 embed_dim: Optional[int] = None, 105 use_conv_transpose: bool = False, 106 **kwargs 107 ) -> None: 108 super().__init__() 109 110 self.use_sam_stats = use_sam_stats 111 self.use_mae_stats = use_mae_stats 112 self.use_skip_connection = use_skip_connection 113 self.resize_input = resize_input 114 115 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 116 print(f"Using {encoder} from {backbone.upper()}") 117 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 118 119 if encoder_checkpoint is not None: 120 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 121 122 in_chans = self.encoder.in_chans 123 if embed_dim is None: 124 embed_dim = self.encoder.embed_dim 125 126 else: # `nn.Module` ViT backbone 127 self.encoder = encoder 128 129 have_neck = False 130 for name, _ in self.encoder.named_parameters(): 131 if name.startswith("neck"): 132 have_neck = True 133 134 if embed_dim is None: 135 if have_neck: 136 embed_dim = self.encoder.neck[2].out_channels # the value is 256 137 else: 138 embed_dim = self.encoder.patch_embed.proj.out_channels 139 140 try: 141 in_chans = self.encoder.patch_embed.proj.in_channels 142 except AttributeError: # for getting the input channels while using vit_t from MobileSam 143 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 144 145 # parameters for the decoder network 146 depth = 3 147 initial_features = 64 148 gain = 2 149 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 150 scale_factors = depth * [2] 151 self.out_channels = out_channels 152 153 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 154 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 155 156 if decoder is None: 157 self.decoder = Decoder( 158 features=features_decoder, 159 scale_factors=scale_factors[::-1], 160 conv_block_impl=ConvBlock2d, 161 sampler_impl=_upsampler, 162 ) 163 else: 164 self.decoder = decoder 165 166 if use_skip_connection: 167 self.deconv1 = Deconv2DBlock( 168 in_channels=embed_dim, 169 out_channels=features_decoder[0], 170 use_conv_transpose=use_conv_transpose, 171 ) 172 self.deconv2 = nn.Sequential( 173 Deconv2DBlock( 174 in_channels=embed_dim, 175 out_channels=features_decoder[0], 176 use_conv_transpose=use_conv_transpose, 177 ), 178 Deconv2DBlock( 179 in_channels=features_decoder[0], 180 out_channels=features_decoder[1], 181 use_conv_transpose=use_conv_transpose, 182 ) 183 ) 184 self.deconv3 = nn.Sequential( 185 Deconv2DBlock( 186 in_channels=embed_dim, 187 out_channels=features_decoder[0], 188 use_conv_transpose=use_conv_transpose, 189 ), 190 Deconv2DBlock( 191 in_channels=features_decoder[0], 192 out_channels=features_decoder[1], 193 use_conv_transpose=use_conv_transpose, 194 ), 195 Deconv2DBlock( 196 in_channels=features_decoder[1], 197 out_channels=features_decoder[2], 198 use_conv_transpose=use_conv_transpose, 199 ) 200 ) 201 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 202 else: 203 self.deconv1 = Deconv2DBlock( 204 in_channels=embed_dim, 205 out_channels=features_decoder[0], 206 use_conv_transpose=use_conv_transpose, 207 ) 208 self.deconv2 = Deconv2DBlock( 209 in_channels=features_decoder[0], 210 out_channels=features_decoder[1], 211 use_conv_transpose=use_conv_transpose, 212 ) 213 self.deconv3 = Deconv2DBlock( 214 in_channels=features_decoder[1], 215 out_channels=features_decoder[2], 216 use_conv_transpose=use_conv_transpose, 217 ) 218 self.deconv4 = Deconv2DBlock( 219 in_channels=features_decoder[2], 220 out_channels=features_decoder[3], 221 use_conv_transpose=use_conv_transpose, 222 ) 223 224 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 225 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 226 self.deconv_out = _upsampler( 227 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 228 ) 229 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 230 self.final_activation = self._get_activation(final_activation) 231 232 def _get_activation(self, activation): 233 return_activation = None 234 if activation is None: 235 return None 236 if isinstance(activation, nn.Module): 237 return activation 238 if isinstance(activation, str): 239 return_activation = getattr(nn, activation, None) 240 if return_activation is None: 241 raise ValueError(f"Invalid activation: {activation}") 242 243 return return_activation() 244 245 @staticmethod 246 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 247 """Compute the output size given input size and target long side length. 248 249 Args: 250 oldh: The input image height. 251 oldw: The input image width. 252 long_side_length: The longest side length for resizing. 253 254 Returns: 255 The new image height. 256 The new image width. 257 """ 258 scale = long_side_length * 1.0 / max(oldh, oldw) 259 newh, neww = oldh * scale, oldw * scale 260 neww = int(neww + 0.5) 261 newh = int(newh + 0.5) 262 return (newh, neww) 263 264 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 265 """Resize the image so that the longest side has the correct length. 266 267 Expects batched images with shape BxCxHxW and float format. 268 269 Args: 270 image: The input image. 271 272 Returns: 273 The resized image. 274 """ 275 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 276 return F.interpolate( 277 image, target_size, mode="bilinear", align_corners=False, antialias=True 278 ) 279 280 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 281 """@private 282 """ 283 device = x.device 284 285 if self.use_sam_stats: 286 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 287 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 288 elif self.use_mae_stats: 289 # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 290 raise NotImplementedError 291 else: 292 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 293 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 294 295 if self.resize_input: 296 x = self.resize_longest_side(x) 297 input_shape = x.shape[-2:] 298 299 x = (x - pixel_mean) / pixel_std 300 h, w = x.shape[-2:] 301 padh = self.encoder.img_size - h 302 padw = self.encoder.img_size - w 303 x = F.pad(x, (0, padw, 0, padh)) 304 return x, input_shape 305 306 def postprocess_masks( 307 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 308 ) -> torch.Tensor: 309 """@private 310 """ 311 masks = F.interpolate( 312 masks, 313 (self.encoder.img_size, self.encoder.img_size), 314 mode="bilinear", 315 align_corners=False, 316 ) 317 masks = masks[..., : input_size[0], : input_size[1]] 318 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 319 return masks 320 321 def forward(self, x: torch.Tensor) -> torch.Tensor: 322 """Apply the UNETR to the input data. 323 324 Args: 325 x: The input tensor. 326 327 Returns: 328 The UNETR output. 329 """ 330 original_shape = x.shape[-2:] 331 332 # Reshape the inputs to the shape expected by the encoder 333 # and normalize the inputs if normalization is part of the model. 334 x, input_shape = self.preprocess(x) 335 336 use_skip_connection = getattr(self, "use_skip_connection", True) 337 338 encoder_outputs = self.encoder(x) 339 340 if isinstance(encoder_outputs[-1], list): 341 # `encoder_outputs` can be arranged in only two forms: 342 # - either we only return the image embeddings 343 # - or, we return the image embeddings and the "list" of global attention layers 344 z12, from_encoder = encoder_outputs 345 else: 346 z12 = encoder_outputs 347 348 if use_skip_connection: 349 from_encoder = from_encoder[::-1] 350 z9 = self.deconv1(from_encoder[0]) 351 z6 = self.deconv2(from_encoder[1]) 352 z3 = self.deconv3(from_encoder[2]) 353 z0 = self.deconv4(x) 354 355 else: 356 z9 = self.deconv1(z12) 357 z6 = self.deconv2(z9) 358 z3 = self.deconv3(z6) 359 z0 = self.deconv4(z3) 360 361 updated_from_encoder = [z9, z6, z3] 362 363 x = self.base(z12) 364 x = self.decoder(x, encoder_inputs=updated_from_encoder) 365 x = self.deconv_out(x) 366 367 x = torch.cat([x, z0], dim=1) 368 x = self.decoder_head(x) 369 370 x = self.out_conv(x) 371 if self.final_activation is not None: 372 x = self.final_activation(x) 373 374 x = self.postprocess_masks(x, input_shape, original_shape) 375 return x 376 377 378# 379# ADDITIONAL FUNCTIONALITIES 380# 381 382 383class SingleDeconv2DBlock(nn.Module): 384 """@private 385 """ 386 def __init__(self, scale_factor, in_channels, out_channels): 387 super().__init__() 388 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 389 390 def forward(self, x): 391 return self.block(x) 392 393 394class SingleConv2DBlock(nn.Module): 395 """@private 396 """ 397 def __init__(self, in_channels, out_channels, kernel_size): 398 super().__init__() 399 self.block = nn.Conv2d( 400 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 401 ) 402 403 def forward(self, x): 404 return self.block(x) 405 406 407class Conv2DBlock(nn.Module): 408 """@private 409 """ 410 def __init__(self, in_channels, out_channels, kernel_size=3): 411 super().__init__() 412 self.block = nn.Sequential( 413 SingleConv2DBlock(in_channels, out_channels, kernel_size), 414 nn.BatchNorm2d(out_channels), 415 nn.ReLU(True) 416 ) 417 418 def forward(self, x): 419 return self.block(x) 420 421 422class Deconv2DBlock(nn.Module): 423 """@private 424 """ 425 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 426 super().__init__() 427 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 428 self.block = nn.Sequential( 429 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 430 SingleConv2DBlock(out_channels, out_channels, kernel_size), 431 nn.BatchNorm2d(out_channels), 432 nn.ReLU(True) 433 ) 434 435 def forward(self, x): 436 return self.block(x)
class
UNETR(torch.nn.modules.module.Module):
23class UNETR(nn.Module): 24 """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder. 25 26 Args: 27 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 28 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 29 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 30 decoder: The convolutional decoder. 31 out_channels: The number of output channels of the UNETR. 32 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 33 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 34 resize_input: Whether to resize the input images to match `img_size`. 35 By default, it resizes the inputs to match the `img_size`. 36 encoder_checkpoint: Checkpoint for initializing the vision transformer. 37 Can either be a filepath or an already loaded checkpoint. 38 final_activation: The activation to apply to the UNETR output. 39 use_skip_connection: Whether to use skip connections. By default, it uses skip connections. 40 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 41 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 42 By default, it uses resampling for upsampling. 43 """ 44 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 45 """Function to load pretrained weights to the image encoder. 46 """ 47 if isinstance(checkpoint, str): 48 if backbone == "sam" and isinstance(encoder, str): 49 # If we have a SAM encoder, then we first try to load the full SAM Model 50 # (using micro_sam) and otherwise fall back on directly loading the encoder state 51 # from the checkpoint 52 try: 53 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 54 encoder_state = model.image_encoder.state_dict() 55 except Exception: 56 # Try loading the encoder state directly from a checkpoint. 57 encoder_state = torch.load(checkpoint, weights_only=False) 58 59 elif backbone == "mae": 60 # vit initialization hints from: 61 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 62 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 63 encoder_state = OrderedDict({ 64 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 65 }) 66 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 67 current_encoder_state = self.encoder.state_dict() 68 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 69 del self.encoder.head 70 71 elif backbone == "scalemae": 72 # Load the encoder state directly from a checkpoint. 73 encoder_state = torch.load(checkpoint)["model"] 74 encoder_state = OrderedDict({ 75 k: v for k, v in encoder_state.items() 76 if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed")) 77 }) 78 79 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 80 current_encoder_state = self.encoder.state_dict() 81 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 82 del self.encoder.head 83 84 if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format. 85 del self.encoder.pos_embed 86 87 else: 88 encoder_state = checkpoint 89 90 self.encoder.load_state_dict(encoder_state) 91 92 def __init__( 93 self, 94 img_size: int = 1024, 95 backbone: str = "sam", 96 encoder: Optional[Union[nn.Module, str]] = "vit_b", 97 decoder: Optional[nn.Module] = None, 98 out_channels: int = 1, 99 use_sam_stats: bool = False, 100 use_mae_stats: bool = False, 101 resize_input: bool = True, 102 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 103 final_activation: Optional[Union[str, nn.Module]] = None, 104 use_skip_connection: bool = True, 105 embed_dim: Optional[int] = None, 106 use_conv_transpose: bool = False, 107 **kwargs 108 ) -> None: 109 super().__init__() 110 111 self.use_sam_stats = use_sam_stats 112 self.use_mae_stats = use_mae_stats 113 self.use_skip_connection = use_skip_connection 114 self.resize_input = resize_input 115 116 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 117 print(f"Using {encoder} from {backbone.upper()}") 118 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 119 120 if encoder_checkpoint is not None: 121 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 122 123 in_chans = self.encoder.in_chans 124 if embed_dim is None: 125 embed_dim = self.encoder.embed_dim 126 127 else: # `nn.Module` ViT backbone 128 self.encoder = encoder 129 130 have_neck = False 131 for name, _ in self.encoder.named_parameters(): 132 if name.startswith("neck"): 133 have_neck = True 134 135 if embed_dim is None: 136 if have_neck: 137 embed_dim = self.encoder.neck[2].out_channels # the value is 256 138 else: 139 embed_dim = self.encoder.patch_embed.proj.out_channels 140 141 try: 142 in_chans = self.encoder.patch_embed.proj.in_channels 143 except AttributeError: # for getting the input channels while using vit_t from MobileSam 144 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 145 146 # parameters for the decoder network 147 depth = 3 148 initial_features = 64 149 gain = 2 150 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 151 scale_factors = depth * [2] 152 self.out_channels = out_channels 153 154 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 155 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 156 157 if decoder is None: 158 self.decoder = Decoder( 159 features=features_decoder, 160 scale_factors=scale_factors[::-1], 161 conv_block_impl=ConvBlock2d, 162 sampler_impl=_upsampler, 163 ) 164 else: 165 self.decoder = decoder 166 167 if use_skip_connection: 168 self.deconv1 = Deconv2DBlock( 169 in_channels=embed_dim, 170 out_channels=features_decoder[0], 171 use_conv_transpose=use_conv_transpose, 172 ) 173 self.deconv2 = nn.Sequential( 174 Deconv2DBlock( 175 in_channels=embed_dim, 176 out_channels=features_decoder[0], 177 use_conv_transpose=use_conv_transpose, 178 ), 179 Deconv2DBlock( 180 in_channels=features_decoder[0], 181 out_channels=features_decoder[1], 182 use_conv_transpose=use_conv_transpose, 183 ) 184 ) 185 self.deconv3 = nn.Sequential( 186 Deconv2DBlock( 187 in_channels=embed_dim, 188 out_channels=features_decoder[0], 189 use_conv_transpose=use_conv_transpose, 190 ), 191 Deconv2DBlock( 192 in_channels=features_decoder[0], 193 out_channels=features_decoder[1], 194 use_conv_transpose=use_conv_transpose, 195 ), 196 Deconv2DBlock( 197 in_channels=features_decoder[1], 198 out_channels=features_decoder[2], 199 use_conv_transpose=use_conv_transpose, 200 ) 201 ) 202 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 203 else: 204 self.deconv1 = Deconv2DBlock( 205 in_channels=embed_dim, 206 out_channels=features_decoder[0], 207 use_conv_transpose=use_conv_transpose, 208 ) 209 self.deconv2 = Deconv2DBlock( 210 in_channels=features_decoder[0], 211 out_channels=features_decoder[1], 212 use_conv_transpose=use_conv_transpose, 213 ) 214 self.deconv3 = Deconv2DBlock( 215 in_channels=features_decoder[1], 216 out_channels=features_decoder[2], 217 use_conv_transpose=use_conv_transpose, 218 ) 219 self.deconv4 = Deconv2DBlock( 220 in_channels=features_decoder[2], 221 out_channels=features_decoder[3], 222 use_conv_transpose=use_conv_transpose, 223 ) 224 225 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 226 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 227 self.deconv_out = _upsampler( 228 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 229 ) 230 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 231 self.final_activation = self._get_activation(final_activation) 232 233 def _get_activation(self, activation): 234 return_activation = None 235 if activation is None: 236 return None 237 if isinstance(activation, nn.Module): 238 return activation 239 if isinstance(activation, str): 240 return_activation = getattr(nn, activation, None) 241 if return_activation is None: 242 raise ValueError(f"Invalid activation: {activation}") 243 244 return return_activation() 245 246 @staticmethod 247 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 248 """Compute the output size given input size and target long side length. 249 250 Args: 251 oldh: The input image height. 252 oldw: The input image width. 253 long_side_length: The longest side length for resizing. 254 255 Returns: 256 The new image height. 257 The new image width. 258 """ 259 scale = long_side_length * 1.0 / max(oldh, oldw) 260 newh, neww = oldh * scale, oldw * scale 261 neww = int(neww + 0.5) 262 newh = int(newh + 0.5) 263 return (newh, neww) 264 265 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 266 """Resize the image so that the longest side has the correct length. 267 268 Expects batched images with shape BxCxHxW and float format. 269 270 Args: 271 image: The input image. 272 273 Returns: 274 The resized image. 275 """ 276 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 277 return F.interpolate( 278 image, target_size, mode="bilinear", align_corners=False, antialias=True 279 ) 280 281 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 282 """@private 283 """ 284 device = x.device 285 286 if self.use_sam_stats: 287 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 288 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 289 elif self.use_mae_stats: 290 # TODO: add mean std from mae / scalemae experiments (or open up arguments for this) 291 raise NotImplementedError 292 else: 293 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 294 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 295 296 if self.resize_input: 297 x = self.resize_longest_side(x) 298 input_shape = x.shape[-2:] 299 300 x = (x - pixel_mean) / pixel_std 301 h, w = x.shape[-2:] 302 padh = self.encoder.img_size - h 303 padw = self.encoder.img_size - w 304 x = F.pad(x, (0, padw, 0, padh)) 305 return x, input_shape 306 307 def postprocess_masks( 308 self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], 309 ) -> torch.Tensor: 310 """@private 311 """ 312 masks = F.interpolate( 313 masks, 314 (self.encoder.img_size, self.encoder.img_size), 315 mode="bilinear", 316 align_corners=False, 317 ) 318 masks = masks[..., : input_size[0], : input_size[1]] 319 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 320 return masks 321 322 def forward(self, x: torch.Tensor) -> torch.Tensor: 323 """Apply the UNETR to the input data. 324 325 Args: 326 x: The input tensor. 327 328 Returns: 329 The UNETR output. 330 """ 331 original_shape = x.shape[-2:] 332 333 # Reshape the inputs to the shape expected by the encoder 334 # and normalize the inputs if normalization is part of the model. 335 x, input_shape = self.preprocess(x) 336 337 use_skip_connection = getattr(self, "use_skip_connection", True) 338 339 encoder_outputs = self.encoder(x) 340 341 if isinstance(encoder_outputs[-1], list): 342 # `encoder_outputs` can be arranged in only two forms: 343 # - either we only return the image embeddings 344 # - or, we return the image embeddings and the "list" of global attention layers 345 z12, from_encoder = encoder_outputs 346 else: 347 z12 = encoder_outputs 348 349 if use_skip_connection: 350 from_encoder = from_encoder[::-1] 351 z9 = self.deconv1(from_encoder[0]) 352 z6 = self.deconv2(from_encoder[1]) 353 z3 = self.deconv3(from_encoder[2]) 354 z0 = self.deconv4(x) 355 356 else: 357 z9 = self.deconv1(z12) 358 z6 = self.deconv2(z9) 359 z3 = self.deconv3(z6) 360 z0 = self.deconv4(z3) 361 362 updated_from_encoder = [z9, z6, z3] 363 364 x = self.base(z12) 365 x = self.decoder(x, encoder_inputs=updated_from_encoder) 366 x = self.deconv_out(x) 367 368 x = torch.cat([x, z0], dim=1) 369 x = self.decoder_head(x) 370 371 x = self.out_conv(x) 372 if self.final_activation is not None: 373 x = self.final_activation(x) 374 375 x = self.postprocess_masks(x, input_shape, original_shape) 376 return x
A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam" or "mae".
- encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size
. By default, it resizes the inputs to match theimg_size
. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections. By default, it uses skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. By default, it uses resampling for upsampling.
UNETR( img_size: int = 1024, backbone: str = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = False, **kwargs)
92 def __init__( 93 self, 94 img_size: int = 1024, 95 backbone: str = "sam", 96 encoder: Optional[Union[nn.Module, str]] = "vit_b", 97 decoder: Optional[nn.Module] = None, 98 out_channels: int = 1, 99 use_sam_stats: bool = False, 100 use_mae_stats: bool = False, 101 resize_input: bool = True, 102 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 103 final_activation: Optional[Union[str, nn.Module]] = None, 104 use_skip_connection: bool = True, 105 embed_dim: Optional[int] = None, 106 use_conv_transpose: bool = False, 107 **kwargs 108 ) -> None: 109 super().__init__() 110 111 self.use_sam_stats = use_sam_stats 112 self.use_mae_stats = use_mae_stats 113 self.use_skip_connection = use_skip_connection 114 self.resize_input = resize_input 115 116 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 117 print(f"Using {encoder} from {backbone.upper()}") 118 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder, **kwargs) 119 120 if encoder_checkpoint is not None: 121 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 122 123 in_chans = self.encoder.in_chans 124 if embed_dim is None: 125 embed_dim = self.encoder.embed_dim 126 127 else: # `nn.Module` ViT backbone 128 self.encoder = encoder 129 130 have_neck = False 131 for name, _ in self.encoder.named_parameters(): 132 if name.startswith("neck"): 133 have_neck = True 134 135 if embed_dim is None: 136 if have_neck: 137 embed_dim = self.encoder.neck[2].out_channels # the value is 256 138 else: 139 embed_dim = self.encoder.patch_embed.proj.out_channels 140 141 try: 142 in_chans = self.encoder.patch_embed.proj.in_channels 143 except AttributeError: # for getting the input channels while using vit_t from MobileSam 144 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 145 146 # parameters for the decoder network 147 depth = 3 148 initial_features = 64 149 gain = 2 150 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 151 scale_factors = depth * [2] 152 self.out_channels = out_channels 153 154 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 155 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 156 157 if decoder is None: 158 self.decoder = Decoder( 159 features=features_decoder, 160 scale_factors=scale_factors[::-1], 161 conv_block_impl=ConvBlock2d, 162 sampler_impl=_upsampler, 163 ) 164 else: 165 self.decoder = decoder 166 167 if use_skip_connection: 168 self.deconv1 = Deconv2DBlock( 169 in_channels=embed_dim, 170 out_channels=features_decoder[0], 171 use_conv_transpose=use_conv_transpose, 172 ) 173 self.deconv2 = nn.Sequential( 174 Deconv2DBlock( 175 in_channels=embed_dim, 176 out_channels=features_decoder[0], 177 use_conv_transpose=use_conv_transpose, 178 ), 179 Deconv2DBlock( 180 in_channels=features_decoder[0], 181 out_channels=features_decoder[1], 182 use_conv_transpose=use_conv_transpose, 183 ) 184 ) 185 self.deconv3 = nn.Sequential( 186 Deconv2DBlock( 187 in_channels=embed_dim, 188 out_channels=features_decoder[0], 189 use_conv_transpose=use_conv_transpose, 190 ), 191 Deconv2DBlock( 192 in_channels=features_decoder[0], 193 out_channels=features_decoder[1], 194 use_conv_transpose=use_conv_transpose, 195 ), 196 Deconv2DBlock( 197 in_channels=features_decoder[1], 198 out_channels=features_decoder[2], 199 use_conv_transpose=use_conv_transpose, 200 ) 201 ) 202 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 203 else: 204 self.deconv1 = Deconv2DBlock( 205 in_channels=embed_dim, 206 out_channels=features_decoder[0], 207 use_conv_transpose=use_conv_transpose, 208 ) 209 self.deconv2 = Deconv2DBlock( 210 in_channels=features_decoder[0], 211 out_channels=features_decoder[1], 212 use_conv_transpose=use_conv_transpose, 213 ) 214 self.deconv3 = Deconv2DBlock( 215 in_channels=features_decoder[1], 216 out_channels=features_decoder[2], 217 use_conv_transpose=use_conv_transpose, 218 ) 219 self.deconv4 = Deconv2DBlock( 220 in_channels=features_decoder[2], 221 out_channels=features_decoder[3], 222 use_conv_transpose=use_conv_transpose, 223 ) 224 225 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 226 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 227 self.deconv_out = _upsampler( 228 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 229 ) 230 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 231 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
@staticmethod
def
get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
246 @staticmethod 247 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 248 """Compute the output size given input size and target long side length. 249 250 Args: 251 oldh: The input image height. 252 oldw: The input image width. 253 long_side_length: The longest side length for resizing. 254 255 Returns: 256 The new image height. 257 The new image width. 258 """ 259 scale = long_side_length * 1.0 / max(oldh, oldw) 260 newh, neww = oldh * scale, oldw * scale 261 neww = int(neww + 0.5) 262 newh = int(newh + 0.5) 263 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
def
resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
265 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 266 """Resize the image so that the longest side has the correct length. 267 268 Expects batched images with shape BxCxHxW and float format. 269 270 Args: 271 image: The input image. 272 273 Returns: 274 The resized image. 275 """ 276 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 277 return F.interpolate( 278 image, target_size, mode="bilinear", align_corners=False, antialias=True 279 )
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
def
forward(self, x: torch.Tensor) -> torch.Tensor:
322 def forward(self, x: torch.Tensor) -> torch.Tensor: 323 """Apply the UNETR to the input data. 324 325 Args: 326 x: The input tensor. 327 328 Returns: 329 The UNETR output. 330 """ 331 original_shape = x.shape[-2:] 332 333 # Reshape the inputs to the shape expected by the encoder 334 # and normalize the inputs if normalization is part of the model. 335 x, input_shape = self.preprocess(x) 336 337 use_skip_connection = getattr(self, "use_skip_connection", True) 338 339 encoder_outputs = self.encoder(x) 340 341 if isinstance(encoder_outputs[-1], list): 342 # `encoder_outputs` can be arranged in only two forms: 343 # - either we only return the image embeddings 344 # - or, we return the image embeddings and the "list" of global attention layers 345 z12, from_encoder = encoder_outputs 346 else: 347 z12 = encoder_outputs 348 349 if use_skip_connection: 350 from_encoder = from_encoder[::-1] 351 z9 = self.deconv1(from_encoder[0]) 352 z6 = self.deconv2(from_encoder[1]) 353 z3 = self.deconv3(from_encoder[2]) 354 z0 = self.deconv4(x) 355 356 else: 357 z9 = self.deconv1(z12) 358 z6 = self.deconv2(z9) 359 z3 = self.deconv3(z6) 360 z0 = self.deconv4(z3) 361 362 updated_from_encoder = [z9, z6, z3] 363 364 x = self.base(z12) 365 x = self.decoder(x, encoder_inputs=updated_from_encoder) 366 x = self.deconv_out(x) 367 368 x = torch.cat([x, z0], dim=1) 369 x = self.decoder_head(x) 370 371 x = self.out_conv(x) 372 if self.final_activation is not None: 373 x = self.final_activation(x) 374 375 x = self.postprocess_masks(x, input_shape, original_shape) 376 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.