torch_em.self_training.probabilistic_unet_trainer

  1import time
  2import torch
  3import torch_em
  4
  5
  6class DummyLoss(torch.nn.Module):
  7    pass
  8
  9
 10class ProbabilisticUNetTrainer(torch_em.trainer.DefaultTrainer):
 11    """This trainer implements training for the 'Probabilistic UNet' of Kohl et al.: (https://arxiv.org/abs/1806.05034).
 12    This approach combines the learnings from UNet and VAEs (Prior and Posterior networks) to obtain generative
 13    segmentations. The heuristic trains by taking into account the feature maps from UNet and the samples from
 14    the posterior distribution, estimating the loss and further sampling from the prior for validation.
 15
 16    Parameters:
 17        clipping_value [float] - (default: None)
 18        prior_samples [int] - (default: 16)
 19        loss [callable] - (default: None)
 20        loss_and_metric [callable] - (default: None)
 21    """
 22
 23    def __init__(
 24            self,
 25            clipping_value=None,
 26            prior_samples=16,
 27            loss=None,
 28            loss_and_metric=None,
 29            **kwargs
 30    ):
 31        super().__init__(loss=loss, metric=DummyLoss(), **kwargs)
 32        assert loss, loss_and_metric is not None
 33
 34        self.loss_and_metric = loss_and_metric
 35
 36        self.clipping_value = clipping_value
 37
 38        self.prior_samples = prior_samples
 39        self.sigmoid = torch.nn.Sigmoid()
 40
 41        self._kwargs = kwargs
 42
 43    #
 44    # functionality for sampling from the network
 45    #
 46
 47    def _sample(self):
 48        samples = [self.model.sample() for _ in range(self.prior_samples)]
 49        return samples
 50
 51    #
 52    # training and validation functionality
 53    #
 54
 55    def _train_epoch_impl(self, progress, forward_context, backprop):
 56        self.model.train()
 57
 58        n_iter = 0
 59        t_per_iter = time.time()
 60
 61        for x, y in self.train_loader:
 62            x, y = x.to(self.device), y.to(self.device)
 63
 64            self.optimizer.zero_grad()
 65
 66            with forward_context():
 67                # We pass the model, the input and the labels to the supervised loss function, so
 68                # that's how the loss is calculated stays flexible, e.g. here to enable ELBO for PUNet.
 69                loss = self.loss(self.model, x, y)
 70
 71            backprop(loss)
 72
 73            # To counter the exploding gradients in the posterior net
 74            if self.clipping_value is not None:
 75                torch.nn.utils.clip_grad_norm_(self.model.posterior.encoder.layers.parameters(), self.clipping_value)
 76
 77            if self.logger is not None:
 78                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
 79                samples = self._sample() if self._iteration % self.log_image_interval == 0 else None
 80                self.logger.log_train(self._iteration, loss, lr, x, y, samples)
 81
 82            self._iteration += 1
 83            n_iter += 1
 84            if self._iteration >= self.max_iteration:
 85                break
 86            progress.update(1)
 87
 88        t_per_iter = (time.time() - t_per_iter) / n_iter
 89        return t_per_iter
 90
 91    def _validate_impl(self, forward_context):
 92        self.model.eval()
 93
 94        metric_val = 0.0
 95        loss_val = 0.0
 96
 97        with torch.no_grad():
 98            for x, y in self.val_loader:
 99                x, y = x.to(self.device), y.to(self.device)
100
101                with forward_context():
102                    loss, metric = self.loss_and_metric(self.model, x, y)
103
104                loss_val += loss.item()
105                metric_val += metric
106
107        metric_val /= len(self.val_loader)
108        loss_val /= len(self.val_loader)
109
110        if self.logger is not None:
111            samples = self._sample()
112            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, samples)
113
114        return metric_val
class DummyLoss(torch.nn.modules.module.Module):
7class DummyLoss(torch.nn.Module):
8    pass

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Inherited Members
torch.nn.modules.module.Module
Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ProbabilisticUNetTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 11class ProbabilisticUNetTrainer(torch_em.trainer.DefaultTrainer):
 12    """This trainer implements training for the 'Probabilistic UNet' of Kohl et al.: (https://arxiv.org/abs/1806.05034).
 13    This approach combines the learnings from UNet and VAEs (Prior and Posterior networks) to obtain generative
 14    segmentations. The heuristic trains by taking into account the feature maps from UNet and the samples from
 15    the posterior distribution, estimating the loss and further sampling from the prior for validation.
 16
 17    Parameters:
 18        clipping_value [float] - (default: None)
 19        prior_samples [int] - (default: 16)
 20        loss [callable] - (default: None)
 21        loss_and_metric [callable] - (default: None)
 22    """
 23
 24    def __init__(
 25            self,
 26            clipping_value=None,
 27            prior_samples=16,
 28            loss=None,
 29            loss_and_metric=None,
 30            **kwargs
 31    ):
 32        super().__init__(loss=loss, metric=DummyLoss(), **kwargs)
 33        assert loss, loss_and_metric is not None
 34
 35        self.loss_and_metric = loss_and_metric
 36
 37        self.clipping_value = clipping_value
 38
 39        self.prior_samples = prior_samples
 40        self.sigmoid = torch.nn.Sigmoid()
 41
 42        self._kwargs = kwargs
 43
 44    #
 45    # functionality for sampling from the network
 46    #
 47
 48    def _sample(self):
 49        samples = [self.model.sample() for _ in range(self.prior_samples)]
 50        return samples
 51
 52    #
 53    # training and validation functionality
 54    #
 55
 56    def _train_epoch_impl(self, progress, forward_context, backprop):
 57        self.model.train()
 58
 59        n_iter = 0
 60        t_per_iter = time.time()
 61
 62        for x, y in self.train_loader:
 63            x, y = x.to(self.device), y.to(self.device)
 64
 65            self.optimizer.zero_grad()
 66
 67            with forward_context():
 68                # We pass the model, the input and the labels to the supervised loss function, so
 69                # that's how the loss is calculated stays flexible, e.g. here to enable ELBO for PUNet.
 70                loss = self.loss(self.model, x, y)
 71
 72            backprop(loss)
 73
 74            # To counter the exploding gradients in the posterior net
 75            if self.clipping_value is not None:
 76                torch.nn.utils.clip_grad_norm_(self.model.posterior.encoder.layers.parameters(), self.clipping_value)
 77
 78            if self.logger is not None:
 79                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
 80                samples = self._sample() if self._iteration % self.log_image_interval == 0 else None
 81                self.logger.log_train(self._iteration, loss, lr, x, y, samples)
 82
 83            self._iteration += 1
 84            n_iter += 1
 85            if self._iteration >= self.max_iteration:
 86                break
 87            progress.update(1)
 88
 89        t_per_iter = (time.time() - t_per_iter) / n_iter
 90        return t_per_iter
 91
 92    def _validate_impl(self, forward_context):
 93        self.model.eval()
 94
 95        metric_val = 0.0
 96        loss_val = 0.0
 97
 98        with torch.no_grad():
 99            for x, y in self.val_loader:
100                x, y = x.to(self.device), y.to(self.device)
101
102                with forward_context():
103                    loss, metric = self.loss_and_metric(self.model, x, y)
104
105                loss_val += loss.item()
106                metric_val += metric
107
108        metric_val /= len(self.val_loader)
109        loss_val /= len(self.val_loader)
110
111        if self.logger is not None:
112            samples = self._sample()
113            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, samples)
114
115        return metric_val

This trainer implements training for the 'Probabilistic UNet' of Kohl et al.: (https://arxiv.org/abs/1806.05034). This approach combines the learnings from UNet and VAEs (Prior and Posterior networks) to obtain generative segmentations. The heuristic trains by taking into account the feature maps from UNet and the samples from the posterior distribution, estimating the loss and further sampling from the prior for validation.

Arguments:
  • clipping_value [float] - (default: None)
  • prior_samples [int] - (default: 16)
  • loss [callable] - (default: None)
  • loss_and_metric [callable] - (default: None)
ProbabilisticUNetTrainer( clipping_value=None, prior_samples=16, loss=None, loss_and_metric=None, **kwargs)
24    def __init__(
25            self,
26            clipping_value=None,
27            prior_samples=16,
28            loss=None,
29            loss_and_metric=None,
30            **kwargs
31    ):
32        super().__init__(loss=loss, metric=DummyLoss(), **kwargs)
33        assert loss, loss_and_metric is not None
34
35        self.loss_and_metric = loss_and_metric
36
37        self.clipping_value = clipping_value
38
39        self.prior_samples = prior_samples
40        self.sigmoid = torch.nn.Sigmoid()
41
42        self._kwargs = kwargs
loss_and_metric
clipping_value
prior_samples
sigmoid