torch_em.model.probabilistic_unet
@private
1"""@private 2""" 3 4# This code is based on the original TensorFlow implementation: https://github.com/SimonKohl/probabilistic_unet 5# The below implementation is from: https://github.com/stefanknegt/Probabilistic-Unet-Pytorch 6 7import numpy as np 8 9import torch 10import torch.nn as nn 11from torch.distributions import Normal, Independent, kl 12 13from torch_em.model import UNet2d 14from torch_em.loss.dice import DiceLossWithLogits 15 16 17def truncated_normal_(tensor, mean=0, std=1): 18 """@private 19 """ 20 size = tensor.shape 21 tmp = tensor.new_empty(size + (4,)).normal_() 22 valid = (tmp < 2) & (tmp > -2) 23 ind = valid.max(-1, keepdim=True)[1] 24 tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 25 tensor.data.mul_(std).add_(mean) 26 27 28def init_weights(m): 29 """@private 30 """ 31 if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 32 nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 33 # nn.init.normal_(m.weight, std=0.001) 34 # nn.init.normal_(m.bias, std=0.001) 35 truncated_normal_(m.bias, mean=0, std=0.001) 36 37 38def init_weights_orthogonal_normal(m): 39 """@private 40 """ 41 if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 42 nn.init.orthogonal_(m.weight) 43 truncated_normal_(m.bias, mean=0, std=0.001) 44 # nn.init.normal_(m.bias, std=0.001) 45 46 47class Encoder(nn.Module): 48 """ 49 A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block 50 convolutional layers, after each block a pooling operation is performed. 51 And after each convolutional layer a non-linear (ReLU) activation function is applied. 52 """ 53 def __init__( 54 self, 55 input_channels, 56 num_filters, 57 no_convs_per_block, 58 initializers, 59 padding=True, 60 posterior=False, 61 num_classes=None 62 ): 63 64 super().__init__() 65 66 self.contracting_path = nn.ModuleList() 67 self.input_channels = input_channels 68 self.num_filters = num_filters 69 70 if posterior: 71 # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels. 72 assert num_classes is not None 73 self.input_channels += num_classes 74 75 layers = [] 76 output_dim = None # Initialize output_dim of the layers 77 78 for i in range(len(self.num_filters)): 79 """ 80 Determine input_dim and output_dim of conv layers in this block. The first layer is input x output, 81 All the subsequent layers are output x output. 82 """ 83 84 input_dim = self.input_channels if i == 0 else output_dim 85 output_dim = num_filters[i] 86 87 if i != 0: 88 layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 89 90 layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3)) 91 layers.append(nn.ReLU(inplace=True)) 92 93 for _ in range(no_convs_per_block-1): 94 layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3)) 95 layers.append(nn.ReLU(inplace=True)) 96 97 self.layers = nn.Sequential(*layers) 98 99 self.layers.apply(init_weights) 100 101 def forward(self, input): 102 output = self.layers(input) 103 return output 104 105 106class AxisAlignedConvGaussian(nn.Module): 107 """ 108 A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix. 109 """ 110 111 def __init__( 112 self, 113 input_channels, 114 num_filters, 115 no_convs_per_block, 116 latent_dim, 117 initializers, 118 posterior=False, 119 num_classes=None 120 ): 121 122 super().__init__() 123 124 self.input_channels = input_channels 125 self.channel_axis = 1 126 self.num_filters = num_filters 127 self.no_convs_per_block = no_convs_per_block 128 self.latent_dim = latent_dim 129 130 self.posterior = posterior 131 if self.posterior: 132 self.name = 'Posterior' 133 else: 134 self.name = 'Prior' 135 136 self.encoder = Encoder( 137 self.input_channels, 138 self.num_filters, 139 self.no_convs_per_block, 140 initializers, 141 posterior=self.posterior, 142 num_classes=num_classes 143 ) 144 145 self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1) 146 self.show_img = 0 147 self.show_seg = 0 148 self.show_concat = 0 149 self.show_enc = 0 150 self.sum_input = 0 151 152 # 153 # @ Original paper's training details: 154 # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative 155 # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001 156 157 # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') # from Stefan's impl. 158 # nn.init.normal_(self.conv_layer.weight, std=0.001) # suggested @issues from Stefan's impl. 159 160 # nn.init.normal_(self.conv_layer.bias) # from Stefan's impl. 161 # 162 163 nn.init.orthogonal_(self.conv_layer.weight, gain=1) 164 nn.init.trunc_normal_(self.conv_layer.bias, std=0.001) 165 166 def forward(self, input, segm=None): 167 168 # If segmentation is not none, concatenate the mask to the channel axis of the input 169 if segm is not None: 170 self.show_img = input 171 self.show_seg = segm 172 input = torch.cat((input, segm), dim=1) 173 self.show_concat = input 174 self.sum_input = torch.sum(input) 175 176 encoding = self.encoder(input) 177 self.show_enc = encoding 178 179 # We only want the mean of the resulting hxw image 180 encoding = torch.mean(encoding, dim=2, keepdim=True) 181 encoding = torch.mean(encoding, dim=3, keepdim=True) 182 183 # Convert encoding to 2 x latent dim and split up for mu and log_sigma 184 mu_log_sigma = self.conv_layer(encoding) 185 186 # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1 187 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 188 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 189 190 mu = mu_log_sigma[:, :self.latent_dim] 191 log_sigma = mu_log_sigma[:, self.latent_dim:] 192 193 # This is a multivariate normal with diagonal covariance matrix sigma 194 # https://github.com/pytorch/pytorch/pull/11178 195 dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) 196 return dist 197 198 199class Fcomb(nn.Module): 200 """ 201 A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space, 202 and output of the UNet (the feature map) by concatenating them along their channel axis. 203 """ 204 def __init__( 205 self, 206 num_filters, 207 latent_dim, 208 num_output_channels, 209 num_classes, 210 no_convs_fcomb, 211 initializers, 212 use_tile=True, 213 device=None 214 ): 215 216 super().__init__() 217 218 self.num_channels = num_output_channels 219 self.num_classes = num_classes 220 self.channel_axis = 1 221 self.spatial_axes = [2, 3] 222 self.num_filters = num_filters 223 self.latent_dim = latent_dim 224 self.use_tile = use_tile 225 self.no_convs_fcomb = no_convs_fcomb 226 self.name = 'Fcomb' 227 228 if device is None: 229 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 230 else: 231 self.device = device 232 233 if self.use_tile: 234 layers = [] 235 236 # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer 237 layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1)) 238 layers.append(nn.ReLU(inplace=True)) 239 240 for _ in range(no_convs_fcomb-2): 241 layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1)) 242 layers.append(nn.ReLU(inplace=True)) 243 244 self.layers = nn.Sequential(*layers) 245 246 self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1) 247 248 if initializers['w'] == 'orthogonal': 249 self.layers.apply(init_weights_orthogonal_normal) 250 self.last_layer.apply(init_weights_orthogonal_normal) 251 else: 252 self.layers.apply(init_weights) 253 self.last_layer.apply(init_weights) 254 255 def tile(self, a, dim, n_tile): 256 """ 257 This function is taken form PyTorch forum and mimics the behavior of tf.tile. 258 Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 259 """ 260 init_dim = a.size(dim) 261 repeat_idx = [1] * a.dim() 262 repeat_idx[dim] = n_tile 263 a = a.repeat(*(repeat_idx)) 264 order_index = torch.LongTensor( 265 np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 266 ).to(self.device) 267 return torch.index_select(a, dim, order_index) 268 269 def forward(self, feature_map, z): 270 """ 271 Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W). 272 So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified) 273 """ 274 if self.use_tile: 275 z = torch.unsqueeze(z, 2) 276 z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]]) 277 z = torch.unsqueeze(z, 3) 278 z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]]) 279 280 # Concatenate the feature map (output of the UNet) and the sample taken from the latent space 281 feature_map = torch.cat((feature_map, z), dim=self.channel_axis) 282 output = self.layers(feature_map) 283 return self.last_layer(output) 284 285 286class ProbabilisticUNet(nn.Module): 287 """ This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034). 288 This generative segmentation heuristic uses UNet combined with a conditional variational 289 autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses. 290 291 The following elements are initialized to get our desired network: 292 input_channels: the number of channels in the image (1 for grayscale and 3 for RGB) 293 num_classes: the number of classes to predict 294 num_filters: is a list consisting of the amount of filters layer 295 latent_dim: dimension of the latent space 296 no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior 297 beta: KL and reconstruction loss are weighted using a KL weighting factor (β) 298 consensus_masking: activates consensus masking in the reconstruction loss 299 rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss) 300 301 Args: 302 input_channels [int] - (default: 1) 303 num_classes [int] - (default: 1) 304 num_filters [list] - (default: [32, 64, 128, 192]) 305 latent_dim [int] - (default: 6) 306 no_convs_fcomb [int] - (default: 4) 307 beta [float] - (default: 10.0) 308 consensus_masking [bool] - (default: False) 309 rl_swap [bool] - (default: False) 310 device [torch.device] - (default: None) 311 """ 312 313 def __init__( 314 self, 315 input_channels=1, 316 num_classes=1, 317 num_filters=[32, 64, 128, 192], 318 latent_dim=6, 319 no_convs_fcomb=4, 320 beta=10.0, 321 consensus_masking=False, 322 rl_swap=False, 323 device=None 324 ): 325 326 super().__init__() 327 328 self.input_channels = input_channels 329 self.num_classes = num_classes 330 self.num_filters = num_filters 331 self.latent_dim = latent_dim 332 self.no_convs_per_block = 3 333 self.no_convs_fcomb = no_convs_fcomb 334 self.initializers = {'w': 'he_normal', 'b': 'normal'} 335 self.beta = beta 336 self.z_prior_sample = 0 337 self.consensus_masking = consensus_masking 338 self.rl_swap = rl_swap 339 340 if device is None: 341 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 342 else: 343 self.device = device 344 345 self.unet = UNet2d( 346 in_channels=self.input_channels, 347 out_channels=None, 348 depth=len(self.num_filters), 349 initial_features=num_filters[0] 350 ).to(self.device) 351 352 self.prior = AxisAlignedConvGaussian( 353 self.input_channels, 354 self.num_filters, 355 self.no_convs_per_block, 356 self.latent_dim, 357 self.initializers 358 ).to(self.device) 359 360 self.posterior = AxisAlignedConvGaussian( 361 self.input_channels, 362 self.num_filters, 363 self.no_convs_per_block, 364 self.latent_dim, 365 self.initializers, 366 posterior=True, 367 num_classes=num_classes 368 ).to(self.device) 369 370 self.fcomb = Fcomb( 371 self.num_filters, 372 self.latent_dim, 373 self.input_channels, 374 self.num_classes, 375 self.no_convs_fcomb, 376 {'w': 'orthogonal', 'b': 'normal'}, 377 use_tile=True, 378 device=self.device 379 ).to(self.device) 380 381 def _check_shape(self, patch): 382 spatial_shape = tuple(patch.shape)[2:] 383 depth = len(self.num_filters) 384 factor = [2**depth] * len(spatial_shape) 385 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 386 msg = f"Invalid shape for Probabilistic U-Net: {spatial_shape} is not divisible by {factor}" 387 raise ValueError(msg) 388 389 def forward(self, patch, segm=None): 390 """ 391 Construct prior latent space for patch and run patch through UNet, 392 in case training is True also construct posterior latent space 393 """ 394 self._check_shape(patch) 395 396 if segm is not None: 397 self.posterior_latent_space = self.posterior.forward(patch, segm) 398 self.prior_latent_space = self.prior.forward(patch) 399 self.unet_features = self.unet.forward(patch) 400 401 def sample(self, testing=False): 402 """ 403 Sample a segmentation by reconstructing from a prior sample and combining this with UNet features 404 """ 405 if testing is False: 406 # TODO: prior distribution ? (posterior in this case!) 407 z_prior = self.prior_latent_space.rsample() 408 self.z_prior_sample = z_prior 409 else: 410 # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. 411 # z_prior = self.prior_latent_space.base_dist.loc 412 z_prior = self.prior_latent_space.sample() 413 self.z_prior_sample = z_prior 414 return self.fcomb.forward(self.unet_features, z_prior) 415 416 def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): 417 """ 418 Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map 419 use_posterior_mean: use posterior_mean instead of sampling z_q 420 calculate_posterior: use a provided sample or sample from posterior latent space 421 """ 422 if use_posterior_mean: 423 z_posterior = self.posterior_latent_space.loc 424 else: 425 if calculate_posterior: 426 z_posterior = self.posterior_latent_space.rsample() 427 return self.fcomb.forward(self.unet_features, z_posterior) 428 429 def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): 430 """ 431 Calculate the KL divergence between the posterior and prior KL(Q||P) 432 analytic: calculate KL analytically or via sampling from the posterior 433 calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample 434 """ 435 if analytic: 436 # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 437 kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) 438 else: 439 if calculate_posterior: 440 z_posterior = self.posterior_latent_space.rsample() 441 log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) 442 log_prior_prob = self.prior_latent_space.log_prob(z_posterior) 443 kl_div = log_posterior_prob - log_prior_prob 444 return kl_div 445 446 def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False): 447 """ 448 Calculate the evidence lower bound of the log-likelihood of P(Y|X) 449 consm: consensus response 450 """ 451 452 if self.rl_swap: 453 criterion = DiceLossWithLogits() 454 else: 455 criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) 456 457 z_posterior = self.posterior_latent_space.rsample() 458 459 self.kl = torch.mean( 460 self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior) 461 ) 462 463 # Here we use the posterior sample sampled above 464 self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, 465 calculate_posterior=False, z_posterior=z_posterior) 466 467 if self.consensus_masking is True and consm is not None: 468 reconstruction_loss = criterion(self.reconstruction * consm, segm * consm) 469 else: 470 reconstruction_loss = criterion(self.reconstruction, segm) 471 472 self.reconstruction_loss = torch.sum(reconstruction_loss) 473 self.mean_reconstruction_loss = torch.mean(reconstruction_loss) 474 475 return -(self.reconstruction_loss + self.beta * self.kl)
48class Encoder(nn.Module): 49 """ 50 A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block 51 convolutional layers, after each block a pooling operation is performed. 52 And after each convolutional layer a non-linear (ReLU) activation function is applied. 53 """ 54 def __init__( 55 self, 56 input_channels, 57 num_filters, 58 no_convs_per_block, 59 initializers, 60 padding=True, 61 posterior=False, 62 num_classes=None 63 ): 64 65 super().__init__() 66 67 self.contracting_path = nn.ModuleList() 68 self.input_channels = input_channels 69 self.num_filters = num_filters 70 71 if posterior: 72 # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels. 73 assert num_classes is not None 74 self.input_channels += num_classes 75 76 layers = [] 77 output_dim = None # Initialize output_dim of the layers 78 79 for i in range(len(self.num_filters)): 80 """ 81 Determine input_dim and output_dim of conv layers in this block. The first layer is input x output, 82 All the subsequent layers are output x output. 83 """ 84 85 input_dim = self.input_channels if i == 0 else output_dim 86 output_dim = num_filters[i] 87 88 if i != 0: 89 layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 90 91 layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3)) 92 layers.append(nn.ReLU(inplace=True)) 93 94 for _ in range(no_convs_per_block-1): 95 layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3)) 96 layers.append(nn.ReLU(inplace=True)) 97 98 self.layers = nn.Sequential(*layers) 99 100 self.layers.apply(init_weights) 101 102 def forward(self, input): 103 output = self.layers(input) 104 return output
A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers, after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied.
54 def __init__( 55 self, 56 input_channels, 57 num_filters, 58 no_convs_per_block, 59 initializers, 60 padding=True, 61 posterior=False, 62 num_classes=None 63 ): 64 65 super().__init__() 66 67 self.contracting_path = nn.ModuleList() 68 self.input_channels = input_channels 69 self.num_filters = num_filters 70 71 if posterior: 72 # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels. 73 assert num_classes is not None 74 self.input_channels += num_classes 75 76 layers = [] 77 output_dim = None # Initialize output_dim of the layers 78 79 for i in range(len(self.num_filters)): 80 """ 81 Determine input_dim and output_dim of conv layers in this block. The first layer is input x output, 82 All the subsequent layers are output x output. 83 """ 84 85 input_dim = self.input_channels if i == 0 else output_dim 86 output_dim = num_filters[i] 87 88 if i != 0: 89 layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 90 91 layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3)) 92 layers.append(nn.ReLU(inplace=True)) 93 94 for _ in range(no_convs_per_block-1): 95 layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3)) 96 layers.append(nn.ReLU(inplace=True)) 97 98 self.layers = nn.Sequential(*layers) 99 100 self.layers.apply(init_weights)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
107class AxisAlignedConvGaussian(nn.Module): 108 """ 109 A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix. 110 """ 111 112 def __init__( 113 self, 114 input_channels, 115 num_filters, 116 no_convs_per_block, 117 latent_dim, 118 initializers, 119 posterior=False, 120 num_classes=None 121 ): 122 123 super().__init__() 124 125 self.input_channels = input_channels 126 self.channel_axis = 1 127 self.num_filters = num_filters 128 self.no_convs_per_block = no_convs_per_block 129 self.latent_dim = latent_dim 130 131 self.posterior = posterior 132 if self.posterior: 133 self.name = 'Posterior' 134 else: 135 self.name = 'Prior' 136 137 self.encoder = Encoder( 138 self.input_channels, 139 self.num_filters, 140 self.no_convs_per_block, 141 initializers, 142 posterior=self.posterior, 143 num_classes=num_classes 144 ) 145 146 self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1) 147 self.show_img = 0 148 self.show_seg = 0 149 self.show_concat = 0 150 self.show_enc = 0 151 self.sum_input = 0 152 153 # 154 # @ Original paper's training details: 155 # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative 156 # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001 157 158 # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') # from Stefan's impl. 159 # nn.init.normal_(self.conv_layer.weight, std=0.001) # suggested @issues from Stefan's impl. 160 161 # nn.init.normal_(self.conv_layer.bias) # from Stefan's impl. 162 # 163 164 nn.init.orthogonal_(self.conv_layer.weight, gain=1) 165 nn.init.trunc_normal_(self.conv_layer.bias, std=0.001) 166 167 def forward(self, input, segm=None): 168 169 # If segmentation is not none, concatenate the mask to the channel axis of the input 170 if segm is not None: 171 self.show_img = input 172 self.show_seg = segm 173 input = torch.cat((input, segm), dim=1) 174 self.show_concat = input 175 self.sum_input = torch.sum(input) 176 177 encoding = self.encoder(input) 178 self.show_enc = encoding 179 180 # We only want the mean of the resulting hxw image 181 encoding = torch.mean(encoding, dim=2, keepdim=True) 182 encoding = torch.mean(encoding, dim=3, keepdim=True) 183 184 # Convert encoding to 2 x latent dim and split up for mu and log_sigma 185 mu_log_sigma = self.conv_layer(encoding) 186 187 # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1 188 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 189 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 190 191 mu = mu_log_sigma[:, :self.latent_dim] 192 log_sigma = mu_log_sigma[:, self.latent_dim:] 193 194 # This is a multivariate normal with diagonal covariance matrix sigma 195 # https://github.com/pytorch/pytorch/pull/11178 196 dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) 197 return dist
A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
112 def __init__( 113 self, 114 input_channels, 115 num_filters, 116 no_convs_per_block, 117 latent_dim, 118 initializers, 119 posterior=False, 120 num_classes=None 121 ): 122 123 super().__init__() 124 125 self.input_channels = input_channels 126 self.channel_axis = 1 127 self.num_filters = num_filters 128 self.no_convs_per_block = no_convs_per_block 129 self.latent_dim = latent_dim 130 131 self.posterior = posterior 132 if self.posterior: 133 self.name = 'Posterior' 134 else: 135 self.name = 'Prior' 136 137 self.encoder = Encoder( 138 self.input_channels, 139 self.num_filters, 140 self.no_convs_per_block, 141 initializers, 142 posterior=self.posterior, 143 num_classes=num_classes 144 ) 145 146 self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1) 147 self.show_img = 0 148 self.show_seg = 0 149 self.show_concat = 0 150 self.show_enc = 0 151 self.sum_input = 0 152 153 # 154 # @ Original paper's training details: 155 # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative 156 # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001 157 158 # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') # from Stefan's impl. 159 # nn.init.normal_(self.conv_layer.weight, std=0.001) # suggested @issues from Stefan's impl. 160 161 # nn.init.normal_(self.conv_layer.bias) # from Stefan's impl. 162 # 163 164 nn.init.orthogonal_(self.conv_layer.weight, gain=1) 165 nn.init.trunc_normal_(self.conv_layer.bias, std=0.001)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
167 def forward(self, input, segm=None): 168 169 # If segmentation is not none, concatenate the mask to the channel axis of the input 170 if segm is not None: 171 self.show_img = input 172 self.show_seg = segm 173 input = torch.cat((input, segm), dim=1) 174 self.show_concat = input 175 self.sum_input = torch.sum(input) 176 177 encoding = self.encoder(input) 178 self.show_enc = encoding 179 180 # We only want the mean of the resulting hxw image 181 encoding = torch.mean(encoding, dim=2, keepdim=True) 182 encoding = torch.mean(encoding, dim=3, keepdim=True) 183 184 # Convert encoding to 2 x latent dim and split up for mu and log_sigma 185 mu_log_sigma = self.conv_layer(encoding) 186 187 # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1 188 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 189 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 190 191 mu = mu_log_sigma[:, :self.latent_dim] 192 log_sigma = mu_log_sigma[:, self.latent_dim:] 193 194 # This is a multivariate normal with diagonal covariance matrix sigma 195 # https://github.com/pytorch/pytorch/pull/11178 196 dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) 197 return dist
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
200class Fcomb(nn.Module): 201 """ 202 A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space, 203 and output of the UNet (the feature map) by concatenating them along their channel axis. 204 """ 205 def __init__( 206 self, 207 num_filters, 208 latent_dim, 209 num_output_channels, 210 num_classes, 211 no_convs_fcomb, 212 initializers, 213 use_tile=True, 214 device=None 215 ): 216 217 super().__init__() 218 219 self.num_channels = num_output_channels 220 self.num_classes = num_classes 221 self.channel_axis = 1 222 self.spatial_axes = [2, 3] 223 self.num_filters = num_filters 224 self.latent_dim = latent_dim 225 self.use_tile = use_tile 226 self.no_convs_fcomb = no_convs_fcomb 227 self.name = 'Fcomb' 228 229 if device is None: 230 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 231 else: 232 self.device = device 233 234 if self.use_tile: 235 layers = [] 236 237 # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer 238 layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1)) 239 layers.append(nn.ReLU(inplace=True)) 240 241 for _ in range(no_convs_fcomb-2): 242 layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1)) 243 layers.append(nn.ReLU(inplace=True)) 244 245 self.layers = nn.Sequential(*layers) 246 247 self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1) 248 249 if initializers['w'] == 'orthogonal': 250 self.layers.apply(init_weights_orthogonal_normal) 251 self.last_layer.apply(init_weights_orthogonal_normal) 252 else: 253 self.layers.apply(init_weights) 254 self.last_layer.apply(init_weights) 255 256 def tile(self, a, dim, n_tile): 257 """ 258 This function is taken form PyTorch forum and mimics the behavior of tf.tile. 259 Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 260 """ 261 init_dim = a.size(dim) 262 repeat_idx = [1] * a.dim() 263 repeat_idx[dim] = n_tile 264 a = a.repeat(*(repeat_idx)) 265 order_index = torch.LongTensor( 266 np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 267 ).to(self.device) 268 return torch.index_select(a, dim, order_index) 269 270 def forward(self, feature_map, z): 271 """ 272 Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W). 273 So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified) 274 """ 275 if self.use_tile: 276 z = torch.unsqueeze(z, 2) 277 z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]]) 278 z = torch.unsqueeze(z, 3) 279 z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]]) 280 281 # Concatenate the feature map (output of the UNet) and the sample taken from the latent space 282 feature_map = torch.cat((feature_map, z), dim=self.channel_axis) 283 output = self.layers(feature_map) 284 return self.last_layer(output)
A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space, and output of the UNet (the feature map) by concatenating them along their channel axis.
205 def __init__( 206 self, 207 num_filters, 208 latent_dim, 209 num_output_channels, 210 num_classes, 211 no_convs_fcomb, 212 initializers, 213 use_tile=True, 214 device=None 215 ): 216 217 super().__init__() 218 219 self.num_channels = num_output_channels 220 self.num_classes = num_classes 221 self.channel_axis = 1 222 self.spatial_axes = [2, 3] 223 self.num_filters = num_filters 224 self.latent_dim = latent_dim 225 self.use_tile = use_tile 226 self.no_convs_fcomb = no_convs_fcomb 227 self.name = 'Fcomb' 228 229 if device is None: 230 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 231 else: 232 self.device = device 233 234 if self.use_tile: 235 layers = [] 236 237 # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer 238 layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1)) 239 layers.append(nn.ReLU(inplace=True)) 240 241 for _ in range(no_convs_fcomb-2): 242 layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1)) 243 layers.append(nn.ReLU(inplace=True)) 244 245 self.layers = nn.Sequential(*layers) 246 247 self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1) 248 249 if initializers['w'] == 'orthogonal': 250 self.layers.apply(init_weights_orthogonal_normal) 251 self.last_layer.apply(init_weights_orthogonal_normal) 252 else: 253 self.layers.apply(init_weights) 254 self.last_layer.apply(init_weights)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
256 def tile(self, a, dim, n_tile): 257 """ 258 This function is taken form PyTorch forum and mimics the behavior of tf.tile. 259 Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 260 """ 261 init_dim = a.size(dim) 262 repeat_idx = [1] * a.dim() 263 repeat_idx[dim] = n_tile 264 a = a.repeat(*(repeat_idx)) 265 order_index = torch.LongTensor( 266 np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 267 ).to(self.device) 268 return torch.index_select(a, dim, order_index)
This function is taken form PyTorch forum and mimics the behavior of tf.tile. Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
270 def forward(self, feature_map, z): 271 """ 272 Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W). 273 So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified) 274 """ 275 if self.use_tile: 276 z = torch.unsqueeze(z, 2) 277 z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]]) 278 z = torch.unsqueeze(z, 3) 279 z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]]) 280 281 # Concatenate the feature map (output of the UNet) and the sample taken from the latent space 282 feature_map = torch.cat((feature_map, z), dim=self.channel_axis) 283 output = self.layers(feature_map) 284 return self.last_layer(output)
Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W). So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
287class ProbabilisticUNet(nn.Module): 288 """ This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034). 289 This generative segmentation heuristic uses UNet combined with a conditional variational 290 autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses. 291 292 The following elements are initialized to get our desired network: 293 input_channels: the number of channels in the image (1 for grayscale and 3 for RGB) 294 num_classes: the number of classes to predict 295 num_filters: is a list consisting of the amount of filters layer 296 latent_dim: dimension of the latent space 297 no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior 298 beta: KL and reconstruction loss are weighted using a KL weighting factor (β) 299 consensus_masking: activates consensus masking in the reconstruction loss 300 rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss) 301 302 Args: 303 input_channels [int] - (default: 1) 304 num_classes [int] - (default: 1) 305 num_filters [list] - (default: [32, 64, 128, 192]) 306 latent_dim [int] - (default: 6) 307 no_convs_fcomb [int] - (default: 4) 308 beta [float] - (default: 10.0) 309 consensus_masking [bool] - (default: False) 310 rl_swap [bool] - (default: False) 311 device [torch.device] - (default: None) 312 """ 313 314 def __init__( 315 self, 316 input_channels=1, 317 num_classes=1, 318 num_filters=[32, 64, 128, 192], 319 latent_dim=6, 320 no_convs_fcomb=4, 321 beta=10.0, 322 consensus_masking=False, 323 rl_swap=False, 324 device=None 325 ): 326 327 super().__init__() 328 329 self.input_channels = input_channels 330 self.num_classes = num_classes 331 self.num_filters = num_filters 332 self.latent_dim = latent_dim 333 self.no_convs_per_block = 3 334 self.no_convs_fcomb = no_convs_fcomb 335 self.initializers = {'w': 'he_normal', 'b': 'normal'} 336 self.beta = beta 337 self.z_prior_sample = 0 338 self.consensus_masking = consensus_masking 339 self.rl_swap = rl_swap 340 341 if device is None: 342 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 343 else: 344 self.device = device 345 346 self.unet = UNet2d( 347 in_channels=self.input_channels, 348 out_channels=None, 349 depth=len(self.num_filters), 350 initial_features=num_filters[0] 351 ).to(self.device) 352 353 self.prior = AxisAlignedConvGaussian( 354 self.input_channels, 355 self.num_filters, 356 self.no_convs_per_block, 357 self.latent_dim, 358 self.initializers 359 ).to(self.device) 360 361 self.posterior = AxisAlignedConvGaussian( 362 self.input_channels, 363 self.num_filters, 364 self.no_convs_per_block, 365 self.latent_dim, 366 self.initializers, 367 posterior=True, 368 num_classes=num_classes 369 ).to(self.device) 370 371 self.fcomb = Fcomb( 372 self.num_filters, 373 self.latent_dim, 374 self.input_channels, 375 self.num_classes, 376 self.no_convs_fcomb, 377 {'w': 'orthogonal', 'b': 'normal'}, 378 use_tile=True, 379 device=self.device 380 ).to(self.device) 381 382 def _check_shape(self, patch): 383 spatial_shape = tuple(patch.shape)[2:] 384 depth = len(self.num_filters) 385 factor = [2**depth] * len(spatial_shape) 386 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 387 msg = f"Invalid shape for Probabilistic U-Net: {spatial_shape} is not divisible by {factor}" 388 raise ValueError(msg) 389 390 def forward(self, patch, segm=None): 391 """ 392 Construct prior latent space for patch and run patch through UNet, 393 in case training is True also construct posterior latent space 394 """ 395 self._check_shape(patch) 396 397 if segm is not None: 398 self.posterior_latent_space = self.posterior.forward(patch, segm) 399 self.prior_latent_space = self.prior.forward(patch) 400 self.unet_features = self.unet.forward(patch) 401 402 def sample(self, testing=False): 403 """ 404 Sample a segmentation by reconstructing from a prior sample and combining this with UNet features 405 """ 406 if testing is False: 407 # TODO: prior distribution ? (posterior in this case!) 408 z_prior = self.prior_latent_space.rsample() 409 self.z_prior_sample = z_prior 410 else: 411 # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. 412 # z_prior = self.prior_latent_space.base_dist.loc 413 z_prior = self.prior_latent_space.sample() 414 self.z_prior_sample = z_prior 415 return self.fcomb.forward(self.unet_features, z_prior) 416 417 def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): 418 """ 419 Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map 420 use_posterior_mean: use posterior_mean instead of sampling z_q 421 calculate_posterior: use a provided sample or sample from posterior latent space 422 """ 423 if use_posterior_mean: 424 z_posterior = self.posterior_latent_space.loc 425 else: 426 if calculate_posterior: 427 z_posterior = self.posterior_latent_space.rsample() 428 return self.fcomb.forward(self.unet_features, z_posterior) 429 430 def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): 431 """ 432 Calculate the KL divergence between the posterior and prior KL(Q||P) 433 analytic: calculate KL analytically or via sampling from the posterior 434 calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample 435 """ 436 if analytic: 437 # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 438 kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) 439 else: 440 if calculate_posterior: 441 z_posterior = self.posterior_latent_space.rsample() 442 log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) 443 log_prior_prob = self.prior_latent_space.log_prob(z_posterior) 444 kl_div = log_posterior_prob - log_prior_prob 445 return kl_div 446 447 def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False): 448 """ 449 Calculate the evidence lower bound of the log-likelihood of P(Y|X) 450 consm: consensus response 451 """ 452 453 if self.rl_swap: 454 criterion = DiceLossWithLogits() 455 else: 456 criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) 457 458 z_posterior = self.posterior_latent_space.rsample() 459 460 self.kl = torch.mean( 461 self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior) 462 ) 463 464 # Here we use the posterior sample sampled above 465 self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, 466 calculate_posterior=False, z_posterior=z_posterior) 467 468 if self.consensus_masking is True and consm is not None: 469 reconstruction_loss = criterion(self.reconstruction * consm, segm * consm) 470 else: 471 reconstruction_loss = criterion(self.reconstruction, segm) 472 473 self.reconstruction_loss = torch.sum(reconstruction_loss) 474 self.mean_reconstruction_loss = torch.mean(reconstruction_loss) 475 476 return -(self.reconstruction_loss + self.beta * self.kl)
This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034). This generative segmentation heuristic uses UNet combined with a conditional variational autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses.
The following elements are initialized to get our desired network: input_channels: the number of channels in the image (1 for grayscale and 3 for RGB) num_classes: the number of classes to predict num_filters: is a list consisting of the amount of filters layer latent_dim: dimension of the latent space no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior beta: KL and reconstruction loss are weighted using a KL weighting factor (β) consensus_masking: activates consensus masking in the reconstruction loss rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss)
Arguments:
- input_channels [int] - (default: 1)
- num_classes [int] - (default: 1)
- num_filters [list] - (default: [32, 64, 128, 192])
- latent_dim [int] - (default: 6)
- no_convs_fcomb [int] - (default: 4)
- beta [float] - (default: 10.0)
- consensus_masking [bool] - (default: False)
- rl_swap [bool] - (default: False)
- device [torch.device] - (default: None)
314 def __init__( 315 self, 316 input_channels=1, 317 num_classes=1, 318 num_filters=[32, 64, 128, 192], 319 latent_dim=6, 320 no_convs_fcomb=4, 321 beta=10.0, 322 consensus_masking=False, 323 rl_swap=False, 324 device=None 325 ): 326 327 super().__init__() 328 329 self.input_channels = input_channels 330 self.num_classes = num_classes 331 self.num_filters = num_filters 332 self.latent_dim = latent_dim 333 self.no_convs_per_block = 3 334 self.no_convs_fcomb = no_convs_fcomb 335 self.initializers = {'w': 'he_normal', 'b': 'normal'} 336 self.beta = beta 337 self.z_prior_sample = 0 338 self.consensus_masking = consensus_masking 339 self.rl_swap = rl_swap 340 341 if device is None: 342 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 343 else: 344 self.device = device 345 346 self.unet = UNet2d( 347 in_channels=self.input_channels, 348 out_channels=None, 349 depth=len(self.num_filters), 350 initial_features=num_filters[0] 351 ).to(self.device) 352 353 self.prior = AxisAlignedConvGaussian( 354 self.input_channels, 355 self.num_filters, 356 self.no_convs_per_block, 357 self.latent_dim, 358 self.initializers 359 ).to(self.device) 360 361 self.posterior = AxisAlignedConvGaussian( 362 self.input_channels, 363 self.num_filters, 364 self.no_convs_per_block, 365 self.latent_dim, 366 self.initializers, 367 posterior=True, 368 num_classes=num_classes 369 ).to(self.device) 370 371 self.fcomb = Fcomb( 372 self.num_filters, 373 self.latent_dim, 374 self.input_channels, 375 self.num_classes, 376 self.no_convs_fcomb, 377 {'w': 'orthogonal', 'b': 'normal'}, 378 use_tile=True, 379 device=self.device 380 ).to(self.device)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
390 def forward(self, patch, segm=None): 391 """ 392 Construct prior latent space for patch and run patch through UNet, 393 in case training is True also construct posterior latent space 394 """ 395 self._check_shape(patch) 396 397 if segm is not None: 398 self.posterior_latent_space = self.posterior.forward(patch, segm) 399 self.prior_latent_space = self.prior.forward(patch) 400 self.unet_features = self.unet.forward(patch)
Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space
402 def sample(self, testing=False): 403 """ 404 Sample a segmentation by reconstructing from a prior sample and combining this with UNet features 405 """ 406 if testing is False: 407 # TODO: prior distribution ? (posterior in this case!) 408 z_prior = self.prior_latent_space.rsample() 409 self.z_prior_sample = z_prior 410 else: 411 # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. 412 # z_prior = self.prior_latent_space.base_dist.loc 413 z_prior = self.prior_latent_space.sample() 414 self.z_prior_sample = z_prior 415 return self.fcomb.forward(self.unet_features, z_prior)
Sample a segmentation by reconstructing from a prior sample and combining this with UNet features
417 def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): 418 """ 419 Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map 420 use_posterior_mean: use posterior_mean instead of sampling z_q 421 calculate_posterior: use a provided sample or sample from posterior latent space 422 """ 423 if use_posterior_mean: 424 z_posterior = self.posterior_latent_space.loc 425 else: 426 if calculate_posterior: 427 z_posterior = self.posterior_latent_space.rsample() 428 return self.fcomb.forward(self.unet_features, z_posterior)
Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map use_posterior_mean: use posterior_mean instead of sampling z_q calculate_posterior: use a provided sample or sample from posterior latent space
430 def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): 431 """ 432 Calculate the KL divergence between the posterior and prior KL(Q||P) 433 analytic: calculate KL analytically or via sampling from the posterior 434 calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample 435 """ 436 if analytic: 437 # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 438 kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) 439 else: 440 if calculate_posterior: 441 z_posterior = self.posterior_latent_space.rsample() 442 log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) 443 log_prior_prob = self.prior_latent_space.log_prob(z_posterior) 444 kl_div = log_posterior_prob - log_prior_prob 445 return kl_div
Calculate the KL divergence between the posterior and prior KL(Q||P) analytic: calculate KL analytically or via sampling from the posterior calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
447 def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False): 448 """ 449 Calculate the evidence lower bound of the log-likelihood of P(Y|X) 450 consm: consensus response 451 """ 452 453 if self.rl_swap: 454 criterion = DiceLossWithLogits() 455 else: 456 criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) 457 458 z_posterior = self.posterior_latent_space.rsample() 459 460 self.kl = torch.mean( 461 self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior) 462 ) 463 464 # Here we use the posterior sample sampled above 465 self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, 466 calculate_posterior=False, z_posterior=z_posterior) 467 468 if self.consensus_masking is True and consm is not None: 469 reconstruction_loss = criterion(self.reconstruction * consm, segm * consm) 470 else: 471 reconstruction_loss = criterion(self.reconstruction, segm) 472 473 self.reconstruction_loss = torch.sum(reconstruction_loss) 474 self.mean_reconstruction_loss = torch.mean(reconstruction_loss) 475 476 return -(self.reconstruction_loss + self.beta * self.kl)
Calculate the evidence lower bound of the log-likelihood of P(Y|X) consm: consensus response