torch_em.model.probabilistic_unet
1# This code is based on the original TensorFlow implementation: https://github.com/SimonKohl/probabilistic_unet 2# The below implementation is from: https://github.com/stefanknegt/Probabilistic-Unet-Pytorch 3 4import numpy as np 5 6import torch 7import torch.nn as nn 8from torch.distributions import Normal, Independent, kl 9 10from torch_em.model import UNet2d 11from torch_em.loss.dice import DiceLossWithLogits 12 13 14def truncated_normal_(tensor, mean=0, std=1): 15 size = tensor.shape 16 tmp = tensor.new_empty(size + (4,)).normal_() 17 valid = (tmp < 2) & (tmp > -2) 18 ind = valid.max(-1, keepdim=True)[1] 19 tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 20 tensor.data.mul_(std).add_(mean) 21 22 23def init_weights(m): 24 if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 25 nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 26 # nn.init.normal_(m.weight, std=0.001) 27 # nn.init.normal_(m.bias, std=0.001) 28 truncated_normal_(m.bias, mean=0, std=0.001) 29 30 31def init_weights_orthogonal_normal(m): 32 if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 33 nn.init.orthogonal_(m.weight) 34 truncated_normal_(m.bias, mean=0, std=0.001) 35 # nn.init.normal_(m.bias, std=0.001) 36 37 38class Encoder(nn.Module): 39 """ 40 A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block 41 convolutional layers, after each block a pooling operation is performed. 42 And after each convolutional layer a non-linear (ReLU) activation function is applied. 43 """ 44 45 def __init__( 46 self, 47 input_channels, 48 num_filters, 49 no_convs_per_block, 50 initializers, 51 padding=True, 52 posterior=False, 53 num_classes=None 54 ): 55 56 super().__init__() 57 58 self.contracting_path = nn.ModuleList() 59 self.input_channels = input_channels 60 self.num_filters = num_filters 61 62 if posterior: 63 # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels. 64 assert num_classes is not None 65 self.input_channels += num_classes 66 67 layers = [] 68 output_dim = None # Initialize output_dim of the layers 69 70 for i in range(len(self.num_filters)): 71 """ 72 Determine input_dim and output_dim of conv layers in this block. The first layer is input x output, 73 All the subsequent layers are output x output. 74 """ 75 76 input_dim = self.input_channels if i == 0 else output_dim 77 output_dim = num_filters[i] 78 79 if i != 0: 80 layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 81 82 layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3)) 83 layers.append(nn.ReLU(inplace=True)) 84 85 for _ in range(no_convs_per_block-1): 86 layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3)) 87 layers.append(nn.ReLU(inplace=True)) 88 89 self.layers = nn.Sequential(*layers) 90 91 self.layers.apply(init_weights) 92 93 def forward(self, input): 94 output = self.layers(input) 95 return output 96 97 98class AxisAlignedConvGaussian(nn.Module): 99 """ 100 A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix. 101 """ 102 103 def __init__( 104 self, 105 input_channels, 106 num_filters, 107 no_convs_per_block, 108 latent_dim, 109 initializers, 110 posterior=False, 111 num_classes=None 112 ): 113 114 super().__init__() 115 116 self.input_channels = input_channels 117 self.channel_axis = 1 118 self.num_filters = num_filters 119 self.no_convs_per_block = no_convs_per_block 120 self.latent_dim = latent_dim 121 122 self.posterior = posterior 123 if self.posterior: 124 self.name = 'Posterior' 125 else: 126 self.name = 'Prior' 127 128 self.encoder = Encoder( 129 self.input_channels, 130 self.num_filters, 131 self.no_convs_per_block, 132 initializers, 133 posterior=self.posterior, 134 num_classes=num_classes 135 ) 136 137 self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1) 138 self.show_img = 0 139 self.show_seg = 0 140 self.show_concat = 0 141 self.show_enc = 0 142 self.sum_input = 0 143 144 # 145 # @ Original paper's training details: 146 # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative 147 # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001 148 149 # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') # from Stefan's impl. 150 # nn.init.normal_(self.conv_layer.weight, std=0.001) # suggested @issues from Stefan's impl. 151 152 # nn.init.normal_(self.conv_layer.bias) # from Stefan's impl. 153 # 154 155 nn.init.orthogonal_(self.conv_layer.weight, gain=1) 156 nn.init.trunc_normal_(self.conv_layer.bias, std=0.001) 157 158 def forward(self, input, segm=None): 159 160 # If segmentation is not none, concatenate the mask to the channel axis of the input 161 if segm is not None: 162 self.show_img = input 163 self.show_seg = segm 164 input = torch.cat((input, segm), dim=1) 165 self.show_concat = input 166 self.sum_input = torch.sum(input) 167 168 encoding = self.encoder(input) 169 self.show_enc = encoding 170 171 # We only want the mean of the resulting hxw image 172 encoding = torch.mean(encoding, dim=2, keepdim=True) 173 encoding = torch.mean(encoding, dim=3, keepdim=True) 174 175 # Convert encoding to 2 x latent dim and split up for mu and log_sigma 176 mu_log_sigma = self.conv_layer(encoding) 177 178 # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1 179 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 180 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 181 182 mu = mu_log_sigma[:, :self.latent_dim] 183 log_sigma = mu_log_sigma[:, self.latent_dim:] 184 185 # This is a multivariate normal with diagonal covariance matrix sigma 186 # https://github.com/pytorch/pytorch/pull/11178 187 dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) 188 return dist 189 190 191class Fcomb(nn.Module): 192 """ 193 A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space, 194 and output of the UNet (the feature map) by concatenating them along their channel axis. 195 """ 196 def __init__( 197 self, 198 num_filters, 199 latent_dim, 200 num_output_channels, 201 num_classes, 202 no_convs_fcomb, 203 initializers, 204 use_tile=True, 205 device=None 206 ): 207 208 super().__init__() 209 210 self.num_channels = num_output_channels 211 self.num_classes = num_classes 212 self.channel_axis = 1 213 self.spatial_axes = [2, 3] 214 self.num_filters = num_filters 215 self.latent_dim = latent_dim 216 self.use_tile = use_tile 217 self.no_convs_fcomb = no_convs_fcomb 218 self.name = 'Fcomb' 219 220 if device is None: 221 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 222 else: 223 self.device = device 224 225 if self.use_tile: 226 layers = [] 227 228 # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer 229 layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1)) 230 layers.append(nn.ReLU(inplace=True)) 231 232 for _ in range(no_convs_fcomb-2): 233 layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1)) 234 layers.append(nn.ReLU(inplace=True)) 235 236 self.layers = nn.Sequential(*layers) 237 238 self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1) 239 240 if initializers['w'] == 'orthogonal': 241 self.layers.apply(init_weights_orthogonal_normal) 242 self.last_layer.apply(init_weights_orthogonal_normal) 243 else: 244 self.layers.apply(init_weights) 245 self.last_layer.apply(init_weights) 246 247 def tile(self, a, dim, n_tile): 248 """ 249 This function is taken form PyTorch forum and mimics the behavior of tf.tile. 250 Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 251 """ 252 init_dim = a.size(dim) 253 repeat_idx = [1] * a.dim() 254 repeat_idx[dim] = n_tile 255 a = a.repeat(*(repeat_idx)) 256 order_index = torch.LongTensor( 257 np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 258 ).to(self.device) 259 return torch.index_select(a, dim, order_index) 260 261 def forward(self, feature_map, z): 262 """ 263 Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W). 264 So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified) 265 """ 266 if self.use_tile: 267 z = torch.unsqueeze(z, 2) 268 z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]]) 269 z = torch.unsqueeze(z, 3) 270 z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]]) 271 272 # Concatenate the feature map (output of the UNet) and the sample taken from the latent space 273 feature_map = torch.cat((feature_map, z), dim=self.channel_axis) 274 output = self.layers(feature_map) 275 return self.last_layer(output) 276 277 278class ProbabilisticUNet(nn.Module): 279 """ This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034). 280 This generative segmentation heuristic uses UNet combined with a conditional variational 281 autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses. 282 283 The following elements are initialized to get our desired network: 284 input_channels: the number of channels in the image (1 for grayscale and 3 for RGB) 285 num_classes: the number of classes to predict 286 num_filters: is a list consisting of the amount of filters layer 287 latent_dim: dimension of the latent space 288 no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior 289 beta: KL and reconstruction loss are weighted using a KL weighting factor (β) 290 consensus_masking: activates consensus masking in the reconstruction loss 291 rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss) 292 293 Parameters: 294 input_channels [int] - (default: 1) 295 num_classes [int] - (default: 1) 296 num_filters [list] - (default: [32, 64, 128, 192]) 297 latent_dim [int] - (default: 6) 298 no_convs_fcomb [int] - (default: 4) 299 beta [float] - (default: 10.0) 300 consensus_masking [bool] - (default: False) 301 rl_swap [bool] - (default: False) 302 device [torch.device] - (default: None) 303 """ 304 305 def __init__( 306 self, 307 input_channels=1, 308 num_classes=1, 309 num_filters=[32, 64, 128, 192], 310 latent_dim=6, 311 no_convs_fcomb=4, 312 beta=10.0, 313 consensus_masking=False, 314 rl_swap=False, 315 device=None 316 ): 317 318 super().__init__() 319 320 self.input_channels = input_channels 321 self.num_classes = num_classes 322 self.num_filters = num_filters 323 self.latent_dim = latent_dim 324 self.no_convs_per_block = 3 325 self.no_convs_fcomb = no_convs_fcomb 326 self.initializers = {'w': 'he_normal', 'b': 'normal'} 327 self.beta = beta 328 self.z_prior_sample = 0 329 self.consensus_masking = consensus_masking 330 self.rl_swap = rl_swap 331 332 if device is None: 333 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 334 else: 335 self.device = device 336 337 self.unet = UNet2d( 338 in_channels=self.input_channels, 339 out_channels=None, 340 depth=len(self.num_filters), 341 initial_features=num_filters[0] 342 ).to(self.device) 343 344 self.prior = AxisAlignedConvGaussian( 345 self.input_channels, 346 self.num_filters, 347 self.no_convs_per_block, 348 self.latent_dim, 349 self.initializers 350 ).to(self.device) 351 352 self.posterior = AxisAlignedConvGaussian( 353 self.input_channels, 354 self.num_filters, 355 self.no_convs_per_block, 356 self.latent_dim, 357 self.initializers, 358 posterior=True, 359 num_classes=num_classes 360 ).to(self.device) 361 362 self.fcomb = Fcomb( 363 self.num_filters, 364 self.latent_dim, 365 self.input_channels, 366 self.num_classes, 367 self.no_convs_fcomb, 368 {'w': 'orthogonal', 'b': 'normal'}, 369 use_tile=True, 370 device=self.device 371 ).to(self.device) 372 373 def _check_shape(self, patch): 374 spatial_shape = tuple(patch.shape)[2:] 375 depth = len(self.num_filters) 376 factor = [2**depth] * len(spatial_shape) 377 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 378 msg = f"Invalid shape for Probabilistic U-Net: {spatial_shape} is not divisible by {factor}" 379 raise ValueError(msg) 380 381 def forward(self, patch, segm=None): 382 """ 383 Construct prior latent space for patch and run patch through UNet, 384 in case training is True also construct posterior latent space 385 """ 386 self._check_shape(patch) 387 388 if segm is not None: 389 self.posterior_latent_space = self.posterior.forward(patch, segm) 390 self.prior_latent_space = self.prior.forward(patch) 391 self.unet_features = self.unet.forward(patch) 392 393 def sample(self, testing=False): 394 """ 395 Sample a segmentation by reconstructing from a prior sample and combining this with UNet features 396 """ 397 if testing is False: 398 # TODO: prior distribution ? (posterior in this case!) 399 z_prior = self.prior_latent_space.rsample() 400 self.z_prior_sample = z_prior 401 else: 402 # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. 403 # z_prior = self.prior_latent_space.base_dist.loc 404 z_prior = self.prior_latent_space.sample() 405 self.z_prior_sample = z_prior 406 return self.fcomb.forward(self.unet_features, z_prior) 407 408 def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): 409 """ 410 Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map 411 use_posterior_mean: use posterior_mean instead of sampling z_q 412 calculate_posterior: use a provided sample or sample from posterior latent space 413 """ 414 if use_posterior_mean: 415 z_posterior = self.posterior_latent_space.loc 416 else: 417 if calculate_posterior: 418 z_posterior = self.posterior_latent_space.rsample() 419 return self.fcomb.forward(self.unet_features, z_posterior) 420 421 def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): 422 """ 423 Calculate the KL divergence between the posterior and prior KL(Q||P) 424 analytic: calculate KL analytically or via sampling from the posterior 425 calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample 426 """ 427 if analytic: 428 # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 429 kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) 430 else: 431 if calculate_posterior: 432 z_posterior = self.posterior_latent_space.rsample() 433 log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) 434 log_prior_prob = self.prior_latent_space.log_prob(z_posterior) 435 kl_div = log_posterior_prob - log_prior_prob 436 return kl_div 437 438 def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False): 439 """ 440 Calculate the evidence lower bound of the log-likelihood of P(Y|X) 441 consm: consensus response 442 """ 443 444 if self.rl_swap: 445 criterion = DiceLossWithLogits() 446 else: 447 criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) 448 449 z_posterior = self.posterior_latent_space.rsample() 450 451 self.kl = torch.mean( 452 self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior) 453 ) 454 455 # Here we use the posterior sample sampled above 456 self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, 457 calculate_posterior=False, z_posterior=z_posterior) 458 459 if self.consensus_masking is True and consm is not None: 460 reconstruction_loss = criterion(self.reconstruction * consm, segm * consm) 461 else: 462 reconstruction_loss = criterion(self.reconstruction, segm) 463 464 self.reconstruction_loss = torch.sum(reconstruction_loss) 465 self.mean_reconstruction_loss = torch.mean(reconstruction_loss) 466 467 return -(self.reconstruction_loss + self.beta * self.kl)
39class Encoder(nn.Module): 40 """ 41 A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block 42 convolutional layers, after each block a pooling operation is performed. 43 And after each convolutional layer a non-linear (ReLU) activation function is applied. 44 """ 45 46 def __init__( 47 self, 48 input_channels, 49 num_filters, 50 no_convs_per_block, 51 initializers, 52 padding=True, 53 posterior=False, 54 num_classes=None 55 ): 56 57 super().__init__() 58 59 self.contracting_path = nn.ModuleList() 60 self.input_channels = input_channels 61 self.num_filters = num_filters 62 63 if posterior: 64 # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels. 65 assert num_classes is not None 66 self.input_channels += num_classes 67 68 layers = [] 69 output_dim = None # Initialize output_dim of the layers 70 71 for i in range(len(self.num_filters)): 72 """ 73 Determine input_dim and output_dim of conv layers in this block. The first layer is input x output, 74 All the subsequent layers are output x output. 75 """ 76 77 input_dim = self.input_channels if i == 0 else output_dim 78 output_dim = num_filters[i] 79 80 if i != 0: 81 layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 82 83 layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3)) 84 layers.append(nn.ReLU(inplace=True)) 85 86 for _ in range(no_convs_per_block-1): 87 layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3)) 88 layers.append(nn.ReLU(inplace=True)) 89 90 self.layers = nn.Sequential(*layers) 91 92 self.layers.apply(init_weights) 93 94 def forward(self, input): 95 output = self.layers(input) 96 return output
A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers, after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied.
46 def __init__( 47 self, 48 input_channels, 49 num_filters, 50 no_convs_per_block, 51 initializers, 52 padding=True, 53 posterior=False, 54 num_classes=None 55 ): 56 57 super().__init__() 58 59 self.contracting_path = nn.ModuleList() 60 self.input_channels = input_channels 61 self.num_filters = num_filters 62 63 if posterior: 64 # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels. 65 assert num_classes is not None 66 self.input_channels += num_classes 67 68 layers = [] 69 output_dim = None # Initialize output_dim of the layers 70 71 for i in range(len(self.num_filters)): 72 """ 73 Determine input_dim and output_dim of conv layers in this block. The first layer is input x output, 74 All the subsequent layers are output x output. 75 """ 76 77 input_dim = self.input_channels if i == 0 else output_dim 78 output_dim = num_filters[i] 79 80 if i != 0: 81 layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 82 83 layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3)) 84 layers.append(nn.ReLU(inplace=True)) 85 86 for _ in range(no_convs_per_block-1): 87 layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3)) 88 layers.append(nn.ReLU(inplace=True)) 89 90 self.layers = nn.Sequential(*layers) 91 92 self.layers.apply(init_weights)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
99class AxisAlignedConvGaussian(nn.Module): 100 """ 101 A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix. 102 """ 103 104 def __init__( 105 self, 106 input_channels, 107 num_filters, 108 no_convs_per_block, 109 latent_dim, 110 initializers, 111 posterior=False, 112 num_classes=None 113 ): 114 115 super().__init__() 116 117 self.input_channels = input_channels 118 self.channel_axis = 1 119 self.num_filters = num_filters 120 self.no_convs_per_block = no_convs_per_block 121 self.latent_dim = latent_dim 122 123 self.posterior = posterior 124 if self.posterior: 125 self.name = 'Posterior' 126 else: 127 self.name = 'Prior' 128 129 self.encoder = Encoder( 130 self.input_channels, 131 self.num_filters, 132 self.no_convs_per_block, 133 initializers, 134 posterior=self.posterior, 135 num_classes=num_classes 136 ) 137 138 self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1) 139 self.show_img = 0 140 self.show_seg = 0 141 self.show_concat = 0 142 self.show_enc = 0 143 self.sum_input = 0 144 145 # 146 # @ Original paper's training details: 147 # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative 148 # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001 149 150 # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') # from Stefan's impl. 151 # nn.init.normal_(self.conv_layer.weight, std=0.001) # suggested @issues from Stefan's impl. 152 153 # nn.init.normal_(self.conv_layer.bias) # from Stefan's impl. 154 # 155 156 nn.init.orthogonal_(self.conv_layer.weight, gain=1) 157 nn.init.trunc_normal_(self.conv_layer.bias, std=0.001) 158 159 def forward(self, input, segm=None): 160 161 # If segmentation is not none, concatenate the mask to the channel axis of the input 162 if segm is not None: 163 self.show_img = input 164 self.show_seg = segm 165 input = torch.cat((input, segm), dim=1) 166 self.show_concat = input 167 self.sum_input = torch.sum(input) 168 169 encoding = self.encoder(input) 170 self.show_enc = encoding 171 172 # We only want the mean of the resulting hxw image 173 encoding = torch.mean(encoding, dim=2, keepdim=True) 174 encoding = torch.mean(encoding, dim=3, keepdim=True) 175 176 # Convert encoding to 2 x latent dim and split up for mu and log_sigma 177 mu_log_sigma = self.conv_layer(encoding) 178 179 # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1 180 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 181 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 182 183 mu = mu_log_sigma[:, :self.latent_dim] 184 log_sigma = mu_log_sigma[:, self.latent_dim:] 185 186 # This is a multivariate normal with diagonal covariance matrix sigma 187 # https://github.com/pytorch/pytorch/pull/11178 188 dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) 189 return dist
A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
104 def __init__( 105 self, 106 input_channels, 107 num_filters, 108 no_convs_per_block, 109 latent_dim, 110 initializers, 111 posterior=False, 112 num_classes=None 113 ): 114 115 super().__init__() 116 117 self.input_channels = input_channels 118 self.channel_axis = 1 119 self.num_filters = num_filters 120 self.no_convs_per_block = no_convs_per_block 121 self.latent_dim = latent_dim 122 123 self.posterior = posterior 124 if self.posterior: 125 self.name = 'Posterior' 126 else: 127 self.name = 'Prior' 128 129 self.encoder = Encoder( 130 self.input_channels, 131 self.num_filters, 132 self.no_convs_per_block, 133 initializers, 134 posterior=self.posterior, 135 num_classes=num_classes 136 ) 137 138 self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1) 139 self.show_img = 0 140 self.show_seg = 0 141 self.show_concat = 0 142 self.show_enc = 0 143 self.sum_input = 0 144 145 # 146 # @ Original paper's training details: 147 # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative 148 # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001 149 150 # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') # from Stefan's impl. 151 # nn.init.normal_(self.conv_layer.weight, std=0.001) # suggested @issues from Stefan's impl. 152 153 # nn.init.normal_(self.conv_layer.bias) # from Stefan's impl. 154 # 155 156 nn.init.orthogonal_(self.conv_layer.weight, gain=1) 157 nn.init.trunc_normal_(self.conv_layer.bias, std=0.001)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
159 def forward(self, input, segm=None): 160 161 # If segmentation is not none, concatenate the mask to the channel axis of the input 162 if segm is not None: 163 self.show_img = input 164 self.show_seg = segm 165 input = torch.cat((input, segm), dim=1) 166 self.show_concat = input 167 self.sum_input = torch.sum(input) 168 169 encoding = self.encoder(input) 170 self.show_enc = encoding 171 172 # We only want the mean of the resulting hxw image 173 encoding = torch.mean(encoding, dim=2, keepdim=True) 174 encoding = torch.mean(encoding, dim=3, keepdim=True) 175 176 # Convert encoding to 2 x latent dim and split up for mu and log_sigma 177 mu_log_sigma = self.conv_layer(encoding) 178 179 # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1 180 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 181 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 182 183 mu = mu_log_sigma[:, :self.latent_dim] 184 log_sigma = mu_log_sigma[:, self.latent_dim:] 185 186 # This is a multivariate normal with diagonal covariance matrix sigma 187 # https://github.com/pytorch/pytorch/pull/11178 188 dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) 189 return dist
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
192class Fcomb(nn.Module): 193 """ 194 A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space, 195 and output of the UNet (the feature map) by concatenating them along their channel axis. 196 """ 197 def __init__( 198 self, 199 num_filters, 200 latent_dim, 201 num_output_channels, 202 num_classes, 203 no_convs_fcomb, 204 initializers, 205 use_tile=True, 206 device=None 207 ): 208 209 super().__init__() 210 211 self.num_channels = num_output_channels 212 self.num_classes = num_classes 213 self.channel_axis = 1 214 self.spatial_axes = [2, 3] 215 self.num_filters = num_filters 216 self.latent_dim = latent_dim 217 self.use_tile = use_tile 218 self.no_convs_fcomb = no_convs_fcomb 219 self.name = 'Fcomb' 220 221 if device is None: 222 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 223 else: 224 self.device = device 225 226 if self.use_tile: 227 layers = [] 228 229 # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer 230 layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1)) 231 layers.append(nn.ReLU(inplace=True)) 232 233 for _ in range(no_convs_fcomb-2): 234 layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1)) 235 layers.append(nn.ReLU(inplace=True)) 236 237 self.layers = nn.Sequential(*layers) 238 239 self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1) 240 241 if initializers['w'] == 'orthogonal': 242 self.layers.apply(init_weights_orthogonal_normal) 243 self.last_layer.apply(init_weights_orthogonal_normal) 244 else: 245 self.layers.apply(init_weights) 246 self.last_layer.apply(init_weights) 247 248 def tile(self, a, dim, n_tile): 249 """ 250 This function is taken form PyTorch forum and mimics the behavior of tf.tile. 251 Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 252 """ 253 init_dim = a.size(dim) 254 repeat_idx = [1] * a.dim() 255 repeat_idx[dim] = n_tile 256 a = a.repeat(*(repeat_idx)) 257 order_index = torch.LongTensor( 258 np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 259 ).to(self.device) 260 return torch.index_select(a, dim, order_index) 261 262 def forward(self, feature_map, z): 263 """ 264 Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W). 265 So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified) 266 """ 267 if self.use_tile: 268 z = torch.unsqueeze(z, 2) 269 z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]]) 270 z = torch.unsqueeze(z, 3) 271 z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]]) 272 273 # Concatenate the feature map (output of the UNet) and the sample taken from the latent space 274 feature_map = torch.cat((feature_map, z), dim=self.channel_axis) 275 output = self.layers(feature_map) 276 return self.last_layer(output)
A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space, and output of the UNet (the feature map) by concatenating them along their channel axis.
197 def __init__( 198 self, 199 num_filters, 200 latent_dim, 201 num_output_channels, 202 num_classes, 203 no_convs_fcomb, 204 initializers, 205 use_tile=True, 206 device=None 207 ): 208 209 super().__init__() 210 211 self.num_channels = num_output_channels 212 self.num_classes = num_classes 213 self.channel_axis = 1 214 self.spatial_axes = [2, 3] 215 self.num_filters = num_filters 216 self.latent_dim = latent_dim 217 self.use_tile = use_tile 218 self.no_convs_fcomb = no_convs_fcomb 219 self.name = 'Fcomb' 220 221 if device is None: 222 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 223 else: 224 self.device = device 225 226 if self.use_tile: 227 layers = [] 228 229 # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer 230 layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1)) 231 layers.append(nn.ReLU(inplace=True)) 232 233 for _ in range(no_convs_fcomb-2): 234 layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1)) 235 layers.append(nn.ReLU(inplace=True)) 236 237 self.layers = nn.Sequential(*layers) 238 239 self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1) 240 241 if initializers['w'] == 'orthogonal': 242 self.layers.apply(init_weights_orthogonal_normal) 243 self.last_layer.apply(init_weights_orthogonal_normal) 244 else: 245 self.layers.apply(init_weights) 246 self.last_layer.apply(init_weights)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
248 def tile(self, a, dim, n_tile): 249 """ 250 This function is taken form PyTorch forum and mimics the behavior of tf.tile. 251 Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 252 """ 253 init_dim = a.size(dim) 254 repeat_idx = [1] * a.dim() 255 repeat_idx[dim] = n_tile 256 a = a.repeat(*(repeat_idx)) 257 order_index = torch.LongTensor( 258 np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 259 ).to(self.device) 260 return torch.index_select(a, dim, order_index)
This function is taken form PyTorch forum and mimics the behavior of tf.tile. Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
262 def forward(self, feature_map, z): 263 """ 264 Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W). 265 So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified) 266 """ 267 if self.use_tile: 268 z = torch.unsqueeze(z, 2) 269 z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]]) 270 z = torch.unsqueeze(z, 3) 271 z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]]) 272 273 # Concatenate the feature map (output of the UNet) and the sample taken from the latent space 274 feature_map = torch.cat((feature_map, z), dim=self.channel_axis) 275 output = self.layers(feature_map) 276 return self.last_layer(output)
Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W). So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
279class ProbabilisticUNet(nn.Module): 280 """ This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034). 281 This generative segmentation heuristic uses UNet combined with a conditional variational 282 autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses. 283 284 The following elements are initialized to get our desired network: 285 input_channels: the number of channels in the image (1 for grayscale and 3 for RGB) 286 num_classes: the number of classes to predict 287 num_filters: is a list consisting of the amount of filters layer 288 latent_dim: dimension of the latent space 289 no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior 290 beta: KL and reconstruction loss are weighted using a KL weighting factor (β) 291 consensus_masking: activates consensus masking in the reconstruction loss 292 rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss) 293 294 Parameters: 295 input_channels [int] - (default: 1) 296 num_classes [int] - (default: 1) 297 num_filters [list] - (default: [32, 64, 128, 192]) 298 latent_dim [int] - (default: 6) 299 no_convs_fcomb [int] - (default: 4) 300 beta [float] - (default: 10.0) 301 consensus_masking [bool] - (default: False) 302 rl_swap [bool] - (default: False) 303 device [torch.device] - (default: None) 304 """ 305 306 def __init__( 307 self, 308 input_channels=1, 309 num_classes=1, 310 num_filters=[32, 64, 128, 192], 311 latent_dim=6, 312 no_convs_fcomb=4, 313 beta=10.0, 314 consensus_masking=False, 315 rl_swap=False, 316 device=None 317 ): 318 319 super().__init__() 320 321 self.input_channels = input_channels 322 self.num_classes = num_classes 323 self.num_filters = num_filters 324 self.latent_dim = latent_dim 325 self.no_convs_per_block = 3 326 self.no_convs_fcomb = no_convs_fcomb 327 self.initializers = {'w': 'he_normal', 'b': 'normal'} 328 self.beta = beta 329 self.z_prior_sample = 0 330 self.consensus_masking = consensus_masking 331 self.rl_swap = rl_swap 332 333 if device is None: 334 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 335 else: 336 self.device = device 337 338 self.unet = UNet2d( 339 in_channels=self.input_channels, 340 out_channels=None, 341 depth=len(self.num_filters), 342 initial_features=num_filters[0] 343 ).to(self.device) 344 345 self.prior = AxisAlignedConvGaussian( 346 self.input_channels, 347 self.num_filters, 348 self.no_convs_per_block, 349 self.latent_dim, 350 self.initializers 351 ).to(self.device) 352 353 self.posterior = AxisAlignedConvGaussian( 354 self.input_channels, 355 self.num_filters, 356 self.no_convs_per_block, 357 self.latent_dim, 358 self.initializers, 359 posterior=True, 360 num_classes=num_classes 361 ).to(self.device) 362 363 self.fcomb = Fcomb( 364 self.num_filters, 365 self.latent_dim, 366 self.input_channels, 367 self.num_classes, 368 self.no_convs_fcomb, 369 {'w': 'orthogonal', 'b': 'normal'}, 370 use_tile=True, 371 device=self.device 372 ).to(self.device) 373 374 def _check_shape(self, patch): 375 spatial_shape = tuple(patch.shape)[2:] 376 depth = len(self.num_filters) 377 factor = [2**depth] * len(spatial_shape) 378 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 379 msg = f"Invalid shape for Probabilistic U-Net: {spatial_shape} is not divisible by {factor}" 380 raise ValueError(msg) 381 382 def forward(self, patch, segm=None): 383 """ 384 Construct prior latent space for patch and run patch through UNet, 385 in case training is True also construct posterior latent space 386 """ 387 self._check_shape(patch) 388 389 if segm is not None: 390 self.posterior_latent_space = self.posterior.forward(patch, segm) 391 self.prior_latent_space = self.prior.forward(patch) 392 self.unet_features = self.unet.forward(patch) 393 394 def sample(self, testing=False): 395 """ 396 Sample a segmentation by reconstructing from a prior sample and combining this with UNet features 397 """ 398 if testing is False: 399 # TODO: prior distribution ? (posterior in this case!) 400 z_prior = self.prior_latent_space.rsample() 401 self.z_prior_sample = z_prior 402 else: 403 # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. 404 # z_prior = self.prior_latent_space.base_dist.loc 405 z_prior = self.prior_latent_space.sample() 406 self.z_prior_sample = z_prior 407 return self.fcomb.forward(self.unet_features, z_prior) 408 409 def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): 410 """ 411 Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map 412 use_posterior_mean: use posterior_mean instead of sampling z_q 413 calculate_posterior: use a provided sample or sample from posterior latent space 414 """ 415 if use_posterior_mean: 416 z_posterior = self.posterior_latent_space.loc 417 else: 418 if calculate_posterior: 419 z_posterior = self.posterior_latent_space.rsample() 420 return self.fcomb.forward(self.unet_features, z_posterior) 421 422 def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): 423 """ 424 Calculate the KL divergence between the posterior and prior KL(Q||P) 425 analytic: calculate KL analytically or via sampling from the posterior 426 calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample 427 """ 428 if analytic: 429 # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 430 kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) 431 else: 432 if calculate_posterior: 433 z_posterior = self.posterior_latent_space.rsample() 434 log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) 435 log_prior_prob = self.prior_latent_space.log_prob(z_posterior) 436 kl_div = log_posterior_prob - log_prior_prob 437 return kl_div 438 439 def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False): 440 """ 441 Calculate the evidence lower bound of the log-likelihood of P(Y|X) 442 consm: consensus response 443 """ 444 445 if self.rl_swap: 446 criterion = DiceLossWithLogits() 447 else: 448 criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) 449 450 z_posterior = self.posterior_latent_space.rsample() 451 452 self.kl = torch.mean( 453 self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior) 454 ) 455 456 # Here we use the posterior sample sampled above 457 self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, 458 calculate_posterior=False, z_posterior=z_posterior) 459 460 if self.consensus_masking is True and consm is not None: 461 reconstruction_loss = criterion(self.reconstruction * consm, segm * consm) 462 else: 463 reconstruction_loss = criterion(self.reconstruction, segm) 464 465 self.reconstruction_loss = torch.sum(reconstruction_loss) 466 self.mean_reconstruction_loss = torch.mean(reconstruction_loss) 467 468 return -(self.reconstruction_loss + self.beta * self.kl)
This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034). This generative segmentation heuristic uses UNet combined with a conditional variational autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses.
The following elements are initialized to get our desired network: input_channels: the number of channels in the image (1 for grayscale and 3 for RGB) num_classes: the number of classes to predict num_filters: is a list consisting of the amount of filters layer latent_dim: dimension of the latent space no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior beta: KL and reconstruction loss are weighted using a KL weighting factor (β) consensus_masking: activates consensus masking in the reconstruction loss rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss)
Arguments:
- input_channels [int] - (default: 1)
- num_classes [int] - (default: 1)
- num_filters [list] - (default: [32, 64, 128, 192])
- latent_dim [int] - (default: 6)
- no_convs_fcomb [int] - (default: 4)
- beta [float] - (default: 10.0)
- consensus_masking [bool] - (default: False)
- rl_swap [bool] - (default: False)
- device [torch.device] - (default: None)
306 def __init__( 307 self, 308 input_channels=1, 309 num_classes=1, 310 num_filters=[32, 64, 128, 192], 311 latent_dim=6, 312 no_convs_fcomb=4, 313 beta=10.0, 314 consensus_masking=False, 315 rl_swap=False, 316 device=None 317 ): 318 319 super().__init__() 320 321 self.input_channels = input_channels 322 self.num_classes = num_classes 323 self.num_filters = num_filters 324 self.latent_dim = latent_dim 325 self.no_convs_per_block = 3 326 self.no_convs_fcomb = no_convs_fcomb 327 self.initializers = {'w': 'he_normal', 'b': 'normal'} 328 self.beta = beta 329 self.z_prior_sample = 0 330 self.consensus_masking = consensus_masking 331 self.rl_swap = rl_swap 332 333 if device is None: 334 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 335 else: 336 self.device = device 337 338 self.unet = UNet2d( 339 in_channels=self.input_channels, 340 out_channels=None, 341 depth=len(self.num_filters), 342 initial_features=num_filters[0] 343 ).to(self.device) 344 345 self.prior = AxisAlignedConvGaussian( 346 self.input_channels, 347 self.num_filters, 348 self.no_convs_per_block, 349 self.latent_dim, 350 self.initializers 351 ).to(self.device) 352 353 self.posterior = AxisAlignedConvGaussian( 354 self.input_channels, 355 self.num_filters, 356 self.no_convs_per_block, 357 self.latent_dim, 358 self.initializers, 359 posterior=True, 360 num_classes=num_classes 361 ).to(self.device) 362 363 self.fcomb = Fcomb( 364 self.num_filters, 365 self.latent_dim, 366 self.input_channels, 367 self.num_classes, 368 self.no_convs_fcomb, 369 {'w': 'orthogonal', 'b': 'normal'}, 370 use_tile=True, 371 device=self.device 372 ).to(self.device)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
382 def forward(self, patch, segm=None): 383 """ 384 Construct prior latent space for patch and run patch through UNet, 385 in case training is True also construct posterior latent space 386 """ 387 self._check_shape(patch) 388 389 if segm is not None: 390 self.posterior_latent_space = self.posterior.forward(patch, segm) 391 self.prior_latent_space = self.prior.forward(patch) 392 self.unet_features = self.unet.forward(patch)
Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space
394 def sample(self, testing=False): 395 """ 396 Sample a segmentation by reconstructing from a prior sample and combining this with UNet features 397 """ 398 if testing is False: 399 # TODO: prior distribution ? (posterior in this case!) 400 z_prior = self.prior_latent_space.rsample() 401 self.z_prior_sample = z_prior 402 else: 403 # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. 404 # z_prior = self.prior_latent_space.base_dist.loc 405 z_prior = self.prior_latent_space.sample() 406 self.z_prior_sample = z_prior 407 return self.fcomb.forward(self.unet_features, z_prior)
Sample a segmentation by reconstructing from a prior sample and combining this with UNet features
409 def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): 410 """ 411 Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map 412 use_posterior_mean: use posterior_mean instead of sampling z_q 413 calculate_posterior: use a provided sample or sample from posterior latent space 414 """ 415 if use_posterior_mean: 416 z_posterior = self.posterior_latent_space.loc 417 else: 418 if calculate_posterior: 419 z_posterior = self.posterior_latent_space.rsample() 420 return self.fcomb.forward(self.unet_features, z_posterior)
Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map use_posterior_mean: use posterior_mean instead of sampling z_q calculate_posterior: use a provided sample or sample from posterior latent space
422 def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): 423 """ 424 Calculate the KL divergence between the posterior and prior KL(Q||P) 425 analytic: calculate KL analytically or via sampling from the posterior 426 calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample 427 """ 428 if analytic: 429 # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 430 kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) 431 else: 432 if calculate_posterior: 433 z_posterior = self.posterior_latent_space.rsample() 434 log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) 435 log_prior_prob = self.prior_latent_space.log_prob(z_posterior) 436 kl_div = log_posterior_prob - log_prior_prob 437 return kl_div
Calculate the KL divergence between the posterior and prior KL(Q||P) analytic: calculate KL analytically or via sampling from the posterior calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
439 def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False): 440 """ 441 Calculate the evidence lower bound of the log-likelihood of P(Y|X) 442 consm: consensus response 443 """ 444 445 if self.rl_swap: 446 criterion = DiceLossWithLogits() 447 else: 448 criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) 449 450 z_posterior = self.posterior_latent_space.rsample() 451 452 self.kl = torch.mean( 453 self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior) 454 ) 455 456 # Here we use the posterior sample sampled above 457 self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, 458 calculate_posterior=False, z_posterior=z_posterior) 459 460 if self.consensus_masking is True and consm is not None: 461 reconstruction_loss = criterion(self.reconstruction * consm, segm * consm) 462 else: 463 reconstruction_loss = criterion(self.reconstruction, segm) 464 465 self.reconstruction_loss = torch.sum(reconstruction_loss) 466 self.mean_reconstruction_loss = torch.mean(reconstruction_loss) 467 468 return -(self.reconstruction_loss + self.beta * self.kl)
Calculate the evidence lower bound of the log-likelihood of P(Y|X) consm: consensus response
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile