torch_em.self_training.loss
1import torch 2import torch_em 3import torch.nn as nn 4from torch_em.loss import DiceLoss 5 6 7class DefaultSelfTrainingLoss(nn.Module): 8 """Loss function for self training. 9 10 Parameters: 11 loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss) 12 activation [nn.Module, callable] - the activation function to be applied to the prediction 13 before passing it to the loss. (default: None) 14 """ 15 def __init__(self, loss=torch_em.loss.DiceLoss(), activation=None): 16 super().__init__() 17 self.activation = activation 18 self.loss = loss 19 # TODO serialize the class names and kwargs instead 20 self.init_kwargs = {} 21 22 def __call__(self, model, input_, labels, label_filter=None): 23 prediction = model(input_) 24 if self.activation is not None: 25 prediction = self.activation(prediction) 26 if label_filter is None: 27 loss = self.loss(prediction, labels) 28 else: 29 loss = self.loss(prediction * label_filter, labels * label_filter) 30 return loss 31 32 33class DefaultSelfTrainingLossAndMetric(nn.Module): 34 """Loss and metric function for self training. 35 36 Parameters: 37 loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss) 38 metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss) 39 activation [nn.Module, callable] - the activation function to be applied to the prediction 40 before passing it to the loss. (default: None) 41 """ 42 def __init__(self, loss=torch_em.loss.DiceLoss(), metric=torch_em.loss.DiceLoss(), activation=None): 43 super().__init__() 44 self.activation = activation 45 self.loss = loss 46 self.metric = metric 47 # TODO serialize the class names and dicts instead 48 self.init_kwargs = {} 49 50 def __call__(self, model, input_, labels, label_filter=None): 51 prediction = model(input_) 52 if self.activation is not None: 53 prediction = self.activation(prediction) 54 if label_filter is None: 55 loss = self.loss(prediction, labels) 56 else: 57 loss = self.loss(prediction * label_filter, labels * label_filter) 58 metric = self.metric(prediction, labels) 59 return loss, metric 60 61 62def l2_regularisation(m): 63 l2_reg = None 64 65 for W in m.parameters(): 66 if l2_reg is None: 67 l2_reg = W.norm(2) 68 else: 69 l2_reg = l2_reg + W.norm(2) 70 return l2_reg 71 72 73class ProbabilisticUNetLoss(nn.Module): 74 """ 75 Loss function for Probabilistic UNet 76 77 Parameters : 78 # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.) 79 loss [nn.Module] - the loss function to be used. (default: None) 80 """ 81 def __init__(self, loss=None): 82 super().__init__() 83 self.loss = loss 84 85 def __call__(self, model, input_, labels, label_filter=None): 86 model.forward(input_, labels) 87 88 if self.loss is None: 89 elbo = model.elbo(labels, label_filter) 90 reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ 91 l2_regularisation(model.fcomb.layers) 92 loss = -elbo + 1e-5 * reg_loss 93 94 return loss 95 96 97class ProbabilisticUNetLossAndMetric(nn.Module): 98 """Loss and metric function for Probabilistic UNet. 99 100 Parameters: 101 # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.) 102 loss [nn.Module] - the loss function to be used. (default: None) 103 104 metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss) 105 activation [nn.Module, callable] - the activation function to be applied to the prediction 106 before evaluating the average predictions. (default: None) 107 """ 108 def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16): 109 super().__init__() 110 self.activation = activation 111 self.metric = metric 112 self.loss = loss 113 self.prior_samples = prior_samples 114 115 def __call__(self, model, input_, labels, label_filter=None): 116 model.forward(input_, labels) 117 118 if self.loss is None: 119 elbo = model.elbo(labels, label_filter) 120 reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ 121 l2_regularisation(model.fcomb.layers) 122 loss = -elbo + 1e-5 * reg_loss 123 124 samples_per_distribution = [] 125 for _ in range(self.prior_samples): 126 samples = model.sample(testing=False) 127 if self.activation is not None: 128 samples = self.activation(samples) 129 samples_per_distribution.append(samples) 130 131 avg_samples = torch.stack(samples_per_distribution, dim=0).sum(dim=0) / len(samples_per_distribution) 132 metric = self.metric(avg_samples, labels) 133 134 return loss, metric
class
DefaultSelfTrainingLoss(torch.nn.modules.module.Module):
8class DefaultSelfTrainingLoss(nn.Module): 9 """Loss function for self training. 10 11 Parameters: 12 loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss) 13 activation [nn.Module, callable] - the activation function to be applied to the prediction 14 before passing it to the loss. (default: None) 15 """ 16 def __init__(self, loss=torch_em.loss.DiceLoss(), activation=None): 17 super().__init__() 18 self.activation = activation 19 self.loss = loss 20 # TODO serialize the class names and kwargs instead 21 self.init_kwargs = {} 22 23 def __call__(self, model, input_, labels, label_filter=None): 24 prediction = model(input_) 25 if self.activation is not None: 26 prediction = self.activation(prediction) 27 if label_filter is None: 28 loss = self.loss(prediction, labels) 29 else: 30 loss = self.loss(prediction * label_filter, labels * label_filter) 31 return loss
Loss function for self training.
Arguments:
- loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss)
- activation [nn.Module, callable] - the activation function to be applied to the prediction before passing it to the loss. (default: None)
DefaultSelfTrainingLoss(loss=DiceLoss(), activation=None)
16 def __init__(self, loss=torch_em.loss.DiceLoss(), activation=None): 17 super().__init__() 18 self.activation = activation 19 self.loss = loss 20 # TODO serialize the class names and kwargs instead 21 self.init_kwargs = {}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
class
DefaultSelfTrainingLossAndMetric(torch.nn.modules.module.Module):
34class DefaultSelfTrainingLossAndMetric(nn.Module): 35 """Loss and metric function for self training. 36 37 Parameters: 38 loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss) 39 metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss) 40 activation [nn.Module, callable] - the activation function to be applied to the prediction 41 before passing it to the loss. (default: None) 42 """ 43 def __init__(self, loss=torch_em.loss.DiceLoss(), metric=torch_em.loss.DiceLoss(), activation=None): 44 super().__init__() 45 self.activation = activation 46 self.loss = loss 47 self.metric = metric 48 # TODO serialize the class names and dicts instead 49 self.init_kwargs = {} 50 51 def __call__(self, model, input_, labels, label_filter=None): 52 prediction = model(input_) 53 if self.activation is not None: 54 prediction = self.activation(prediction) 55 if label_filter is None: 56 loss = self.loss(prediction, labels) 57 else: 58 loss = self.loss(prediction * label_filter, labels * label_filter) 59 metric = self.metric(prediction, labels) 60 return loss, metric
Loss and metric function for self training.
Arguments:
- loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss)
- metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
- activation [nn.Module, callable] - the activation function to be applied to the prediction before passing it to the loss. (default: None)
DefaultSelfTrainingLossAndMetric(loss=DiceLoss(), metric=DiceLoss(), activation=None)
43 def __init__(self, loss=torch_em.loss.DiceLoss(), metric=torch_em.loss.DiceLoss(), activation=None): 44 super().__init__() 45 self.activation = activation 46 self.loss = loss 47 self.metric = metric 48 # TODO serialize the class names and dicts instead 49 self.init_kwargs = {}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
def
l2_regularisation(m):
class
ProbabilisticUNetLoss(torch.nn.modules.module.Module):
74class ProbabilisticUNetLoss(nn.Module): 75 """ 76 Loss function for Probabilistic UNet 77 78 Parameters : 79 # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.) 80 loss [nn.Module] - the loss function to be used. (default: None) 81 """ 82 def __init__(self, loss=None): 83 super().__init__() 84 self.loss = loss 85 86 def __call__(self, model, input_, labels, label_filter=None): 87 model.forward(input_, labels) 88 89 if self.loss is None: 90 elbo = model.elbo(labels, label_filter) 91 reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ 92 l2_regularisation(model.fcomb.layers) 93 loss = -elbo + 1e-5 * reg_loss 94 95 return loss
Loss function for Probabilistic UNet
Parameters :
TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
loss [nn.Module] - the loss function to be used. (default: None)
ProbabilisticUNetLoss(loss=None)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
class
ProbabilisticUNetLossAndMetric(torch.nn.modules.module.Module):
98class ProbabilisticUNetLossAndMetric(nn.Module): 99 """Loss and metric function for Probabilistic UNet. 100 101 Parameters: 102 # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.) 103 loss [nn.Module] - the loss function to be used. (default: None) 104 105 metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss) 106 activation [nn.Module, callable] - the activation function to be applied to the prediction 107 before evaluating the average predictions. (default: None) 108 """ 109 def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16): 110 super().__init__() 111 self.activation = activation 112 self.metric = metric 113 self.loss = loss 114 self.prior_samples = prior_samples 115 116 def __call__(self, model, input_, labels, label_filter=None): 117 model.forward(input_, labels) 118 119 if self.loss is None: 120 elbo = model.elbo(labels, label_filter) 121 reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ 122 l2_regularisation(model.fcomb.layers) 123 loss = -elbo + 1e-5 * reg_loss 124 125 samples_per_distribution = [] 126 for _ in range(self.prior_samples): 127 samples = model.sample(testing=False) 128 if self.activation is not None: 129 samples = self.activation(samples) 130 samples_per_distribution.append(samples) 131 132 avg_samples = torch.stack(samples_per_distribution, dim=0).sum(dim=0) / len(samples_per_distribution) 133 metric = self.metric(avg_samples, labels) 134 135 return loss, metric
Loss and metric function for Probabilistic UNet.
Arguments:
- # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
- loss [nn.Module] - the loss function to be used. (default: None)
- metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
- activation [nn.Module, callable] - the activation function to be applied to the prediction before evaluating the average predictions. (default: None)
ProbabilisticUNetLossAndMetric(loss=None, metric=DiceLoss(), activation=Sigmoid(), prior_samples=16)
109 def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16): 110 super().__init__() 111 self.activation = activation 112 self.metric = metric 113 self.loss = loss 114 self.prior_samples = prior_samples
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile