torch_em.model.unetr
1from collections import OrderedDict 2from typing import Optional, Tuple, Union 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7 8from .vit import get_vision_transformer 9from .unet import Decoder, ConvBlock2d, Upsampler2d 10 11try: 12 from micro_sam.util import get_sam_model 13except ImportError: 14 get_sam_model = None 15 16 17# 18# UNETR IMPLEMENTATION [Vision Transformer (ViT from MAE / ViT from SAM) + UNet Decoder from `torch_em`] 19# 20 21 22class UNETR(nn.Module): 23 """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder. 24 25 Args: 26 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 27 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 28 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 29 decoder: The convolutional decoder. 30 out_channels: The number of output channels of the UNETR. 31 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 32 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 33 resize_input: Whether to resize the input images to match `img_size`. 34 encoder_checkpoint: Checkpoint for initializing the vision transformer. 35 Can either be a filepath or an already loaded checkpoint. 36 final_activation: The activation to apply to the UNETR output. 37 use_skip_connection: Whether to use skip connections. 38 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 39 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 40 """ 41 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 42 """Function to load pretrained weights to the image encoder. 43 """ 44 if isinstance(checkpoint, str): 45 if backbone == "sam" and isinstance(encoder, str): 46 # If we have a SAM encoder, then we first try to load the full SAM Model 47 # (using micro_sam) and otherwise fall back on directly loading the encoder state 48 # from the checkpoint 49 try: 50 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 51 encoder_state = model.image_encoder.state_dict() 52 except Exception: 53 # Try loading the encoder state directly from a checkpoint. 54 encoder_state = torch.load(checkpoint, weights_only=False) 55 56 elif backbone == "mae": 57 # vit initialization hints from: 58 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 59 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 60 encoder_state = OrderedDict({ 61 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 62 }) 63 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 64 current_encoder_state = self.encoder.state_dict() 65 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 66 del self.encoder.head 67 68 else: 69 encoder_state = checkpoint 70 71 self.encoder.load_state_dict(encoder_state) 72 73 def __init__( 74 self, 75 img_size: int = 1024, 76 backbone: str = "sam", 77 encoder: Optional[Union[nn.Module, str]] = "vit_b", 78 decoder: Optional[nn.Module] = None, 79 out_channels: int = 1, 80 use_sam_stats: bool = False, 81 use_mae_stats: bool = False, 82 resize_input: bool = True, 83 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 84 final_activation: Optional[Union[str, nn.Module]] = None, 85 use_skip_connection: bool = True, 86 embed_dim: Optional[int] = None, 87 use_conv_transpose: bool = True, 88 ) -> None: 89 super().__init__() 90 91 self.use_sam_stats = use_sam_stats 92 self.use_mae_stats = use_mae_stats 93 self.use_skip_connection = use_skip_connection 94 self.resize_input = resize_input 95 96 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 97 print(f"Using {encoder} from {backbone.upper()}") 98 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) 99 if encoder_checkpoint is not None: 100 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 101 102 in_chans = self.encoder.in_chans 103 if embed_dim is None: 104 embed_dim = self.encoder.embed_dim 105 106 else: # `nn.Module` ViT backbone 107 self.encoder = encoder 108 109 have_neck = False 110 for name, _ in self.encoder.named_parameters(): 111 if name.startswith("neck"): 112 have_neck = True 113 114 if embed_dim is None: 115 if have_neck: 116 embed_dim = self.encoder.neck[2].out_channels # the value is 256 117 else: 118 embed_dim = self.encoder.patch_embed.proj.out_channels 119 120 try: 121 in_chans = self.encoder.patch_embed.proj.in_channels 122 except AttributeError: # for getting the input channels while using vit_t from MobileSam 123 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 124 125 # parameters for the decoder network 126 depth = 3 127 initial_features = 64 128 gain = 2 129 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 130 scale_factors = depth * [2] 131 self.out_channels = out_channels 132 133 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 134 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 135 136 if decoder is None: 137 self.decoder = Decoder( 138 features=features_decoder, 139 scale_factors=scale_factors[::-1], 140 conv_block_impl=ConvBlock2d, 141 sampler_impl=_upsampler, 142 ) 143 else: 144 self.decoder = decoder 145 146 if use_skip_connection: 147 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 148 self.deconv2 = nn.Sequential( 149 Deconv2DBlock(embed_dim, features_decoder[0]), 150 Deconv2DBlock(features_decoder[0], features_decoder[1]) 151 ) 152 self.deconv3 = nn.Sequential( 153 Deconv2DBlock(embed_dim, features_decoder[0]), 154 Deconv2DBlock(features_decoder[0], features_decoder[1]), 155 Deconv2DBlock(features_decoder[1], features_decoder[2]) 156 ) 157 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 158 else: 159 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 160 self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) 161 self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) 162 self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) 163 164 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 165 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 166 self.deconv_out = _upsampler( 167 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 168 ) 169 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 170 self.final_activation = self._get_activation(final_activation) 171 172 def _get_activation(self, activation): 173 return_activation = None 174 if activation is None: 175 return None 176 if isinstance(activation, nn.Module): 177 return activation 178 if isinstance(activation, str): 179 return_activation = getattr(nn, activation, None) 180 if return_activation is None: 181 raise ValueError(f"Invalid activation: {activation}") 182 return return_activation() 183 184 @staticmethod 185 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 186 """Compute the output size given input size and target long side length. 187 188 Args: 189 oldh: The input image height. 190 oldw: The input image width. 191 long_side_length: The longest side length for resizing. 192 193 Returns: 194 The new image height. 195 The new image width. 196 """ 197 scale = long_side_length * 1.0 / max(oldh, oldw) 198 newh, neww = oldh * scale, oldw * scale 199 neww = int(neww + 0.5) 200 newh = int(newh + 0.5) 201 return (newh, neww) 202 203 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 204 """Resize the image so that the longest side has the correct length. 205 206 Expects batched images with shape BxCxHxW and float format. 207 208 Args: 209 image: The input image. 210 211 Returns: 212 The resized image. 213 """ 214 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 215 return F.interpolate( 216 image, target_size, mode="bilinear", align_corners=False, antialias=True 217 ) 218 219 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 220 """@private 221 """ 222 device = x.device 223 224 if self.use_sam_stats: 225 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 226 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 227 elif self.use_mae_stats: 228 # TODO: add mean std from mae experiments (or open up arguments for this) 229 raise NotImplementedError 230 else: 231 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 232 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 233 234 if self.resize_input: 235 x = self.resize_longest_side(x) 236 input_shape = x.shape[-2:] 237 238 x = (x - pixel_mean) / pixel_std 239 h, w = x.shape[-2:] 240 padh = self.encoder.img_size - h 241 padw = self.encoder.img_size - w 242 x = F.pad(x, (0, padw, 0, padh)) 243 return x, input_shape 244 245 def postprocess_masks( 246 self, 247 masks: torch.Tensor, 248 input_size: Tuple[int, ...], 249 original_size: Tuple[int, ...], 250 ) -> torch.Tensor: 251 """@private 252 """ 253 masks = F.interpolate( 254 masks, 255 (self.encoder.img_size, self.encoder.img_size), 256 mode="bilinear", 257 align_corners=False, 258 ) 259 masks = masks[..., : input_size[0], : input_size[1]] 260 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 261 return masks 262 263 def forward(self, x: torch.Tensor) -> torch.Tensor: 264 """Apply the UNETR to the input data. 265 266 Args: 267 x: The input tensor. 268 269 Returns: 270 The UNETR output. 271 """ 272 original_shape = x.shape[-2:] 273 274 # Reshape the inputs to the shape expected by the encoder 275 # and normalize the inputs if normalization is part of the model. 276 x, input_shape = self.preprocess(x) 277 278 use_skip_connection = getattr(self, "use_skip_connection", True) 279 280 encoder_outputs = self.encoder(x) 281 282 if isinstance(encoder_outputs[-1], list): 283 # `encoder_outputs` can be arranged in only two forms: 284 # - either we only return the image embeddings 285 # - or, we return the image embeddings and the "list" of global attention layers 286 z12, from_encoder = encoder_outputs 287 else: 288 z12 = encoder_outputs 289 290 if use_skip_connection: 291 from_encoder = from_encoder[::-1] 292 z9 = self.deconv1(from_encoder[0]) 293 z6 = self.deconv2(from_encoder[1]) 294 z3 = self.deconv3(from_encoder[2]) 295 z0 = self.deconv4(x) 296 297 else: 298 z9 = self.deconv1(z12) 299 z6 = self.deconv2(z9) 300 z3 = self.deconv3(z6) 301 z0 = self.deconv4(z3) 302 303 updated_from_encoder = [z9, z6, z3] 304 305 x = self.base(z12) 306 x = self.decoder(x, encoder_inputs=updated_from_encoder) 307 x = self.deconv_out(x) 308 309 x = torch.cat([x, z0], dim=1) 310 x = self.decoder_head(x) 311 312 x = self.out_conv(x) 313 if self.final_activation is not None: 314 x = self.final_activation(x) 315 316 x = self.postprocess_masks(x, input_shape, original_shape) 317 return x 318 319 320# 321# ADDITIONAL FUNCTIONALITIES 322# 323 324 325class SingleDeconv2DBlock(nn.Module): 326 """@private 327 """ 328 def __init__(self, scale_factor, in_channels, out_channels): 329 super().__init__() 330 self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) 331 332 def forward(self, x): 333 return self.block(x) 334 335 336class SingleConv2DBlock(nn.Module): 337 """@private 338 """ 339 def __init__(self, in_channels, out_channels, kernel_size): 340 super().__init__() 341 self.block = nn.Conv2d( 342 in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) 343 ) 344 345 def forward(self, x): 346 return self.block(x) 347 348 349class Conv2DBlock(nn.Module): 350 """@private 351 """ 352 def __init__(self, in_channels, out_channels, kernel_size=3): 353 super().__init__() 354 self.block = nn.Sequential( 355 SingleConv2DBlock(in_channels, out_channels, kernel_size), 356 nn.BatchNorm2d(out_channels), 357 nn.ReLU(True) 358 ) 359 360 def forward(self, x): 361 return self.block(x) 362 363 364class Deconv2DBlock(nn.Module): 365 """@private 366 """ 367 def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): 368 super().__init__() 369 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 370 self.block = nn.Sequential( 371 _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), 372 SingleConv2DBlock(out_channels, out_channels, kernel_size), 373 nn.BatchNorm2d(out_channels), 374 nn.ReLU(True) 375 ) 376 377 def forward(self, x): 378 return self.block(x)
class
UNETR(torch.nn.modules.module.Module):
23class UNETR(nn.Module): 24 """A U-Net Transformer using a vision transformer as encoder and a convolutional decoder. 25 26 Args: 27 img_size: The size of the input for the image encoder. Input images will be resized to match this size. 28 backbone: The name of the vision transformer implementation. One of "sam" or "mae". 29 encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module. 30 decoder: The convolutional decoder. 31 out_channels: The number of output channels of the UNETR. 32 use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model. 33 use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model. 34 resize_input: Whether to resize the input images to match `img_size`. 35 encoder_checkpoint: Checkpoint for initializing the vision transformer. 36 Can either be a filepath or an already loaded checkpoint. 37 final_activation: The activation to apply to the UNETR output. 38 use_skip_connection: Whether to use skip connections. 39 embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer. 40 use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling. 41 """ 42 def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): 43 """Function to load pretrained weights to the image encoder. 44 """ 45 if isinstance(checkpoint, str): 46 if backbone == "sam" and isinstance(encoder, str): 47 # If we have a SAM encoder, then we first try to load the full SAM Model 48 # (using micro_sam) and otherwise fall back on directly loading the encoder state 49 # from the checkpoint 50 try: 51 _, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True) 52 encoder_state = model.image_encoder.state_dict() 53 except Exception: 54 # Try loading the encoder state directly from a checkpoint. 55 encoder_state = torch.load(checkpoint, weights_only=False) 56 57 elif backbone == "mae": 58 # vit initialization hints from: 59 # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 60 encoder_state = torch.load(checkpoint, weights_only=False)["model"] 61 encoder_state = OrderedDict({ 62 k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder")) 63 }) 64 # Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) 65 current_encoder_state = self.encoder.state_dict() 66 if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): 67 del self.encoder.head 68 69 else: 70 encoder_state = checkpoint 71 72 self.encoder.load_state_dict(encoder_state) 73 74 def __init__( 75 self, 76 img_size: int = 1024, 77 backbone: str = "sam", 78 encoder: Optional[Union[nn.Module, str]] = "vit_b", 79 decoder: Optional[nn.Module] = None, 80 out_channels: int = 1, 81 use_sam_stats: bool = False, 82 use_mae_stats: bool = False, 83 resize_input: bool = True, 84 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 85 final_activation: Optional[Union[str, nn.Module]] = None, 86 use_skip_connection: bool = True, 87 embed_dim: Optional[int] = None, 88 use_conv_transpose: bool = True, 89 ) -> None: 90 super().__init__() 91 92 self.use_sam_stats = use_sam_stats 93 self.use_mae_stats = use_mae_stats 94 self.use_skip_connection = use_skip_connection 95 self.resize_input = resize_input 96 97 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 98 print(f"Using {encoder} from {backbone.upper()}") 99 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) 100 if encoder_checkpoint is not None: 101 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 102 103 in_chans = self.encoder.in_chans 104 if embed_dim is None: 105 embed_dim = self.encoder.embed_dim 106 107 else: # `nn.Module` ViT backbone 108 self.encoder = encoder 109 110 have_neck = False 111 for name, _ in self.encoder.named_parameters(): 112 if name.startswith("neck"): 113 have_neck = True 114 115 if embed_dim is None: 116 if have_neck: 117 embed_dim = self.encoder.neck[2].out_channels # the value is 256 118 else: 119 embed_dim = self.encoder.patch_embed.proj.out_channels 120 121 try: 122 in_chans = self.encoder.patch_embed.proj.in_channels 123 except AttributeError: # for getting the input channels while using vit_t from MobileSam 124 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 125 126 # parameters for the decoder network 127 depth = 3 128 initial_features = 64 129 gain = 2 130 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 131 scale_factors = depth * [2] 132 self.out_channels = out_channels 133 134 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 135 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 136 137 if decoder is None: 138 self.decoder = Decoder( 139 features=features_decoder, 140 scale_factors=scale_factors[::-1], 141 conv_block_impl=ConvBlock2d, 142 sampler_impl=_upsampler, 143 ) 144 else: 145 self.decoder = decoder 146 147 if use_skip_connection: 148 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 149 self.deconv2 = nn.Sequential( 150 Deconv2DBlock(embed_dim, features_decoder[0]), 151 Deconv2DBlock(features_decoder[0], features_decoder[1]) 152 ) 153 self.deconv3 = nn.Sequential( 154 Deconv2DBlock(embed_dim, features_decoder[0]), 155 Deconv2DBlock(features_decoder[0], features_decoder[1]), 156 Deconv2DBlock(features_decoder[1], features_decoder[2]) 157 ) 158 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 159 else: 160 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 161 self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) 162 self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) 163 self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) 164 165 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 166 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 167 self.deconv_out = _upsampler( 168 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 169 ) 170 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 171 self.final_activation = self._get_activation(final_activation) 172 173 def _get_activation(self, activation): 174 return_activation = None 175 if activation is None: 176 return None 177 if isinstance(activation, nn.Module): 178 return activation 179 if isinstance(activation, str): 180 return_activation = getattr(nn, activation, None) 181 if return_activation is None: 182 raise ValueError(f"Invalid activation: {activation}") 183 return return_activation() 184 185 @staticmethod 186 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 187 """Compute the output size given input size and target long side length. 188 189 Args: 190 oldh: The input image height. 191 oldw: The input image width. 192 long_side_length: The longest side length for resizing. 193 194 Returns: 195 The new image height. 196 The new image width. 197 """ 198 scale = long_side_length * 1.0 / max(oldh, oldw) 199 newh, neww = oldh * scale, oldw * scale 200 neww = int(neww + 0.5) 201 newh = int(newh + 0.5) 202 return (newh, neww) 203 204 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 205 """Resize the image so that the longest side has the correct length. 206 207 Expects batched images with shape BxCxHxW and float format. 208 209 Args: 210 image: The input image. 211 212 Returns: 213 The resized image. 214 """ 215 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 216 return F.interpolate( 217 image, target_size, mode="bilinear", align_corners=False, antialias=True 218 ) 219 220 def preprocess(self, x: torch.Tensor) -> torch.Tensor: 221 """@private 222 """ 223 device = x.device 224 225 if self.use_sam_stats: 226 pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device) 227 pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device) 228 elif self.use_mae_stats: 229 # TODO: add mean std from mae experiments (or open up arguments for this) 230 raise NotImplementedError 231 else: 232 pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device) 233 pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device) 234 235 if self.resize_input: 236 x = self.resize_longest_side(x) 237 input_shape = x.shape[-2:] 238 239 x = (x - pixel_mean) / pixel_std 240 h, w = x.shape[-2:] 241 padh = self.encoder.img_size - h 242 padw = self.encoder.img_size - w 243 x = F.pad(x, (0, padw, 0, padh)) 244 return x, input_shape 245 246 def postprocess_masks( 247 self, 248 masks: torch.Tensor, 249 input_size: Tuple[int, ...], 250 original_size: Tuple[int, ...], 251 ) -> torch.Tensor: 252 """@private 253 """ 254 masks = F.interpolate( 255 masks, 256 (self.encoder.img_size, self.encoder.img_size), 257 mode="bilinear", 258 align_corners=False, 259 ) 260 masks = masks[..., : input_size[0], : input_size[1]] 261 masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 262 return masks 263 264 def forward(self, x: torch.Tensor) -> torch.Tensor: 265 """Apply the UNETR to the input data. 266 267 Args: 268 x: The input tensor. 269 270 Returns: 271 The UNETR output. 272 """ 273 original_shape = x.shape[-2:] 274 275 # Reshape the inputs to the shape expected by the encoder 276 # and normalize the inputs if normalization is part of the model. 277 x, input_shape = self.preprocess(x) 278 279 use_skip_connection = getattr(self, "use_skip_connection", True) 280 281 encoder_outputs = self.encoder(x) 282 283 if isinstance(encoder_outputs[-1], list): 284 # `encoder_outputs` can be arranged in only two forms: 285 # - either we only return the image embeddings 286 # - or, we return the image embeddings and the "list" of global attention layers 287 z12, from_encoder = encoder_outputs 288 else: 289 z12 = encoder_outputs 290 291 if use_skip_connection: 292 from_encoder = from_encoder[::-1] 293 z9 = self.deconv1(from_encoder[0]) 294 z6 = self.deconv2(from_encoder[1]) 295 z3 = self.deconv3(from_encoder[2]) 296 z0 = self.deconv4(x) 297 298 else: 299 z9 = self.deconv1(z12) 300 z6 = self.deconv2(z9) 301 z3 = self.deconv3(z6) 302 z0 = self.deconv4(z3) 303 304 updated_from_encoder = [z9, z6, z3] 305 306 x = self.base(z12) 307 x = self.decoder(x, encoder_inputs=updated_from_encoder) 308 x = self.deconv_out(x) 309 310 x = torch.cat([x, z0], dim=1) 311 x = self.decoder_head(x) 312 313 x = self.out_conv(x) 314 if self.final_activation is not None: 315 x = self.final_activation(x) 316 317 x = self.postprocess_masks(x, input_shape, original_shape) 318 return x
A U-Net Transformer using a vision transformer as encoder and a convolutional decoder.
Arguments:
- img_size: The size of the input for the image encoder. Input images will be resized to match this size.
- backbone: The name of the vision transformer implementation. One of "sam" or "mae".
- encoder: The vision transformer. Can either be a name, such as "vit_b" or a torch module.
- decoder: The convolutional decoder.
- out_channels: The number of output channels of the UNETR.
- use_sam_stats: Whether to normalize the input data with the statistics of the pretrained SAM model.
- use_mae_stats: Whether to normalize the input data with the statistics of the pretrained MAE model.
- resize_input: Whether to resize the input images to match
img_size
. - encoder_checkpoint: Checkpoint for initializing the vision transformer. Can either be a filepath or an already loaded checkpoint.
- final_activation: The activation to apply to the UNETR output.
- use_skip_connection: Whether to use skip connections.
- embed_dim: The embedding dimensionality, corresponding to the output dimension of the vision transformer.
- use_conv_transpose: Whether to use transposed convolutions instead of resampling for upsampling.
UNETR( img_size: int = 1024, backbone: str = 'sam', encoder: Union[torch.nn.modules.module.Module, str, NoneType] = 'vit_b', decoder: Optional[torch.nn.modules.module.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, resize_input: bool = True, encoder_checkpoint: Union[str, collections.OrderedDict, NoneType] = None, final_activation: Union[torch.nn.modules.module.Module, str, NoneType] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, use_conv_transpose: bool = True)
74 def __init__( 75 self, 76 img_size: int = 1024, 77 backbone: str = "sam", 78 encoder: Optional[Union[nn.Module, str]] = "vit_b", 79 decoder: Optional[nn.Module] = None, 80 out_channels: int = 1, 81 use_sam_stats: bool = False, 82 use_mae_stats: bool = False, 83 resize_input: bool = True, 84 encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, 85 final_activation: Optional[Union[str, nn.Module]] = None, 86 use_skip_connection: bool = True, 87 embed_dim: Optional[int] = None, 88 use_conv_transpose: bool = True, 89 ) -> None: 90 super().__init__() 91 92 self.use_sam_stats = use_sam_stats 93 self.use_mae_stats = use_mae_stats 94 self.use_skip_connection = use_skip_connection 95 self.resize_input = resize_input 96 97 if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" 98 print(f"Using {encoder} from {backbone.upper()}") 99 self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) 100 if encoder_checkpoint is not None: 101 self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) 102 103 in_chans = self.encoder.in_chans 104 if embed_dim is None: 105 embed_dim = self.encoder.embed_dim 106 107 else: # `nn.Module` ViT backbone 108 self.encoder = encoder 109 110 have_neck = False 111 for name, _ in self.encoder.named_parameters(): 112 if name.startswith("neck"): 113 have_neck = True 114 115 if embed_dim is None: 116 if have_neck: 117 embed_dim = self.encoder.neck[2].out_channels # the value is 256 118 else: 119 embed_dim = self.encoder.patch_embed.proj.out_channels 120 121 try: 122 in_chans = self.encoder.patch_embed.proj.in_channels 123 except AttributeError: # for getting the input channels while using vit_t from MobileSam 124 in_chans = self.encoder.patch_embed.seq[0].c.in_channels 125 126 # parameters for the decoder network 127 depth = 3 128 initial_features = 64 129 gain = 2 130 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 131 scale_factors = depth * [2] 132 self.out_channels = out_channels 133 134 # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose 135 _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d 136 137 if decoder is None: 138 self.decoder = Decoder( 139 features=features_decoder, 140 scale_factors=scale_factors[::-1], 141 conv_block_impl=ConvBlock2d, 142 sampler_impl=_upsampler, 143 ) 144 else: 145 self.decoder = decoder 146 147 if use_skip_connection: 148 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 149 self.deconv2 = nn.Sequential( 150 Deconv2DBlock(embed_dim, features_decoder[0]), 151 Deconv2DBlock(features_decoder[0], features_decoder[1]) 152 ) 153 self.deconv3 = nn.Sequential( 154 Deconv2DBlock(embed_dim, features_decoder[0]), 155 Deconv2DBlock(features_decoder[0], features_decoder[1]), 156 Deconv2DBlock(features_decoder[1], features_decoder[2]) 157 ) 158 self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) 159 else: 160 self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) 161 self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) 162 self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) 163 self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) 164 165 self.base = ConvBlock2d(embed_dim, features_decoder[0]) 166 self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) 167 self.deconv_out = _upsampler( 168 scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] 169 ) 170 self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) 171 self.final_activation = self._get_activation(final_activation)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
@staticmethod
def
get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
185 @staticmethod 186 def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 187 """Compute the output size given input size and target long side length. 188 189 Args: 190 oldh: The input image height. 191 oldw: The input image width. 192 long_side_length: The longest side length for resizing. 193 194 Returns: 195 The new image height. 196 The new image width. 197 """ 198 scale = long_side_length * 1.0 / max(oldh, oldw) 199 newh, neww = oldh * scale, oldw * scale 200 neww = int(neww + 0.5) 201 newh = int(newh + 0.5) 202 return (newh, neww)
Compute the output size given input size and target long side length.
Arguments:
- oldh: The input image height.
- oldw: The input image width.
- long_side_length: The longest side length for resizing.
Returns:
The new image height. The new image width.
def
resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
204 def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor: 205 """Resize the image so that the longest side has the correct length. 206 207 Expects batched images with shape BxCxHxW and float format. 208 209 Args: 210 image: The input image. 211 212 Returns: 213 The resized image. 214 """ 215 target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size) 216 return F.interpolate( 217 image, target_size, mode="bilinear", align_corners=False, antialias=True 218 )
Resize the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW and float format.
Arguments:
- image: The input image.
Returns:
The resized image.
def
forward(self, x: torch.Tensor) -> torch.Tensor:
264 def forward(self, x: torch.Tensor) -> torch.Tensor: 265 """Apply the UNETR to the input data. 266 267 Args: 268 x: The input tensor. 269 270 Returns: 271 The UNETR output. 272 """ 273 original_shape = x.shape[-2:] 274 275 # Reshape the inputs to the shape expected by the encoder 276 # and normalize the inputs if normalization is part of the model. 277 x, input_shape = self.preprocess(x) 278 279 use_skip_connection = getattr(self, "use_skip_connection", True) 280 281 encoder_outputs = self.encoder(x) 282 283 if isinstance(encoder_outputs[-1], list): 284 # `encoder_outputs` can be arranged in only two forms: 285 # - either we only return the image embeddings 286 # - or, we return the image embeddings and the "list" of global attention layers 287 z12, from_encoder = encoder_outputs 288 else: 289 z12 = encoder_outputs 290 291 if use_skip_connection: 292 from_encoder = from_encoder[::-1] 293 z9 = self.deconv1(from_encoder[0]) 294 z6 = self.deconv2(from_encoder[1]) 295 z3 = self.deconv3(from_encoder[2]) 296 z0 = self.deconv4(x) 297 298 else: 299 z9 = self.deconv1(z12) 300 z6 = self.deconv2(z9) 301 z3 = self.deconv3(z6) 302 z0 = self.deconv4(z3) 303 304 updated_from_encoder = [z9, z6, z3] 305 306 x = self.base(z12) 307 x = self.decoder(x, encoder_inputs=updated_from_encoder) 308 x = self.deconv_out(x) 309 310 x = torch.cat([x, z0], dim=1) 311 x = self.decoder_head(x) 312 313 x = self.out_conv(x) 314 if self.final_activation is not None: 315 x = self.final_activation(x) 316 317 x = self.postprocess_masks(x, input_shape, original_shape) 318 return x
Apply the UNETR to the input data.
Arguments:
- x: The input tensor.
Returns:
The UNETR output.