torch_em.model.probabilistic_unet

@private

  1"""@private
  2"""
  3
  4# This code is based on the original TensorFlow implementation: https://github.com/SimonKohl/probabilistic_unet
  5# The below implementation is from: https://github.com/stefanknegt/Probabilistic-Unet-Pytorch
  6
  7import numpy as np
  8
  9import torch
 10import torch.nn as nn
 11from torch.distributions import Normal, Independent, kl
 12
 13from torch_em.model import UNet2d
 14from torch_em.loss.dice import DiceLossWithLogits
 15
 16
 17def truncated_normal_(tensor, mean=0, std=1):
 18    """@private
 19    """
 20    size = tensor.shape
 21    tmp = tensor.new_empty(size + (4,)).normal_()
 22    valid = (tmp < 2) & (tmp > -2)
 23    ind = valid.max(-1, keepdim=True)[1]
 24    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
 25    tensor.data.mul_(std).add_(mean)
 26
 27
 28def init_weights(m):
 29    """@private
 30    """
 31    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
 32        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
 33        # nn.init.normal_(m.weight, std=0.001)
 34        # nn.init.normal_(m.bias, std=0.001)
 35        truncated_normal_(m.bias, mean=0, std=0.001)
 36
 37
 38def init_weights_orthogonal_normal(m):
 39    """@private
 40    """
 41    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
 42        nn.init.orthogonal_(m.weight)
 43        truncated_normal_(m.bias, mean=0, std=0.001)
 44        # nn.init.normal_(m.bias, std=0.001)
 45
 46
 47class Encoder(nn.Module):
 48    """
 49    A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block
 50    convolutional layers, after each block a pooling operation is performed.
 51    And after each convolutional layer a non-linear (ReLU) activation function is applied.
 52    """
 53    def __init__(
 54        self,
 55        input_channels,
 56        num_filters,
 57        no_convs_per_block,
 58        initializers,
 59        padding=True,
 60        posterior=False,
 61        num_classes=None
 62    ):
 63
 64        super().__init__()
 65
 66        self.contracting_path = nn.ModuleList()
 67        self.input_channels = input_channels
 68        self.num_filters = num_filters
 69
 70        if posterior:
 71            # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels.
 72            assert num_classes is not None
 73            self.input_channels += num_classes
 74
 75        layers = []
 76        output_dim = None  # Initialize output_dim of the layers
 77
 78        for i in range(len(self.num_filters)):
 79            """
 80            Determine input_dim and output_dim of conv layers in this block. The first layer is input x output,
 81            All the subsequent layers are output x output.
 82            """
 83
 84            input_dim = self.input_channels if i == 0 else output_dim
 85            output_dim = num_filters[i]
 86
 87            if i != 0:
 88                layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
 89
 90            layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3))
 91            layers.append(nn.ReLU(inplace=True))
 92
 93            for _ in range(no_convs_per_block-1):
 94                layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3))
 95                layers.append(nn.ReLU(inplace=True))
 96
 97        self.layers = nn.Sequential(*layers)
 98
 99        self.layers.apply(init_weights)
100
101    def forward(self, input):
102        output = self.layers(input)
103        return output
104
105
106class AxisAlignedConvGaussian(nn.Module):
107    """
108    A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
109    """
110
111    def __init__(
112        self,
113        input_channels,
114        num_filters,
115        no_convs_per_block,
116        latent_dim,
117        initializers,
118        posterior=False,
119        num_classes=None
120    ):
121
122        super().__init__()
123
124        self.input_channels = input_channels
125        self.channel_axis = 1
126        self.num_filters = num_filters
127        self.no_convs_per_block = no_convs_per_block
128        self.latent_dim = latent_dim
129
130        self.posterior = posterior
131        if self.posterior:
132            self.name = 'Posterior'
133        else:
134            self.name = 'Prior'
135
136        self.encoder = Encoder(
137                                self.input_channels,
138                                self.num_filters,
139                                self.no_convs_per_block,
140                                initializers,
141                                posterior=self.posterior,
142                                num_classes=num_classes
143                            )
144
145        self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1)
146        self.show_img = 0
147        self.show_seg = 0
148        self.show_concat = 0
149        self.show_enc = 0
150        self.sum_input = 0
151
152        #
153        # @ Original paper's training details:
154        # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative
155        # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001
156
157        # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')  # from Stefan's impl.
158        # nn.init.normal_(self.conv_layer.weight, std=0.001)  # suggested @issues from Stefan's impl.
159
160        # nn.init.normal_(self.conv_layer.bias)  # from Stefan's impl.
161        #
162
163        nn.init.orthogonal_(self.conv_layer.weight, gain=1)
164        nn.init.trunc_normal_(self.conv_layer.bias, std=0.001)
165
166    def forward(self, input, segm=None):
167
168        # If segmentation is not none, concatenate the mask to the channel axis of the input
169        if segm is not None:
170            self.show_img = input
171            self.show_seg = segm
172            input = torch.cat((input, segm), dim=1)
173            self.show_concat = input
174            self.sum_input = torch.sum(input)
175
176        encoding = self.encoder(input)
177        self.show_enc = encoding
178
179        # We only want the mean of the resulting hxw image
180        encoding = torch.mean(encoding, dim=2, keepdim=True)
181        encoding = torch.mean(encoding, dim=3, keepdim=True)
182
183        # Convert encoding to 2 x latent dim and split up for mu and log_sigma
184        mu_log_sigma = self.conv_layer(encoding)
185
186        # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
187        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
188        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
189
190        mu = mu_log_sigma[:, :self.latent_dim]
191        log_sigma = mu_log_sigma[:, self.latent_dim:]
192
193        # This is a multivariate normal with diagonal covariance matrix sigma
194        # https://github.com/pytorch/pytorch/pull/11178
195        dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1)
196        return dist
197
198
199class Fcomb(nn.Module):
200    """
201    A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space,
202    and output of the UNet (the feature map) by concatenating them along their channel axis.
203    """
204    def __init__(
205        self,
206        num_filters,
207        latent_dim,
208        num_output_channels,
209        num_classes,
210        no_convs_fcomb,
211        initializers,
212        use_tile=True,
213        device=None
214    ):
215
216        super().__init__()
217
218        self.num_channels = num_output_channels
219        self.num_classes = num_classes
220        self.channel_axis = 1
221        self.spatial_axes = [2, 3]
222        self.num_filters = num_filters
223        self.latent_dim = latent_dim
224        self.use_tile = use_tile
225        self.no_convs_fcomb = no_convs_fcomb
226        self.name = 'Fcomb'
227
228        if device is None:
229            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
230        else:
231            self.device = device
232
233        if self.use_tile:
234            layers = []
235
236            # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
237            layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1))
238            layers.append(nn.ReLU(inplace=True))
239
240            for _ in range(no_convs_fcomb-2):
241                layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1))
242                layers.append(nn.ReLU(inplace=True))
243
244            self.layers = nn.Sequential(*layers)
245
246            self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
247
248            if initializers['w'] == 'orthogonal':
249                self.layers.apply(init_weights_orthogonal_normal)
250                self.last_layer.apply(init_weights_orthogonal_normal)
251            else:
252                self.layers.apply(init_weights)
253                self.last_layer.apply(init_weights)
254
255    def tile(self, a, dim, n_tile):
256        """
257        This function is taken form PyTorch forum and mimics the behavior of tf.tile.
258        Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
259        """
260        init_dim = a.size(dim)
261        repeat_idx = [1] * a.dim()
262        repeat_idx[dim] = n_tile
263        a = a.repeat(*(repeat_idx))
264        order_index = torch.LongTensor(
265                                    np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
266                                ).to(self.device)
267        return torch.index_select(a, dim, order_index)
268
269    def forward(self, feature_map, z):
270        """
271        Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W).
272        So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
273        """
274        if self.use_tile:
275            z = torch.unsqueeze(z, 2)
276            z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]])
277            z = torch.unsqueeze(z, 3)
278            z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
279
280            # Concatenate the feature map (output of the UNet) and the sample taken from the latent space
281            feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
282            output = self.layers(feature_map)
283            return self.last_layer(output)
284
285
286class ProbabilisticUNet(nn.Module):
287    """ This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034).
288    This generative segmentation heuristic uses UNet combined with a conditional variational
289    autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses.
290
291    The following elements are initialized to get our desired network:
292    input_channels: the number of channels in the image (1 for grayscale and 3 for RGB)
293    num_classes: the number of classes to predict
294    num_filters: is a list consisting of the amount of filters layer
295    latent_dim: dimension of the latent space
296    no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
297    beta: KL and reconstruction loss are weighted using a KL weighting factor (β)
298    consensus_masking: activates consensus masking in the reconstruction loss
299    rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss)
300
301    Args:
302        input_channels [int] - (default: 1)
303        num_classes [int] - (default: 1)
304        num_filters [list] - (default: [32, 64, 128, 192])
305        latent_dim [int] - (default: 6)
306        no_convs_fcomb [int] - (default: 4)
307        beta [float] - (default: 10.0)
308        consensus_masking [bool] - (default: False)
309        rl_swap [bool] - (default: False)
310        device [torch.device] - (default: None)
311    """
312
313    def __init__(
314        self,
315        input_channels=1,
316        num_classes=1,
317        num_filters=[32, 64, 128, 192],
318        latent_dim=6,
319        no_convs_fcomb=4,
320        beta=10.0,
321        consensus_masking=False,
322        rl_swap=False,
323        device=None
324    ):
325
326        super().__init__()
327
328        self.input_channels = input_channels
329        self.num_classes = num_classes
330        self.num_filters = num_filters
331        self.latent_dim = latent_dim
332        self.no_convs_per_block = 3
333        self.no_convs_fcomb = no_convs_fcomb
334        self.initializers = {'w': 'he_normal', 'b': 'normal'}
335        self.beta = beta
336        self.z_prior_sample = 0
337        self.consensus_masking = consensus_masking
338        self.rl_swap = rl_swap
339
340        if device is None:
341            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
342        else:
343            self.device = device
344
345        self.unet = UNet2d(
346                            in_channels=self.input_channels,
347                            out_channels=None,
348                            depth=len(self.num_filters),
349                            initial_features=num_filters[0]
350                        ).to(self.device)
351
352        self.prior = AxisAlignedConvGaussian(
353                            self.input_channels,
354                            self.num_filters,
355                            self.no_convs_per_block,
356                            self.latent_dim,
357                            self.initializers
358                        ).to(self.device)
359
360        self.posterior = AxisAlignedConvGaussian(
361                            self.input_channels,
362                            self.num_filters,
363                            self.no_convs_per_block,
364                            self.latent_dim,
365                            self.initializers,
366                            posterior=True,
367                            num_classes=num_classes
368                        ).to(self.device)
369
370        self.fcomb = Fcomb(
371                            self.num_filters,
372                            self.latent_dim,
373                            self.input_channels,
374                            self.num_classes,
375                            self.no_convs_fcomb,
376                            {'w': 'orthogonal', 'b': 'normal'},
377                            use_tile=True,
378                            device=self.device
379                        ).to(self.device)
380
381    def _check_shape(self, patch):
382        spatial_shape = tuple(patch.shape)[2:]
383        depth = len(self.num_filters)
384        factor = [2**depth] * len(spatial_shape)
385        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
386            msg = f"Invalid shape for Probabilistic U-Net: {spatial_shape} is not divisible by {factor}"
387            raise ValueError(msg)
388
389    def forward(self, patch, segm=None):
390        """
391        Construct prior latent space for patch and run patch through UNet,
392        in case training is True also construct posterior latent space
393        """
394        self._check_shape(patch)
395
396        if segm is not None:
397            self.posterior_latent_space = self.posterior.forward(patch, segm)
398        self.prior_latent_space = self.prior.forward(patch)
399        self.unet_features = self.unet.forward(patch)
400
401    def sample(self, testing=False):
402        """
403        Sample a segmentation by reconstructing from a prior sample and combining this with UNet features
404        """
405        if testing is False:
406            # TODO: prior distribution ? (posterior in this case!)
407            z_prior = self.prior_latent_space.rsample()
408            self.z_prior_sample = z_prior
409        else:
410            # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
411            # z_prior = self.prior_latent_space.base_dist.loc
412            z_prior = self.prior_latent_space.sample()
413            self.z_prior_sample = z_prior
414        return self.fcomb.forward(self.unet_features, z_prior)
415
416    def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
417        """
418        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
419        use_posterior_mean: use posterior_mean instead of sampling z_q
420        calculate_posterior: use a provided sample or sample from posterior latent space
421        """
422        if use_posterior_mean:
423            z_posterior = self.posterior_latent_space.loc
424        else:
425            if calculate_posterior:
426                z_posterior = self.posterior_latent_space.rsample()
427        return self.fcomb.forward(self.unet_features, z_posterior)
428
429    def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
430        """
431        Calculate the KL divergence between the posterior and prior KL(Q||P)
432        analytic: calculate KL analytically or via sampling from the posterior
433        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
434        """
435        if analytic:
436            # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
437            kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
438        else:
439            if calculate_posterior:
440                z_posterior = self.posterior_latent_space.rsample()
441            log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
442            log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
443            kl_div = log_posterior_prob - log_prior_prob
444        return kl_div
445
446    def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False):
447        """
448        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
449        consm: consensus response
450        """
451
452        if self.rl_swap:
453            criterion = DiceLossWithLogits()
454        else:
455            criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None)
456
457        z_posterior = self.posterior_latent_space.rsample()
458
459        self.kl = torch.mean(
460                        self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)
461                    )
462
463        # Here we use the posterior sample sampled above
464        self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean,
465                                               calculate_posterior=False, z_posterior=z_posterior)
466
467        if self.consensus_masking is True and consm is not None:
468            reconstruction_loss = criterion(self.reconstruction * consm, segm * consm)
469        else:
470            reconstruction_loss = criterion(self.reconstruction, segm)
471
472        self.reconstruction_loss = torch.sum(reconstruction_loss)
473        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)
474
475        return -(self.reconstruction_loss + self.beta * self.kl)
class Encoder(torch.nn.modules.module.Module):
 48class Encoder(nn.Module):
 49    """
 50    A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block
 51    convolutional layers, after each block a pooling operation is performed.
 52    And after each convolutional layer a non-linear (ReLU) activation function is applied.
 53    """
 54    def __init__(
 55        self,
 56        input_channels,
 57        num_filters,
 58        no_convs_per_block,
 59        initializers,
 60        padding=True,
 61        posterior=False,
 62        num_classes=None
 63    ):
 64
 65        super().__init__()
 66
 67        self.contracting_path = nn.ModuleList()
 68        self.input_channels = input_channels
 69        self.num_filters = num_filters
 70
 71        if posterior:
 72            # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels.
 73            assert num_classes is not None
 74            self.input_channels += num_classes
 75
 76        layers = []
 77        output_dim = None  # Initialize output_dim of the layers
 78
 79        for i in range(len(self.num_filters)):
 80            """
 81            Determine input_dim and output_dim of conv layers in this block. The first layer is input x output,
 82            All the subsequent layers are output x output.
 83            """
 84
 85            input_dim = self.input_channels if i == 0 else output_dim
 86            output_dim = num_filters[i]
 87
 88            if i != 0:
 89                layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
 90
 91            layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3))
 92            layers.append(nn.ReLU(inplace=True))
 93
 94            for _ in range(no_convs_per_block-1):
 95                layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3))
 96                layers.append(nn.ReLU(inplace=True))
 97
 98        self.layers = nn.Sequential(*layers)
 99
100        self.layers.apply(init_weights)
101
102    def forward(self, input):
103        output = self.layers(input)
104        return output

A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers, after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied.

Encoder( input_channels, num_filters, no_convs_per_block, initializers, padding=True, posterior=False, num_classes=None)
 54    def __init__(
 55        self,
 56        input_channels,
 57        num_filters,
 58        no_convs_per_block,
 59        initializers,
 60        padding=True,
 61        posterior=False,
 62        num_classes=None
 63    ):
 64
 65        super().__init__()
 66
 67        self.contracting_path = nn.ModuleList()
 68        self.input_channels = input_channels
 69        self.num_filters = num_filters
 70
 71        if posterior:
 72            # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels.
 73            assert num_classes is not None
 74            self.input_channels += num_classes
 75
 76        layers = []
 77        output_dim = None  # Initialize output_dim of the layers
 78
 79        for i in range(len(self.num_filters)):
 80            """
 81            Determine input_dim and output_dim of conv layers in this block. The first layer is input x output,
 82            All the subsequent layers are output x output.
 83            """
 84
 85            input_dim = self.input_channels if i == 0 else output_dim
 86            output_dim = num_filters[i]
 87
 88            if i != 0:
 89                layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
 90
 91            layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3))
 92            layers.append(nn.ReLU(inplace=True))
 93
 94            for _ in range(no_convs_per_block-1):
 95                layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3))
 96                layers.append(nn.ReLU(inplace=True))
 97
 98        self.layers = nn.Sequential(*layers)
 99
100        self.layers.apply(init_weights)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

contracting_path
input_channels
num_filters
layers
def forward(self, input):
102    def forward(self, input):
103        output = self.layers(input)
104        return output

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class AxisAlignedConvGaussian(torch.nn.modules.module.Module):
107class AxisAlignedConvGaussian(nn.Module):
108    """
109    A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
110    """
111
112    def __init__(
113        self,
114        input_channels,
115        num_filters,
116        no_convs_per_block,
117        latent_dim,
118        initializers,
119        posterior=False,
120        num_classes=None
121    ):
122
123        super().__init__()
124
125        self.input_channels = input_channels
126        self.channel_axis = 1
127        self.num_filters = num_filters
128        self.no_convs_per_block = no_convs_per_block
129        self.latent_dim = latent_dim
130
131        self.posterior = posterior
132        if self.posterior:
133            self.name = 'Posterior'
134        else:
135            self.name = 'Prior'
136
137        self.encoder = Encoder(
138                                self.input_channels,
139                                self.num_filters,
140                                self.no_convs_per_block,
141                                initializers,
142                                posterior=self.posterior,
143                                num_classes=num_classes
144                            )
145
146        self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1)
147        self.show_img = 0
148        self.show_seg = 0
149        self.show_concat = 0
150        self.show_enc = 0
151        self.sum_input = 0
152
153        #
154        # @ Original paper's training details:
155        # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative
156        # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001
157
158        # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')  # from Stefan's impl.
159        # nn.init.normal_(self.conv_layer.weight, std=0.001)  # suggested @issues from Stefan's impl.
160
161        # nn.init.normal_(self.conv_layer.bias)  # from Stefan's impl.
162        #
163
164        nn.init.orthogonal_(self.conv_layer.weight, gain=1)
165        nn.init.trunc_normal_(self.conv_layer.bias, std=0.001)
166
167    def forward(self, input, segm=None):
168
169        # If segmentation is not none, concatenate the mask to the channel axis of the input
170        if segm is not None:
171            self.show_img = input
172            self.show_seg = segm
173            input = torch.cat((input, segm), dim=1)
174            self.show_concat = input
175            self.sum_input = torch.sum(input)
176
177        encoding = self.encoder(input)
178        self.show_enc = encoding
179
180        # We only want the mean of the resulting hxw image
181        encoding = torch.mean(encoding, dim=2, keepdim=True)
182        encoding = torch.mean(encoding, dim=3, keepdim=True)
183
184        # Convert encoding to 2 x latent dim and split up for mu and log_sigma
185        mu_log_sigma = self.conv_layer(encoding)
186
187        # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
188        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
189        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
190
191        mu = mu_log_sigma[:, :self.latent_dim]
192        log_sigma = mu_log_sigma[:, self.latent_dim:]
193
194        # This is a multivariate normal with diagonal covariance matrix sigma
195        # https://github.com/pytorch/pytorch/pull/11178
196        dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1)
197        return dist

A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.

AxisAlignedConvGaussian( input_channels, num_filters, no_convs_per_block, latent_dim, initializers, posterior=False, num_classes=None)
112    def __init__(
113        self,
114        input_channels,
115        num_filters,
116        no_convs_per_block,
117        latent_dim,
118        initializers,
119        posterior=False,
120        num_classes=None
121    ):
122
123        super().__init__()
124
125        self.input_channels = input_channels
126        self.channel_axis = 1
127        self.num_filters = num_filters
128        self.no_convs_per_block = no_convs_per_block
129        self.latent_dim = latent_dim
130
131        self.posterior = posterior
132        if self.posterior:
133            self.name = 'Posterior'
134        else:
135            self.name = 'Prior'
136
137        self.encoder = Encoder(
138                                self.input_channels,
139                                self.num_filters,
140                                self.no_convs_per_block,
141                                initializers,
142                                posterior=self.posterior,
143                                num_classes=num_classes
144                            )
145
146        self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1)
147        self.show_img = 0
148        self.show_seg = 0
149        self.show_concat = 0
150        self.show_enc = 0
151        self.sum_input = 0
152
153        #
154        # @ Original paper's training details:
155        # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative
156        # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001
157
158        # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')  # from Stefan's impl.
159        # nn.init.normal_(self.conv_layer.weight, std=0.001)  # suggested @issues from Stefan's impl.
160
161        # nn.init.normal_(self.conv_layer.bias)  # from Stefan's impl.
162        #
163
164        nn.init.orthogonal_(self.conv_layer.weight, gain=1)
165        nn.init.trunc_normal_(self.conv_layer.bias, std=0.001)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

input_channels
channel_axis
num_filters
no_convs_per_block
latent_dim
posterior
encoder
conv_layer
show_img
show_seg
show_concat
show_enc
sum_input
def forward(self, input, segm=None):
167    def forward(self, input, segm=None):
168
169        # If segmentation is not none, concatenate the mask to the channel axis of the input
170        if segm is not None:
171            self.show_img = input
172            self.show_seg = segm
173            input = torch.cat((input, segm), dim=1)
174            self.show_concat = input
175            self.sum_input = torch.sum(input)
176
177        encoding = self.encoder(input)
178        self.show_enc = encoding
179
180        # We only want the mean of the resulting hxw image
181        encoding = torch.mean(encoding, dim=2, keepdim=True)
182        encoding = torch.mean(encoding, dim=3, keepdim=True)
183
184        # Convert encoding to 2 x latent dim and split up for mu and log_sigma
185        mu_log_sigma = self.conv_layer(encoding)
186
187        # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
188        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
189        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
190
191        mu = mu_log_sigma[:, :self.latent_dim]
192        log_sigma = mu_log_sigma[:, self.latent_dim:]
193
194        # This is a multivariate normal with diagonal covariance matrix sigma
195        # https://github.com/pytorch/pytorch/pull/11178
196        dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1)
197        return dist

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Fcomb(torch.nn.modules.module.Module):
200class Fcomb(nn.Module):
201    """
202    A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space,
203    and output of the UNet (the feature map) by concatenating them along their channel axis.
204    """
205    def __init__(
206        self,
207        num_filters,
208        latent_dim,
209        num_output_channels,
210        num_classes,
211        no_convs_fcomb,
212        initializers,
213        use_tile=True,
214        device=None
215    ):
216
217        super().__init__()
218
219        self.num_channels = num_output_channels
220        self.num_classes = num_classes
221        self.channel_axis = 1
222        self.spatial_axes = [2, 3]
223        self.num_filters = num_filters
224        self.latent_dim = latent_dim
225        self.use_tile = use_tile
226        self.no_convs_fcomb = no_convs_fcomb
227        self.name = 'Fcomb'
228
229        if device is None:
230            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
231        else:
232            self.device = device
233
234        if self.use_tile:
235            layers = []
236
237            # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
238            layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1))
239            layers.append(nn.ReLU(inplace=True))
240
241            for _ in range(no_convs_fcomb-2):
242                layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1))
243                layers.append(nn.ReLU(inplace=True))
244
245            self.layers = nn.Sequential(*layers)
246
247            self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
248
249            if initializers['w'] == 'orthogonal':
250                self.layers.apply(init_weights_orthogonal_normal)
251                self.last_layer.apply(init_weights_orthogonal_normal)
252            else:
253                self.layers.apply(init_weights)
254                self.last_layer.apply(init_weights)
255
256    def tile(self, a, dim, n_tile):
257        """
258        This function is taken form PyTorch forum and mimics the behavior of tf.tile.
259        Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
260        """
261        init_dim = a.size(dim)
262        repeat_idx = [1] * a.dim()
263        repeat_idx[dim] = n_tile
264        a = a.repeat(*(repeat_idx))
265        order_index = torch.LongTensor(
266                                    np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
267                                ).to(self.device)
268        return torch.index_select(a, dim, order_index)
269
270    def forward(self, feature_map, z):
271        """
272        Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W).
273        So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
274        """
275        if self.use_tile:
276            z = torch.unsqueeze(z, 2)
277            z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]])
278            z = torch.unsqueeze(z, 3)
279            z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
280
281            # Concatenate the feature map (output of the UNet) and the sample taken from the latent space
282            feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
283            output = self.layers(feature_map)
284            return self.last_layer(output)

A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space, and output of the UNet (the feature map) by concatenating them along their channel axis.

Fcomb( num_filters, latent_dim, num_output_channels, num_classes, no_convs_fcomb, initializers, use_tile=True, device=None)
205    def __init__(
206        self,
207        num_filters,
208        latent_dim,
209        num_output_channels,
210        num_classes,
211        no_convs_fcomb,
212        initializers,
213        use_tile=True,
214        device=None
215    ):
216
217        super().__init__()
218
219        self.num_channels = num_output_channels
220        self.num_classes = num_classes
221        self.channel_axis = 1
222        self.spatial_axes = [2, 3]
223        self.num_filters = num_filters
224        self.latent_dim = latent_dim
225        self.use_tile = use_tile
226        self.no_convs_fcomb = no_convs_fcomb
227        self.name = 'Fcomb'
228
229        if device is None:
230            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
231        else:
232            self.device = device
233
234        if self.use_tile:
235            layers = []
236
237            # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
238            layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1))
239            layers.append(nn.ReLU(inplace=True))
240
241            for _ in range(no_convs_fcomb-2):
242                layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1))
243                layers.append(nn.ReLU(inplace=True))
244
245            self.layers = nn.Sequential(*layers)
246
247            self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
248
249            if initializers['w'] == 'orthogonal':
250                self.layers.apply(init_weights_orthogonal_normal)
251                self.last_layer.apply(init_weights_orthogonal_normal)
252            else:
253                self.layers.apply(init_weights)
254                self.last_layer.apply(init_weights)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_channels
num_classes
channel_axis
spatial_axes
num_filters
latent_dim
use_tile
no_convs_fcomb
name
def tile(self, a, dim, n_tile):
256    def tile(self, a, dim, n_tile):
257        """
258        This function is taken form PyTorch forum and mimics the behavior of tf.tile.
259        Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
260        """
261        init_dim = a.size(dim)
262        repeat_idx = [1] * a.dim()
263        repeat_idx[dim] = n_tile
264        a = a.repeat(*(repeat_idx))
265        order_index = torch.LongTensor(
266                                    np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
267                                ).to(self.device)
268        return torch.index_select(a, dim, order_index)

This function is taken form PyTorch forum and mimics the behavior of tf.tile. Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3

def forward(self, feature_map, z):
270    def forward(self, feature_map, z):
271        """
272        Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W).
273        So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
274        """
275        if self.use_tile:
276            z = torch.unsqueeze(z, 2)
277            z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]])
278            z = torch.unsqueeze(z, 3)
279            z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
280
281            # Concatenate the feature map (output of the UNet) and the sample taken from the latent space
282            feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
283            output = self.layers(feature_map)
284            return self.last_layer(output)

Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W). So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)

class ProbabilisticUNet(torch.nn.modules.module.Module):
287class ProbabilisticUNet(nn.Module):
288    """ This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034).
289    This generative segmentation heuristic uses UNet combined with a conditional variational
290    autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses.
291
292    The following elements are initialized to get our desired network:
293    input_channels: the number of channels in the image (1 for grayscale and 3 for RGB)
294    num_classes: the number of classes to predict
295    num_filters: is a list consisting of the amount of filters layer
296    latent_dim: dimension of the latent space
297    no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
298    beta: KL and reconstruction loss are weighted using a KL weighting factor (β)
299    consensus_masking: activates consensus masking in the reconstruction loss
300    rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss)
301
302    Args:
303        input_channels [int] - (default: 1)
304        num_classes [int] - (default: 1)
305        num_filters [list] - (default: [32, 64, 128, 192])
306        latent_dim [int] - (default: 6)
307        no_convs_fcomb [int] - (default: 4)
308        beta [float] - (default: 10.0)
309        consensus_masking [bool] - (default: False)
310        rl_swap [bool] - (default: False)
311        device [torch.device] - (default: None)
312    """
313
314    def __init__(
315        self,
316        input_channels=1,
317        num_classes=1,
318        num_filters=[32, 64, 128, 192],
319        latent_dim=6,
320        no_convs_fcomb=4,
321        beta=10.0,
322        consensus_masking=False,
323        rl_swap=False,
324        device=None
325    ):
326
327        super().__init__()
328
329        self.input_channels = input_channels
330        self.num_classes = num_classes
331        self.num_filters = num_filters
332        self.latent_dim = latent_dim
333        self.no_convs_per_block = 3
334        self.no_convs_fcomb = no_convs_fcomb
335        self.initializers = {'w': 'he_normal', 'b': 'normal'}
336        self.beta = beta
337        self.z_prior_sample = 0
338        self.consensus_masking = consensus_masking
339        self.rl_swap = rl_swap
340
341        if device is None:
342            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
343        else:
344            self.device = device
345
346        self.unet = UNet2d(
347                            in_channels=self.input_channels,
348                            out_channels=None,
349                            depth=len(self.num_filters),
350                            initial_features=num_filters[0]
351                        ).to(self.device)
352
353        self.prior = AxisAlignedConvGaussian(
354                            self.input_channels,
355                            self.num_filters,
356                            self.no_convs_per_block,
357                            self.latent_dim,
358                            self.initializers
359                        ).to(self.device)
360
361        self.posterior = AxisAlignedConvGaussian(
362                            self.input_channels,
363                            self.num_filters,
364                            self.no_convs_per_block,
365                            self.latent_dim,
366                            self.initializers,
367                            posterior=True,
368                            num_classes=num_classes
369                        ).to(self.device)
370
371        self.fcomb = Fcomb(
372                            self.num_filters,
373                            self.latent_dim,
374                            self.input_channels,
375                            self.num_classes,
376                            self.no_convs_fcomb,
377                            {'w': 'orthogonal', 'b': 'normal'},
378                            use_tile=True,
379                            device=self.device
380                        ).to(self.device)
381
382    def _check_shape(self, patch):
383        spatial_shape = tuple(patch.shape)[2:]
384        depth = len(self.num_filters)
385        factor = [2**depth] * len(spatial_shape)
386        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
387            msg = f"Invalid shape for Probabilistic U-Net: {spatial_shape} is not divisible by {factor}"
388            raise ValueError(msg)
389
390    def forward(self, patch, segm=None):
391        """
392        Construct prior latent space for patch and run patch through UNet,
393        in case training is True also construct posterior latent space
394        """
395        self._check_shape(patch)
396
397        if segm is not None:
398            self.posterior_latent_space = self.posterior.forward(patch, segm)
399        self.prior_latent_space = self.prior.forward(patch)
400        self.unet_features = self.unet.forward(patch)
401
402    def sample(self, testing=False):
403        """
404        Sample a segmentation by reconstructing from a prior sample and combining this with UNet features
405        """
406        if testing is False:
407            # TODO: prior distribution ? (posterior in this case!)
408            z_prior = self.prior_latent_space.rsample()
409            self.z_prior_sample = z_prior
410        else:
411            # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
412            # z_prior = self.prior_latent_space.base_dist.loc
413            z_prior = self.prior_latent_space.sample()
414            self.z_prior_sample = z_prior
415        return self.fcomb.forward(self.unet_features, z_prior)
416
417    def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
418        """
419        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
420        use_posterior_mean: use posterior_mean instead of sampling z_q
421        calculate_posterior: use a provided sample or sample from posterior latent space
422        """
423        if use_posterior_mean:
424            z_posterior = self.posterior_latent_space.loc
425        else:
426            if calculate_posterior:
427                z_posterior = self.posterior_latent_space.rsample()
428        return self.fcomb.forward(self.unet_features, z_posterior)
429
430    def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
431        """
432        Calculate the KL divergence between the posterior and prior KL(Q||P)
433        analytic: calculate KL analytically or via sampling from the posterior
434        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
435        """
436        if analytic:
437            # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
438            kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
439        else:
440            if calculate_posterior:
441                z_posterior = self.posterior_latent_space.rsample()
442            log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
443            log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
444            kl_div = log_posterior_prob - log_prior_prob
445        return kl_div
446
447    def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False):
448        """
449        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
450        consm: consensus response
451        """
452
453        if self.rl_swap:
454            criterion = DiceLossWithLogits()
455        else:
456            criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None)
457
458        z_posterior = self.posterior_latent_space.rsample()
459
460        self.kl = torch.mean(
461                        self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)
462                    )
463
464        # Here we use the posterior sample sampled above
465        self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean,
466                                               calculate_posterior=False, z_posterior=z_posterior)
467
468        if self.consensus_masking is True and consm is not None:
469            reconstruction_loss = criterion(self.reconstruction * consm, segm * consm)
470        else:
471            reconstruction_loss = criterion(self.reconstruction, segm)
472
473        self.reconstruction_loss = torch.sum(reconstruction_loss)
474        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)
475
476        return -(self.reconstruction_loss + self.beta * self.kl)

This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034). This generative segmentation heuristic uses UNet combined with a conditional variational autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses.

The following elements are initialized to get our desired network: input_channels: the number of channels in the image (1 for grayscale and 3 for RGB) num_classes: the number of classes to predict num_filters: is a list consisting of the amount of filters layer latent_dim: dimension of the latent space no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior beta: KL and reconstruction loss are weighted using a KL weighting factor (β) consensus_masking: activates consensus masking in the reconstruction loss rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss)

Arguments:
  • input_channels [int] - (default: 1)
  • num_classes [int] - (default: 1)
  • num_filters [list] - (default: [32, 64, 128, 192])
  • latent_dim [int] - (default: 6)
  • no_convs_fcomb [int] - (default: 4)
  • beta [float] - (default: 10.0)
  • consensus_masking [bool] - (default: False)
  • rl_swap [bool] - (default: False)
  • device [torch.device] - (default: None)
ProbabilisticUNet( input_channels=1, num_classes=1, num_filters=[32, 64, 128, 192], latent_dim=6, no_convs_fcomb=4, beta=10.0, consensus_masking=False, rl_swap=False, device=None)
314    def __init__(
315        self,
316        input_channels=1,
317        num_classes=1,
318        num_filters=[32, 64, 128, 192],
319        latent_dim=6,
320        no_convs_fcomb=4,
321        beta=10.0,
322        consensus_masking=False,
323        rl_swap=False,
324        device=None
325    ):
326
327        super().__init__()
328
329        self.input_channels = input_channels
330        self.num_classes = num_classes
331        self.num_filters = num_filters
332        self.latent_dim = latent_dim
333        self.no_convs_per_block = 3
334        self.no_convs_fcomb = no_convs_fcomb
335        self.initializers = {'w': 'he_normal', 'b': 'normal'}
336        self.beta = beta
337        self.z_prior_sample = 0
338        self.consensus_masking = consensus_masking
339        self.rl_swap = rl_swap
340
341        if device is None:
342            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
343        else:
344            self.device = device
345
346        self.unet = UNet2d(
347                            in_channels=self.input_channels,
348                            out_channels=None,
349                            depth=len(self.num_filters),
350                            initial_features=num_filters[0]
351                        ).to(self.device)
352
353        self.prior = AxisAlignedConvGaussian(
354                            self.input_channels,
355                            self.num_filters,
356                            self.no_convs_per_block,
357                            self.latent_dim,
358                            self.initializers
359                        ).to(self.device)
360
361        self.posterior = AxisAlignedConvGaussian(
362                            self.input_channels,
363                            self.num_filters,
364                            self.no_convs_per_block,
365                            self.latent_dim,
366                            self.initializers,
367                            posterior=True,
368                            num_classes=num_classes
369                        ).to(self.device)
370
371        self.fcomb = Fcomb(
372                            self.num_filters,
373                            self.latent_dim,
374                            self.input_channels,
375                            self.num_classes,
376                            self.no_convs_fcomb,
377                            {'w': 'orthogonal', 'b': 'normal'},
378                            use_tile=True,
379                            device=self.device
380                        ).to(self.device)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

input_channels
num_classes
num_filters
latent_dim
no_convs_per_block
no_convs_fcomb
initializers
beta
z_prior_sample
consensus_masking
rl_swap
unet
prior
posterior
fcomb
def forward(self, patch, segm=None):
390    def forward(self, patch, segm=None):
391        """
392        Construct prior latent space for patch and run patch through UNet,
393        in case training is True also construct posterior latent space
394        """
395        self._check_shape(patch)
396
397        if segm is not None:
398            self.posterior_latent_space = self.posterior.forward(patch, segm)
399        self.prior_latent_space = self.prior.forward(patch)
400        self.unet_features = self.unet.forward(patch)

Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space

def sample(self, testing=False):
402    def sample(self, testing=False):
403        """
404        Sample a segmentation by reconstructing from a prior sample and combining this with UNet features
405        """
406        if testing is False:
407            # TODO: prior distribution ? (posterior in this case!)
408            z_prior = self.prior_latent_space.rsample()
409            self.z_prior_sample = z_prior
410        else:
411            # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
412            # z_prior = self.prior_latent_space.base_dist.loc
413            z_prior = self.prior_latent_space.sample()
414            self.z_prior_sample = z_prior
415        return self.fcomb.forward(self.unet_features, z_prior)

Sample a segmentation by reconstructing from a prior sample and combining this with UNet features

def reconstruct( self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
417    def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
418        """
419        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
420        use_posterior_mean: use posterior_mean instead of sampling z_q
421        calculate_posterior: use a provided sample or sample from posterior latent space
422        """
423        if use_posterior_mean:
424            z_posterior = self.posterior_latent_space.loc
425        else:
426            if calculate_posterior:
427                z_posterior = self.posterior_latent_space.rsample()
428        return self.fcomb.forward(self.unet_features, z_posterior)

Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map use_posterior_mean: use posterior_mean instead of sampling z_q calculate_posterior: use a provided sample or sample from posterior latent space

def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
430    def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
431        """
432        Calculate the KL divergence between the posterior and prior KL(Q||P)
433        analytic: calculate KL analytically or via sampling from the posterior
434        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
435        """
436        if analytic:
437            # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
438            kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
439        else:
440            if calculate_posterior:
441                z_posterior = self.posterior_latent_space.rsample()
442            log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
443            log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
444            kl_div = log_posterior_prob - log_prior_prob
445        return kl_div

Calculate the KL divergence between the posterior and prior KL(Q||P) analytic: calculate KL analytically or via sampling from the posterior calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample

def elbo( self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False):
447    def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False):
448        """
449        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
450        consm: consensus response
451        """
452
453        if self.rl_swap:
454            criterion = DiceLossWithLogits()
455        else:
456            criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None)
457
458        z_posterior = self.posterior_latent_space.rsample()
459
460        self.kl = torch.mean(
461                        self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)
462                    )
463
464        # Here we use the posterior sample sampled above
465        self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean,
466                                               calculate_posterior=False, z_posterior=z_posterior)
467
468        if self.consensus_masking is True and consm is not None:
469            reconstruction_loss = criterion(self.reconstruction * consm, segm * consm)
470        else:
471            reconstruction_loss = criterion(self.reconstruction, segm)
472
473        self.reconstruction_loss = torch.sum(reconstruction_loss)
474        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)
475
476        return -(self.reconstruction_loss + self.beta * self.kl)

Calculate the evidence lower bound of the log-likelihood of P(Y|X) consm: consensus response