torch_em.model.probabilistic_unet

  1# This code is based on the original TensorFlow implementation: https://github.com/SimonKohl/probabilistic_unet
  2# The below implementation is from: https://github.com/stefanknegt/Probabilistic-Unet-Pytorch
  3
  4import numpy as np
  5
  6import torch
  7import torch.nn as nn
  8from torch.distributions import Normal, Independent, kl
  9
 10from torch_em.model import UNet2d
 11from torch_em.loss.dice import DiceLossWithLogits
 12
 13
 14def truncated_normal_(tensor, mean=0, std=1):
 15    size = tensor.shape
 16    tmp = tensor.new_empty(size + (4,)).normal_()
 17    valid = (tmp < 2) & (tmp > -2)
 18    ind = valid.max(-1, keepdim=True)[1]
 19    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
 20    tensor.data.mul_(std).add_(mean)
 21
 22
 23def init_weights(m):
 24    if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
 25        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
 26        # nn.init.normal_(m.weight, std=0.001)
 27        # nn.init.normal_(m.bias, std=0.001)
 28        truncated_normal_(m.bias, mean=0, std=0.001)
 29
 30
 31def init_weights_orthogonal_normal(m):
 32    if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
 33        nn.init.orthogonal_(m.weight)
 34        truncated_normal_(m.bias, mean=0, std=0.001)
 35        # nn.init.normal_(m.bias, std=0.001)
 36
 37
 38class Encoder(nn.Module):
 39    """
 40    A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block
 41    convolutional layers, after each block a pooling operation is performed.
 42    And after each convolutional layer a non-linear (ReLU) activation function is applied.
 43    """
 44
 45    def __init__(
 46        self,
 47        input_channels,
 48        num_filters,
 49        no_convs_per_block,
 50        initializers,
 51        padding=True,
 52        posterior=False,
 53        num_classes=None
 54    ):
 55
 56        super().__init__()
 57
 58        self.contracting_path = nn.ModuleList()
 59        self.input_channels = input_channels
 60        self.num_filters = num_filters
 61
 62        if posterior:
 63            # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels.
 64            assert num_classes is not None
 65            self.input_channels += num_classes
 66
 67        layers = []
 68        output_dim = None  # Initialize output_dim of the layers
 69
 70        for i in range(len(self.num_filters)):
 71            """
 72            Determine input_dim and output_dim of conv layers in this block. The first layer is input x output,
 73            All the subsequent layers are output x output.
 74            """
 75
 76            input_dim = self.input_channels if i == 0 else output_dim
 77            output_dim = num_filters[i]
 78
 79            if i != 0:
 80                layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
 81
 82            layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3))
 83            layers.append(nn.ReLU(inplace=True))
 84
 85            for _ in range(no_convs_per_block-1):
 86                layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3))
 87                layers.append(nn.ReLU(inplace=True))
 88
 89        self.layers = nn.Sequential(*layers)
 90
 91        self.layers.apply(init_weights)
 92
 93    def forward(self, input):
 94        output = self.layers(input)
 95        return output
 96
 97
 98class AxisAlignedConvGaussian(nn.Module):
 99    """
100    A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
101    """
102
103    def __init__(
104        self,
105        input_channels,
106        num_filters,
107        no_convs_per_block,
108        latent_dim,
109        initializers,
110        posterior=False,
111        num_classes=None
112    ):
113
114        super().__init__()
115
116        self.input_channels = input_channels
117        self.channel_axis = 1
118        self.num_filters = num_filters
119        self.no_convs_per_block = no_convs_per_block
120        self.latent_dim = latent_dim
121
122        self.posterior = posterior
123        if self.posterior:
124            self.name = 'Posterior'
125        else:
126            self.name = 'Prior'
127
128        self.encoder = Encoder(
129                                self.input_channels,
130                                self.num_filters,
131                                self.no_convs_per_block,
132                                initializers,
133                                posterior=self.posterior,
134                                num_classes=num_classes
135                            )
136
137        self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1)
138        self.show_img = 0
139        self.show_seg = 0
140        self.show_concat = 0
141        self.show_enc = 0
142        self.sum_input = 0
143
144        #
145        # @ Original paper's training details:
146        # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative
147        # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001
148
149        # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')  # from Stefan's impl.
150        # nn.init.normal_(self.conv_layer.weight, std=0.001)  # suggested @issues from Stefan's impl.
151
152        # nn.init.normal_(self.conv_layer.bias)  # from Stefan's impl.
153        #
154
155        nn.init.orthogonal_(self.conv_layer.weight, gain=1)
156        nn.init.trunc_normal_(self.conv_layer.bias, std=0.001)
157
158    def forward(self, input, segm=None):
159
160        # If segmentation is not none, concatenate the mask to the channel axis of the input
161        if segm is not None:
162            self.show_img = input
163            self.show_seg = segm
164            input = torch.cat((input, segm), dim=1)
165            self.show_concat = input
166            self.sum_input = torch.sum(input)
167
168        encoding = self.encoder(input)
169        self.show_enc = encoding
170
171        # We only want the mean of the resulting hxw image
172        encoding = torch.mean(encoding, dim=2, keepdim=True)
173        encoding = torch.mean(encoding, dim=3, keepdim=True)
174
175        # Convert encoding to 2 x latent dim and split up for mu and log_sigma
176        mu_log_sigma = self.conv_layer(encoding)
177
178        # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
179        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
180        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
181
182        mu = mu_log_sigma[:, :self.latent_dim]
183        log_sigma = mu_log_sigma[:, self.latent_dim:]
184
185        # This is a multivariate normal with diagonal covariance matrix sigma
186        # https://github.com/pytorch/pytorch/pull/11178
187        dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1)
188        return dist
189
190
191class Fcomb(nn.Module):
192    """
193    A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space,
194    and output of the UNet (the feature map) by concatenating them along their channel axis.
195    """
196    def __init__(
197        self,
198        num_filters,
199        latent_dim,
200        num_output_channels,
201        num_classes,
202        no_convs_fcomb,
203        initializers,
204        use_tile=True,
205        device=None
206    ):
207
208        super().__init__()
209
210        self.num_channels = num_output_channels
211        self.num_classes = num_classes
212        self.channel_axis = 1
213        self.spatial_axes = [2, 3]
214        self.num_filters = num_filters
215        self.latent_dim = latent_dim
216        self.use_tile = use_tile
217        self.no_convs_fcomb = no_convs_fcomb
218        self.name = 'Fcomb'
219
220        if device is None:
221            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
222        else:
223            self.device = device
224
225        if self.use_tile:
226            layers = []
227
228            # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
229            layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1))
230            layers.append(nn.ReLU(inplace=True))
231
232            for _ in range(no_convs_fcomb-2):
233                layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1))
234                layers.append(nn.ReLU(inplace=True))
235
236            self.layers = nn.Sequential(*layers)
237
238            self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
239
240            if initializers['w'] == 'orthogonal':
241                self.layers.apply(init_weights_orthogonal_normal)
242                self.last_layer.apply(init_weights_orthogonal_normal)
243            else:
244                self.layers.apply(init_weights)
245                self.last_layer.apply(init_weights)
246
247    def tile(self, a, dim, n_tile):
248        """
249        This function is taken form PyTorch forum and mimics the behavior of tf.tile.
250        Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
251        """
252        init_dim = a.size(dim)
253        repeat_idx = [1] * a.dim()
254        repeat_idx[dim] = n_tile
255        a = a.repeat(*(repeat_idx))
256        order_index = torch.LongTensor(
257                                    np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
258                                ).to(self.device)
259        return torch.index_select(a, dim, order_index)
260
261    def forward(self, feature_map, z):
262        """
263        Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W).
264        So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
265        """
266        if self.use_tile:
267            z = torch.unsqueeze(z, 2)
268            z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]])
269            z = torch.unsqueeze(z, 3)
270            z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
271
272            # Concatenate the feature map (output of the UNet) and the sample taken from the latent space
273            feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
274            output = self.layers(feature_map)
275            return self.last_layer(output)
276
277
278class ProbabilisticUNet(nn.Module):
279    """ This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034).
280    This generative segmentation heuristic uses UNet combined with a conditional variational
281    autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses.
282
283    The following elements are initialized to get our desired network:
284    input_channels: the number of channels in the image (1 for grayscale and 3 for RGB)
285    num_classes: the number of classes to predict
286    num_filters: is a list consisting of the amount of filters layer
287    latent_dim: dimension of the latent space
288    no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
289    beta: KL and reconstruction loss are weighted using a KL weighting factor (β)
290    consensus_masking: activates consensus masking in the reconstruction loss
291    rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss)
292
293    Parameters:
294        input_channels [int] - (default: 1)
295        num_classes [int] - (default: 1)
296        num_filters [list] - (default: [32, 64, 128, 192])
297        latent_dim [int] - (default: 6)
298        no_convs_fcomb [int] - (default: 4)
299        beta [float] - (default: 10.0)
300        consensus_masking [bool] - (default: False)
301        rl_swap [bool] - (default: False)
302        device [torch.device] - (default: None)
303    """
304
305    def __init__(
306        self,
307        input_channels=1,
308        num_classes=1,
309        num_filters=[32, 64, 128, 192],
310        latent_dim=6,
311        no_convs_fcomb=4,
312        beta=10.0,
313        consensus_masking=False,
314        rl_swap=False,
315        device=None
316    ):
317
318        super().__init__()
319
320        self.input_channels = input_channels
321        self.num_classes = num_classes
322        self.num_filters = num_filters
323        self.latent_dim = latent_dim
324        self.no_convs_per_block = 3
325        self.no_convs_fcomb = no_convs_fcomb
326        self.initializers = {'w': 'he_normal', 'b': 'normal'}
327        self.beta = beta
328        self.z_prior_sample = 0
329        self.consensus_masking = consensus_masking
330        self.rl_swap = rl_swap
331
332        if device is None:
333            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
334        else:
335            self.device = device
336
337        self.unet = UNet2d(
338                            in_channels=self.input_channels,
339                            out_channels=None,
340                            depth=len(self.num_filters),
341                            initial_features=num_filters[0]
342                        ).to(self.device)
343
344        self.prior = AxisAlignedConvGaussian(
345                            self.input_channels,
346                            self.num_filters,
347                            self.no_convs_per_block,
348                            self.latent_dim,
349                            self.initializers
350                        ).to(self.device)
351
352        self.posterior = AxisAlignedConvGaussian(
353                            self.input_channels,
354                            self.num_filters,
355                            self.no_convs_per_block,
356                            self.latent_dim,
357                            self.initializers,
358                            posterior=True,
359                            num_classes=num_classes
360                        ).to(self.device)
361
362        self.fcomb = Fcomb(
363                            self.num_filters,
364                            self.latent_dim,
365                            self.input_channels,
366                            self.num_classes,
367                            self.no_convs_fcomb,
368                            {'w': 'orthogonal', 'b': 'normal'},
369                            use_tile=True,
370                            device=self.device
371                        ).to(self.device)
372
373    def _check_shape(self, patch):
374        spatial_shape = tuple(patch.shape)[2:]
375        depth = len(self.num_filters)
376        factor = [2**depth] * len(spatial_shape)
377        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
378            msg = f"Invalid shape for Probabilistic U-Net: {spatial_shape} is not divisible by {factor}"
379            raise ValueError(msg)
380
381    def forward(self, patch, segm=None):
382        """
383        Construct prior latent space for patch and run patch through UNet,
384        in case training is True also construct posterior latent space
385        """
386        self._check_shape(patch)
387
388        if segm is not None:
389            self.posterior_latent_space = self.posterior.forward(patch, segm)
390        self.prior_latent_space = self.prior.forward(patch)
391        self.unet_features = self.unet.forward(patch)
392
393    def sample(self, testing=False):
394        """
395        Sample a segmentation by reconstructing from a prior sample and combining this with UNet features
396        """
397        if testing is False:
398            # TODO: prior distribution ? (posterior in this case!)
399            z_prior = self.prior_latent_space.rsample()
400            self.z_prior_sample = z_prior
401        else:
402            # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
403            # z_prior = self.prior_latent_space.base_dist.loc
404            z_prior = self.prior_latent_space.sample()
405            self.z_prior_sample = z_prior
406        return self.fcomb.forward(self.unet_features, z_prior)
407
408    def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
409        """
410        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
411        use_posterior_mean: use posterior_mean instead of sampling z_q
412        calculate_posterior: use a provided sample or sample from posterior latent space
413        """
414        if use_posterior_mean:
415            z_posterior = self.posterior_latent_space.loc
416        else:
417            if calculate_posterior:
418                z_posterior = self.posterior_latent_space.rsample()
419        return self.fcomb.forward(self.unet_features, z_posterior)
420
421    def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
422        """
423        Calculate the KL divergence between the posterior and prior KL(Q||P)
424        analytic: calculate KL analytically or via sampling from the posterior
425        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
426        """
427        if analytic:
428            # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
429            kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
430        else:
431            if calculate_posterior:
432                z_posterior = self.posterior_latent_space.rsample()
433            log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
434            log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
435            kl_div = log_posterior_prob - log_prior_prob
436        return kl_div
437
438    def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False):
439        """
440        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
441        consm: consensus response
442        """
443
444        if self.rl_swap:
445            criterion = DiceLossWithLogits()
446        else:
447            criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None)
448
449        z_posterior = self.posterior_latent_space.rsample()
450
451        self.kl = torch.mean(
452                        self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)
453                    )
454
455        # Here we use the posterior sample sampled above
456        self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean,
457                                               calculate_posterior=False, z_posterior=z_posterior)
458
459        if self.consensus_masking is True and consm is not None:
460            reconstruction_loss = criterion(self.reconstruction * consm, segm * consm)
461        else:
462            reconstruction_loss = criterion(self.reconstruction, segm)
463
464        self.reconstruction_loss = torch.sum(reconstruction_loss)
465        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)
466
467        return -(self.reconstruction_loss + self.beta * self.kl)
def truncated_normal_(tensor, mean=0, std=1):
15def truncated_normal_(tensor, mean=0, std=1):
16    size = tensor.shape
17    tmp = tensor.new_empty(size + (4,)).normal_()
18    valid = (tmp < 2) & (tmp > -2)
19    ind = valid.max(-1, keepdim=True)[1]
20    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
21    tensor.data.mul_(std).add_(mean)
def init_weights(m):
24def init_weights(m):
25    if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
26        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
27        # nn.init.normal_(m.weight, std=0.001)
28        # nn.init.normal_(m.bias, std=0.001)
29        truncated_normal_(m.bias, mean=0, std=0.001)
def init_weights_orthogonal_normal(m):
32def init_weights_orthogonal_normal(m):
33    if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
34        nn.init.orthogonal_(m.weight)
35        truncated_normal_(m.bias, mean=0, std=0.001)
36        # nn.init.normal_(m.bias, std=0.001)
class Encoder(torch.nn.modules.module.Module):
39class Encoder(nn.Module):
40    """
41    A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block
42    convolutional layers, after each block a pooling operation is performed.
43    And after each convolutional layer a non-linear (ReLU) activation function is applied.
44    """
45
46    def __init__(
47        self,
48        input_channels,
49        num_filters,
50        no_convs_per_block,
51        initializers,
52        padding=True,
53        posterior=False,
54        num_classes=None
55    ):
56
57        super().__init__()
58
59        self.contracting_path = nn.ModuleList()
60        self.input_channels = input_channels
61        self.num_filters = num_filters
62
63        if posterior:
64            # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels.
65            assert num_classes is not None
66            self.input_channels += num_classes
67
68        layers = []
69        output_dim = None  # Initialize output_dim of the layers
70
71        for i in range(len(self.num_filters)):
72            """
73            Determine input_dim and output_dim of conv layers in this block. The first layer is input x output,
74            All the subsequent layers are output x output.
75            """
76
77            input_dim = self.input_channels if i == 0 else output_dim
78            output_dim = num_filters[i]
79
80            if i != 0:
81                layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
82
83            layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3))
84            layers.append(nn.ReLU(inplace=True))
85
86            for _ in range(no_convs_per_block-1):
87                layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3))
88                layers.append(nn.ReLU(inplace=True))
89
90        self.layers = nn.Sequential(*layers)
91
92        self.layers.apply(init_weights)
93
94    def forward(self, input):
95        output = self.layers(input)
96        return output

A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers, after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied.

Encoder( input_channels, num_filters, no_convs_per_block, initializers, padding=True, posterior=False, num_classes=None)
46    def __init__(
47        self,
48        input_channels,
49        num_filters,
50        no_convs_per_block,
51        initializers,
52        padding=True,
53        posterior=False,
54        num_classes=None
55    ):
56
57        super().__init__()
58
59        self.contracting_path = nn.ModuleList()
60        self.input_channels = input_channels
61        self.num_filters = num_filters
62
63        if posterior:
64            # To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels.
65            assert num_classes is not None
66            self.input_channels += num_classes
67
68        layers = []
69        output_dim = None  # Initialize output_dim of the layers
70
71        for i in range(len(self.num_filters)):
72            """
73            Determine input_dim and output_dim of conv layers in this block. The first layer is input x output,
74            All the subsequent layers are output x output.
75            """
76
77            input_dim = self.input_channels if i == 0 else output_dim
78            output_dim = num_filters[i]
79
80            if i != 0:
81                layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
82
83            layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3))
84            layers.append(nn.ReLU(inplace=True))
85
86            for _ in range(no_convs_per_block-1):
87                layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3))
88                layers.append(nn.ReLU(inplace=True))
89
90        self.layers = nn.Sequential(*layers)
91
92        self.layers.apply(init_weights)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

contracting_path
input_channels
num_filters
layers
def forward(self, input):
94    def forward(self, input):
95        output = self.layers(input)
96        return output

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class AxisAlignedConvGaussian(torch.nn.modules.module.Module):
 99class AxisAlignedConvGaussian(nn.Module):
100    """
101    A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
102    """
103
104    def __init__(
105        self,
106        input_channels,
107        num_filters,
108        no_convs_per_block,
109        latent_dim,
110        initializers,
111        posterior=False,
112        num_classes=None
113    ):
114
115        super().__init__()
116
117        self.input_channels = input_channels
118        self.channel_axis = 1
119        self.num_filters = num_filters
120        self.no_convs_per_block = no_convs_per_block
121        self.latent_dim = latent_dim
122
123        self.posterior = posterior
124        if self.posterior:
125            self.name = 'Posterior'
126        else:
127            self.name = 'Prior'
128
129        self.encoder = Encoder(
130                                self.input_channels,
131                                self.num_filters,
132                                self.no_convs_per_block,
133                                initializers,
134                                posterior=self.posterior,
135                                num_classes=num_classes
136                            )
137
138        self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1)
139        self.show_img = 0
140        self.show_seg = 0
141        self.show_concat = 0
142        self.show_enc = 0
143        self.sum_input = 0
144
145        #
146        # @ Original paper's training details:
147        # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative
148        # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001
149
150        # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')  # from Stefan's impl.
151        # nn.init.normal_(self.conv_layer.weight, std=0.001)  # suggested @issues from Stefan's impl.
152
153        # nn.init.normal_(self.conv_layer.bias)  # from Stefan's impl.
154        #
155
156        nn.init.orthogonal_(self.conv_layer.weight, gain=1)
157        nn.init.trunc_normal_(self.conv_layer.bias, std=0.001)
158
159    def forward(self, input, segm=None):
160
161        # If segmentation is not none, concatenate the mask to the channel axis of the input
162        if segm is not None:
163            self.show_img = input
164            self.show_seg = segm
165            input = torch.cat((input, segm), dim=1)
166            self.show_concat = input
167            self.sum_input = torch.sum(input)
168
169        encoding = self.encoder(input)
170        self.show_enc = encoding
171
172        # We only want the mean of the resulting hxw image
173        encoding = torch.mean(encoding, dim=2, keepdim=True)
174        encoding = torch.mean(encoding, dim=3, keepdim=True)
175
176        # Convert encoding to 2 x latent dim and split up for mu and log_sigma
177        mu_log_sigma = self.conv_layer(encoding)
178
179        # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
180        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
181        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
182
183        mu = mu_log_sigma[:, :self.latent_dim]
184        log_sigma = mu_log_sigma[:, self.latent_dim:]
185
186        # This is a multivariate normal with diagonal covariance matrix sigma
187        # https://github.com/pytorch/pytorch/pull/11178
188        dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1)
189        return dist

A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.

AxisAlignedConvGaussian( input_channels, num_filters, no_convs_per_block, latent_dim, initializers, posterior=False, num_classes=None)
104    def __init__(
105        self,
106        input_channels,
107        num_filters,
108        no_convs_per_block,
109        latent_dim,
110        initializers,
111        posterior=False,
112        num_classes=None
113    ):
114
115        super().__init__()
116
117        self.input_channels = input_channels
118        self.channel_axis = 1
119        self.num_filters = num_filters
120        self.no_convs_per_block = no_convs_per_block
121        self.latent_dim = latent_dim
122
123        self.posterior = posterior
124        if self.posterior:
125            self.name = 'Posterior'
126        else:
127            self.name = 'Prior'
128
129        self.encoder = Encoder(
130                                self.input_channels,
131                                self.num_filters,
132                                self.no_convs_per_block,
133                                initializers,
134                                posterior=self.posterior,
135                                num_classes=num_classes
136                            )
137
138        self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1)
139        self.show_img = 0
140        self.show_seg = 0
141        self.show_concat = 0
142        self.show_enc = 0
143        self.sum_input = 0
144
145        #
146        # @ Original paper's training details:
147        # All weights of all models are initialized with orthogonal initialization having the gain (multiplicative
148        # factor) set to 1, and the bias terms are initialized by sampling from a truncated normal with σ = 0.001
149
150        # nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')  # from Stefan's impl.
151        # nn.init.normal_(self.conv_layer.weight, std=0.001)  # suggested @issues from Stefan's impl.
152
153        # nn.init.normal_(self.conv_layer.bias)  # from Stefan's impl.
154        #
155
156        nn.init.orthogonal_(self.conv_layer.weight, gain=1)
157        nn.init.trunc_normal_(self.conv_layer.bias, std=0.001)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

input_channels
channel_axis
num_filters
no_convs_per_block
latent_dim
posterior
encoder
conv_layer
show_img
show_seg
show_concat
show_enc
sum_input
def forward(self, input, segm=None):
159    def forward(self, input, segm=None):
160
161        # If segmentation is not none, concatenate the mask to the channel axis of the input
162        if segm is not None:
163            self.show_img = input
164            self.show_seg = segm
165            input = torch.cat((input, segm), dim=1)
166            self.show_concat = input
167            self.sum_input = torch.sum(input)
168
169        encoding = self.encoder(input)
170        self.show_enc = encoding
171
172        # We only want the mean of the resulting hxw image
173        encoding = torch.mean(encoding, dim=2, keepdim=True)
174        encoding = torch.mean(encoding, dim=3, keepdim=True)
175
176        # Convert encoding to 2 x latent dim and split up for mu and log_sigma
177        mu_log_sigma = self.conv_layer(encoding)
178
179        # We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
180        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
181        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
182
183        mu = mu_log_sigma[:, :self.latent_dim]
184        log_sigma = mu_log_sigma[:, self.latent_dim:]
185
186        # This is a multivariate normal with diagonal covariance matrix sigma
187        # https://github.com/pytorch/pytorch/pull/11178
188        dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1)
189        return dist

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class Fcomb(torch.nn.modules.module.Module):
192class Fcomb(nn.Module):
193    """
194    A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space,
195    and output of the UNet (the feature map) by concatenating them along their channel axis.
196    """
197    def __init__(
198        self,
199        num_filters,
200        latent_dim,
201        num_output_channels,
202        num_classes,
203        no_convs_fcomb,
204        initializers,
205        use_tile=True,
206        device=None
207    ):
208
209        super().__init__()
210
211        self.num_channels = num_output_channels
212        self.num_classes = num_classes
213        self.channel_axis = 1
214        self.spatial_axes = [2, 3]
215        self.num_filters = num_filters
216        self.latent_dim = latent_dim
217        self.use_tile = use_tile
218        self.no_convs_fcomb = no_convs_fcomb
219        self.name = 'Fcomb'
220
221        if device is None:
222            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
223        else:
224            self.device = device
225
226        if self.use_tile:
227            layers = []
228
229            # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
230            layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1))
231            layers.append(nn.ReLU(inplace=True))
232
233            for _ in range(no_convs_fcomb-2):
234                layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1))
235                layers.append(nn.ReLU(inplace=True))
236
237            self.layers = nn.Sequential(*layers)
238
239            self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
240
241            if initializers['w'] == 'orthogonal':
242                self.layers.apply(init_weights_orthogonal_normal)
243                self.last_layer.apply(init_weights_orthogonal_normal)
244            else:
245                self.layers.apply(init_weights)
246                self.last_layer.apply(init_weights)
247
248    def tile(self, a, dim, n_tile):
249        """
250        This function is taken form PyTorch forum and mimics the behavior of tf.tile.
251        Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
252        """
253        init_dim = a.size(dim)
254        repeat_idx = [1] * a.dim()
255        repeat_idx[dim] = n_tile
256        a = a.repeat(*(repeat_idx))
257        order_index = torch.LongTensor(
258                                    np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
259                                ).to(self.device)
260        return torch.index_select(a, dim, order_index)
261
262    def forward(self, feature_map, z):
263        """
264        Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W).
265        So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
266        """
267        if self.use_tile:
268            z = torch.unsqueeze(z, 2)
269            z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]])
270            z = torch.unsqueeze(z, 3)
271            z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
272
273            # Concatenate the feature map (output of the UNet) and the sample taken from the latent space
274            feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
275            output = self.layers(feature_map)
276            return self.last_layer(output)

A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space, and output of the UNet (the feature map) by concatenating them along their channel axis.

Fcomb( num_filters, latent_dim, num_output_channels, num_classes, no_convs_fcomb, initializers, use_tile=True, device=None)
197    def __init__(
198        self,
199        num_filters,
200        latent_dim,
201        num_output_channels,
202        num_classes,
203        no_convs_fcomb,
204        initializers,
205        use_tile=True,
206        device=None
207    ):
208
209        super().__init__()
210
211        self.num_channels = num_output_channels
212        self.num_classes = num_classes
213        self.channel_axis = 1
214        self.spatial_axes = [2, 3]
215        self.num_filters = num_filters
216        self.latent_dim = latent_dim
217        self.use_tile = use_tile
218        self.no_convs_fcomb = no_convs_fcomb
219        self.name = 'Fcomb'
220
221        if device is None:
222            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
223        else:
224            self.device = device
225
226        if self.use_tile:
227            layers = []
228
229            # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
230            layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1))
231            layers.append(nn.ReLU(inplace=True))
232
233            for _ in range(no_convs_fcomb-2):
234                layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1))
235                layers.append(nn.ReLU(inplace=True))
236
237            self.layers = nn.Sequential(*layers)
238
239            self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
240
241            if initializers['w'] == 'orthogonal':
242                self.layers.apply(init_weights_orthogonal_normal)
243                self.last_layer.apply(init_weights_orthogonal_normal)
244            else:
245                self.layers.apply(init_weights)
246                self.last_layer.apply(init_weights)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

num_channels
num_classes
channel_axis
spatial_axes
num_filters
latent_dim
use_tile
no_convs_fcomb
name
def tile(self, a, dim, n_tile):
248    def tile(self, a, dim, n_tile):
249        """
250        This function is taken form PyTorch forum and mimics the behavior of tf.tile.
251        Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
252        """
253        init_dim = a.size(dim)
254        repeat_idx = [1] * a.dim()
255        repeat_idx[dim] = n_tile
256        a = a.repeat(*(repeat_idx))
257        order_index = torch.LongTensor(
258                                    np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
259                                ).to(self.device)
260        return torch.index_select(a, dim, order_index)

This function is taken form PyTorch forum and mimics the behavior of tf.tile. Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3

def forward(self, feature_map, z):
262    def forward(self, feature_map, z):
263        """
264        Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W).
265        So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
266        """
267        if self.use_tile:
268            z = torch.unsqueeze(z, 2)
269            z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]])
270            z = torch.unsqueeze(z, 3)
271            z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
272
273            # Concatenate the feature map (output of the UNet) and the sample taken from the latent space
274            feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
275            output = self.layers(feature_map)
276            return self.last_layer(output)

Z is (batch_size x latent_dim) and feature_map is (batch_size x no_channels x H x W). So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ProbabilisticUNet(torch.nn.modules.module.Module):
279class ProbabilisticUNet(nn.Module):
280    """ This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034).
281    This generative segmentation heuristic uses UNet combined with a conditional variational
282    autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses.
283
284    The following elements are initialized to get our desired network:
285    input_channels: the number of channels in the image (1 for grayscale and 3 for RGB)
286    num_classes: the number of classes to predict
287    num_filters: is a list consisting of the amount of filters layer
288    latent_dim: dimension of the latent space
289    no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
290    beta: KL and reconstruction loss are weighted using a KL weighting factor (β)
291    consensus_masking: activates consensus masking in the reconstruction loss
292    rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss)
293
294    Parameters:
295        input_channels [int] - (default: 1)
296        num_classes [int] - (default: 1)
297        num_filters [list] - (default: [32, 64, 128, 192])
298        latent_dim [int] - (default: 6)
299        no_convs_fcomb [int] - (default: 4)
300        beta [float] - (default: 10.0)
301        consensus_masking [bool] - (default: False)
302        rl_swap [bool] - (default: False)
303        device [torch.device] - (default: None)
304    """
305
306    def __init__(
307        self,
308        input_channels=1,
309        num_classes=1,
310        num_filters=[32, 64, 128, 192],
311        latent_dim=6,
312        no_convs_fcomb=4,
313        beta=10.0,
314        consensus_masking=False,
315        rl_swap=False,
316        device=None
317    ):
318
319        super().__init__()
320
321        self.input_channels = input_channels
322        self.num_classes = num_classes
323        self.num_filters = num_filters
324        self.latent_dim = latent_dim
325        self.no_convs_per_block = 3
326        self.no_convs_fcomb = no_convs_fcomb
327        self.initializers = {'w': 'he_normal', 'b': 'normal'}
328        self.beta = beta
329        self.z_prior_sample = 0
330        self.consensus_masking = consensus_masking
331        self.rl_swap = rl_swap
332
333        if device is None:
334            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
335        else:
336            self.device = device
337
338        self.unet = UNet2d(
339                            in_channels=self.input_channels,
340                            out_channels=None,
341                            depth=len(self.num_filters),
342                            initial_features=num_filters[0]
343                        ).to(self.device)
344
345        self.prior = AxisAlignedConvGaussian(
346                            self.input_channels,
347                            self.num_filters,
348                            self.no_convs_per_block,
349                            self.latent_dim,
350                            self.initializers
351                        ).to(self.device)
352
353        self.posterior = AxisAlignedConvGaussian(
354                            self.input_channels,
355                            self.num_filters,
356                            self.no_convs_per_block,
357                            self.latent_dim,
358                            self.initializers,
359                            posterior=True,
360                            num_classes=num_classes
361                        ).to(self.device)
362
363        self.fcomb = Fcomb(
364                            self.num_filters,
365                            self.latent_dim,
366                            self.input_channels,
367                            self.num_classes,
368                            self.no_convs_fcomb,
369                            {'w': 'orthogonal', 'b': 'normal'},
370                            use_tile=True,
371                            device=self.device
372                        ).to(self.device)
373
374    def _check_shape(self, patch):
375        spatial_shape = tuple(patch.shape)[2:]
376        depth = len(self.num_filters)
377        factor = [2**depth] * len(spatial_shape)
378        if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)):
379            msg = f"Invalid shape for Probabilistic U-Net: {spatial_shape} is not divisible by {factor}"
380            raise ValueError(msg)
381
382    def forward(self, patch, segm=None):
383        """
384        Construct prior latent space for patch and run patch through UNet,
385        in case training is True also construct posterior latent space
386        """
387        self._check_shape(patch)
388
389        if segm is not None:
390            self.posterior_latent_space = self.posterior.forward(patch, segm)
391        self.prior_latent_space = self.prior.forward(patch)
392        self.unet_features = self.unet.forward(patch)
393
394    def sample(self, testing=False):
395        """
396        Sample a segmentation by reconstructing from a prior sample and combining this with UNet features
397        """
398        if testing is False:
399            # TODO: prior distribution ? (posterior in this case!)
400            z_prior = self.prior_latent_space.rsample()
401            self.z_prior_sample = z_prior
402        else:
403            # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
404            # z_prior = self.prior_latent_space.base_dist.loc
405            z_prior = self.prior_latent_space.sample()
406            self.z_prior_sample = z_prior
407        return self.fcomb.forward(self.unet_features, z_prior)
408
409    def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
410        """
411        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
412        use_posterior_mean: use posterior_mean instead of sampling z_q
413        calculate_posterior: use a provided sample or sample from posterior latent space
414        """
415        if use_posterior_mean:
416            z_posterior = self.posterior_latent_space.loc
417        else:
418            if calculate_posterior:
419                z_posterior = self.posterior_latent_space.rsample()
420        return self.fcomb.forward(self.unet_features, z_posterior)
421
422    def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
423        """
424        Calculate the KL divergence between the posterior and prior KL(Q||P)
425        analytic: calculate KL analytically or via sampling from the posterior
426        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
427        """
428        if analytic:
429            # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
430            kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
431        else:
432            if calculate_posterior:
433                z_posterior = self.posterior_latent_space.rsample()
434            log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
435            log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
436            kl_div = log_posterior_prob - log_prior_prob
437        return kl_div
438
439    def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False):
440        """
441        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
442        consm: consensus response
443        """
444
445        if self.rl_swap:
446            criterion = DiceLossWithLogits()
447        else:
448            criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None)
449
450        z_posterior = self.posterior_latent_space.rsample()
451
452        self.kl = torch.mean(
453                        self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)
454                    )
455
456        # Here we use the posterior sample sampled above
457        self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean,
458                                               calculate_posterior=False, z_posterior=z_posterior)
459
460        if self.consensus_masking is True and consm is not None:
461            reconstruction_loss = criterion(self.reconstruction * consm, segm * consm)
462        else:
463            reconstruction_loss = criterion(self.reconstruction, segm)
464
465        self.reconstruction_loss = torch.sum(reconstruction_loss)
466        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)
467
468        return -(self.reconstruction_loss + self.beta * self.kl)

This network implementation for the Probabilistic UNet of Kohl et al. (https://arxiv.org/abs/1806.05034). This generative segmentation heuristic uses UNet combined with a conditional variational autoencoder enabling to efficiently produce an unlimited number of plausible hypotheses.

The following elements are initialized to get our desired network: input_channels: the number of channels in the image (1 for grayscale and 3 for RGB) num_classes: the number of classes to predict num_filters: is a list consisting of the amount of filters layer latent_dim: dimension of the latent space no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior beta: KL and reconstruction loss are weighted using a KL weighting factor (β) consensus_masking: activates consensus masking in the reconstruction loss rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss)

Arguments:
  • input_channels [int] - (default: 1)
  • num_classes [int] - (default: 1)
  • num_filters [list] - (default: [32, 64, 128, 192])
  • latent_dim [int] - (default: 6)
  • no_convs_fcomb [int] - (default: 4)
  • beta [float] - (default: 10.0)
  • consensus_masking [bool] - (default: False)
  • rl_swap [bool] - (default: False)
  • device [torch.device] - (default: None)
ProbabilisticUNet( input_channels=1, num_classes=1, num_filters=[32, 64, 128, 192], latent_dim=6, no_convs_fcomb=4, beta=10.0, consensus_masking=False, rl_swap=False, device=None)
306    def __init__(
307        self,
308        input_channels=1,
309        num_classes=1,
310        num_filters=[32, 64, 128, 192],
311        latent_dim=6,
312        no_convs_fcomb=4,
313        beta=10.0,
314        consensus_masking=False,
315        rl_swap=False,
316        device=None
317    ):
318
319        super().__init__()
320
321        self.input_channels = input_channels
322        self.num_classes = num_classes
323        self.num_filters = num_filters
324        self.latent_dim = latent_dim
325        self.no_convs_per_block = 3
326        self.no_convs_fcomb = no_convs_fcomb
327        self.initializers = {'w': 'he_normal', 'b': 'normal'}
328        self.beta = beta
329        self.z_prior_sample = 0
330        self.consensus_masking = consensus_masking
331        self.rl_swap = rl_swap
332
333        if device is None:
334            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
335        else:
336            self.device = device
337
338        self.unet = UNet2d(
339                            in_channels=self.input_channels,
340                            out_channels=None,
341                            depth=len(self.num_filters),
342                            initial_features=num_filters[0]
343                        ).to(self.device)
344
345        self.prior = AxisAlignedConvGaussian(
346                            self.input_channels,
347                            self.num_filters,
348                            self.no_convs_per_block,
349                            self.latent_dim,
350                            self.initializers
351                        ).to(self.device)
352
353        self.posterior = AxisAlignedConvGaussian(
354                            self.input_channels,
355                            self.num_filters,
356                            self.no_convs_per_block,
357                            self.latent_dim,
358                            self.initializers,
359                            posterior=True,
360                            num_classes=num_classes
361                        ).to(self.device)
362
363        self.fcomb = Fcomb(
364                            self.num_filters,
365                            self.latent_dim,
366                            self.input_channels,
367                            self.num_classes,
368                            self.no_convs_fcomb,
369                            {'w': 'orthogonal', 'b': 'normal'},
370                            use_tile=True,
371                            device=self.device
372                        ).to(self.device)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

input_channels
num_classes
num_filters
latent_dim
no_convs_per_block
no_convs_fcomb
initializers
beta
z_prior_sample
consensus_masking
rl_swap
unet
prior
posterior
fcomb
def forward(self, patch, segm=None):
382    def forward(self, patch, segm=None):
383        """
384        Construct prior latent space for patch and run patch through UNet,
385        in case training is True also construct posterior latent space
386        """
387        self._check_shape(patch)
388
389        if segm is not None:
390            self.posterior_latent_space = self.posterior.forward(patch, segm)
391        self.prior_latent_space = self.prior.forward(patch)
392        self.unet_features = self.unet.forward(patch)

Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space

def sample(self, testing=False):
394    def sample(self, testing=False):
395        """
396        Sample a segmentation by reconstructing from a prior sample and combining this with UNet features
397        """
398        if testing is False:
399            # TODO: prior distribution ? (posterior in this case!)
400            z_prior = self.prior_latent_space.rsample()
401            self.z_prior_sample = z_prior
402        else:
403            # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
404            # z_prior = self.prior_latent_space.base_dist.loc
405            z_prior = self.prior_latent_space.sample()
406            self.z_prior_sample = z_prior
407        return self.fcomb.forward(self.unet_features, z_prior)

Sample a segmentation by reconstructing from a prior sample and combining this with UNet features

def reconstruct( self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
409    def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
410        """
411        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
412        use_posterior_mean: use posterior_mean instead of sampling z_q
413        calculate_posterior: use a provided sample or sample from posterior latent space
414        """
415        if use_posterior_mean:
416            z_posterior = self.posterior_latent_space.loc
417        else:
418            if calculate_posterior:
419                z_posterior = self.posterior_latent_space.rsample()
420        return self.fcomb.forward(self.unet_features, z_posterior)

Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map use_posterior_mean: use posterior_mean instead of sampling z_q calculate_posterior: use a provided sample or sample from posterior latent space

def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
422    def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
423        """
424        Calculate the KL divergence between the posterior and prior KL(Q||P)
425        analytic: calculate KL analytically or via sampling from the posterior
426        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
427        """
428        if analytic:
429            # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
430            kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
431        else:
432            if calculate_posterior:
433                z_posterior = self.posterior_latent_space.rsample()
434            log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
435            log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
436            kl_div = log_posterior_prob - log_prior_prob
437        return kl_div

Calculate the KL divergence between the posterior and prior KL(Q||P) analytic: calculate KL analytically or via sampling from the posterior calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample

def elbo( self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False):
439    def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=False):
440        """
441        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
442        consm: consensus response
443        """
444
445        if self.rl_swap:
446            criterion = DiceLossWithLogits()
447        else:
448            criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None)
449
450        z_posterior = self.posterior_latent_space.rsample()
451
452        self.kl = torch.mean(
453                        self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)
454                    )
455
456        # Here we use the posterior sample sampled above
457        self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean,
458                                               calculate_posterior=False, z_posterior=z_posterior)
459
460        if self.consensus_masking is True and consm is not None:
461            reconstruction_loss = criterion(self.reconstruction * consm, segm * consm)
462        else:
463            reconstruction_loss = criterion(self.reconstruction, segm)
464
465        self.reconstruction_loss = torch.sum(reconstruction_loss)
466        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)
467
468        return -(self.reconstruction_loss + self.beta * self.kl)

Calculate the evidence lower bound of the log-likelihood of P(Y|X) consm: consensus response

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile