torch_em.self_training.probabilistic_unet_trainer
1import time 2import torch 3import torch_em 4 5 6class DummyLoss(torch.nn.Module): 7 pass 8 9 10class ProbabilisticUNetTrainer(torch_em.trainer.DefaultTrainer): 11 """This trainer implements training for the 'Probabilistic UNet' of Kohl et al.: (https://arxiv.org/abs/1806.05034). 12 This approach combines the learnings from UNet and VAEs (Prior and Posterior networks) to obtain generative 13 segmentations. The heuristic trains by taking into account the feature maps from UNet and the samples from 14 the posterior distribution, estimating the loss and further sampling from the prior for validation. 15 16 Parameters: 17 clipping_value [float] - (default: None) 18 prior_samples [int] - (default: 16) 19 loss [callable] - (default: None) 20 loss_and_metric [callable] - (default: None) 21 """ 22 23 def __init__( 24 self, 25 clipping_value=None, 26 prior_samples=16, 27 loss=None, 28 loss_and_metric=None, 29 **kwargs 30 ): 31 super().__init__(loss=loss, metric=DummyLoss(), **kwargs) 32 assert loss, loss_and_metric is not None 33 34 self.loss_and_metric = loss_and_metric 35 36 self.clipping_value = clipping_value 37 38 self.prior_samples = prior_samples 39 self.sigmoid = torch.nn.Sigmoid() 40 41 self._kwargs = kwargs 42 43 # 44 # functionality for sampling from the network 45 # 46 47 def _sample(self): 48 samples = [self.model.sample() for _ in range(self.prior_samples)] 49 return samples 50 51 # 52 # training and validation functionality 53 # 54 55 def _train_epoch_impl(self, progress, forward_context, backprop): 56 self.model.train() 57 58 n_iter = 0 59 t_per_iter = time.time() 60 61 for x, y in self.train_loader: 62 x, y = x.to(self.device), y.to(self.device) 63 64 self.optimizer.zero_grad() 65 66 with forward_context(): 67 # We pass the model, the input and the labels to the supervised loss function, so 68 # that's how the loss is calculated stays flexible, e.g. here to enable ELBO for PUNet. 69 loss = self.loss(self.model, x, y) 70 71 backprop(loss) 72 73 # To counter the exploding gradients in the posterior net 74 if self.clipping_value is not None: 75 torch.nn.utils.clip_grad_norm_(self.model.posterior.encoder.layers.parameters(), self.clipping_value) 76 77 if self.logger is not None: 78 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 79 samples = self._sample() if self._iteration % self.log_image_interval == 0 else None 80 self.logger.log_train(self._iteration, loss, lr, x, y, samples) 81 82 self._iteration += 1 83 n_iter += 1 84 if self._iteration >= self.max_iteration: 85 break 86 progress.update(1) 87 88 t_per_iter = (time.time() - t_per_iter) / n_iter 89 return t_per_iter 90 91 def _validate_impl(self, forward_context): 92 self.model.eval() 93 94 metric_val = 0.0 95 loss_val = 0.0 96 97 with torch.no_grad(): 98 for x, y in self.val_loader: 99 x, y = x.to(self.device), y.to(self.device) 100 101 with forward_context(): 102 loss, metric = self.loss_and_metric(self.model, x, y) 103 104 loss_val += loss.item() 105 metric_val += metric 106 107 metric_val /= len(self.val_loader) 108 loss_val /= len(self.val_loader) 109 110 if self.logger is not None: 111 samples = self._sample() 112 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, samples) 113 114 return metric_val
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Inherited Members
- torch.nn.modules.module.Module
- Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
11class ProbabilisticUNetTrainer(torch_em.trainer.DefaultTrainer): 12 """This trainer implements training for the 'Probabilistic UNet' of Kohl et al.: (https://arxiv.org/abs/1806.05034). 13 This approach combines the learnings from UNet and VAEs (Prior and Posterior networks) to obtain generative 14 segmentations. The heuristic trains by taking into account the feature maps from UNet and the samples from 15 the posterior distribution, estimating the loss and further sampling from the prior for validation. 16 17 Parameters: 18 clipping_value [float] - (default: None) 19 prior_samples [int] - (default: 16) 20 loss [callable] - (default: None) 21 loss_and_metric [callable] - (default: None) 22 """ 23 24 def __init__( 25 self, 26 clipping_value=None, 27 prior_samples=16, 28 loss=None, 29 loss_and_metric=None, 30 **kwargs 31 ): 32 super().__init__(loss=loss, metric=DummyLoss(), **kwargs) 33 assert loss, loss_and_metric is not None 34 35 self.loss_and_metric = loss_and_metric 36 37 self.clipping_value = clipping_value 38 39 self.prior_samples = prior_samples 40 self.sigmoid = torch.nn.Sigmoid() 41 42 self._kwargs = kwargs 43 44 # 45 # functionality for sampling from the network 46 # 47 48 def _sample(self): 49 samples = [self.model.sample() for _ in range(self.prior_samples)] 50 return samples 51 52 # 53 # training and validation functionality 54 # 55 56 def _train_epoch_impl(self, progress, forward_context, backprop): 57 self.model.train() 58 59 n_iter = 0 60 t_per_iter = time.time() 61 62 for x, y in self.train_loader: 63 x, y = x.to(self.device), y.to(self.device) 64 65 self.optimizer.zero_grad() 66 67 with forward_context(): 68 # We pass the model, the input and the labels to the supervised loss function, so 69 # that's how the loss is calculated stays flexible, e.g. here to enable ELBO for PUNet. 70 loss = self.loss(self.model, x, y) 71 72 backprop(loss) 73 74 # To counter the exploding gradients in the posterior net 75 if self.clipping_value is not None: 76 torch.nn.utils.clip_grad_norm_(self.model.posterior.encoder.layers.parameters(), self.clipping_value) 77 78 if self.logger is not None: 79 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 80 samples = self._sample() if self._iteration % self.log_image_interval == 0 else None 81 self.logger.log_train(self._iteration, loss, lr, x, y, samples) 82 83 self._iteration += 1 84 n_iter += 1 85 if self._iteration >= self.max_iteration: 86 break 87 progress.update(1) 88 89 t_per_iter = (time.time() - t_per_iter) / n_iter 90 return t_per_iter 91 92 def _validate_impl(self, forward_context): 93 self.model.eval() 94 95 metric_val = 0.0 96 loss_val = 0.0 97 98 with torch.no_grad(): 99 for x, y in self.val_loader: 100 x, y = x.to(self.device), y.to(self.device) 101 102 with forward_context(): 103 loss, metric = self.loss_and_metric(self.model, x, y) 104 105 loss_val += loss.item() 106 metric_val += metric 107 108 metric_val /= len(self.val_loader) 109 loss_val /= len(self.val_loader) 110 111 if self.logger is not None: 112 samples = self._sample() 113 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, samples) 114 115 return metric_val
This trainer implements training for the 'Probabilistic UNet' of Kohl et al.: (https://arxiv.org/abs/1806.05034). This approach combines the learnings from UNet and VAEs (Prior and Posterior networks) to obtain generative segmentations. The heuristic trains by taking into account the feature maps from UNet and the samples from the posterior distribution, estimating the loss and further sampling from the prior for validation.
Arguments:
- clipping_value [float] - (default: None)
- prior_samples [int] - (default: 16)
- loss [callable] - (default: None)
- loss_and_metric [callable] - (default: None)
24 def __init__( 25 self, 26 clipping_value=None, 27 prior_samples=16, 28 loss=None, 29 loss_and_metric=None, 30 **kwargs 31 ): 32 super().__init__(loss=loss, metric=DummyLoss(), **kwargs) 33 assert loss, loss_and_metric is not None 34 35 self.loss_and_metric = loss_and_metric 36 37 self.clipping_value = clipping_value 38 39 self.prior_samples = prior_samples 40 self.sigmoid = torch.nn.Sigmoid() 41 42 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- mixed_precision
- early_stopping
- train_time
- scaler
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- save_checkpoint
- load_checkpoint
- fit