torch_em.loss.dice
1import torch 2import torch.nn as nn 3 4 5# TODO refactor 6def flatten_samples(input_): 7 """ 8 Flattens a tensor or a variable such that the channel axis is first and the sample axis 9 is second. The shapes are transformed as follows: 10 (N, C, H, W) --> (C, N * H * W) 11 (N, C, D, H, W) --> (C, N * D * H * W) 12 (N, C) --> (C, N) 13 The input must be atleast 2d. 14 """ 15 # Get number of channels 16 num_channels = input_.size(1) 17 # Permute the channel axis to first 18 permute_axes = list(range(input_.dim())) 19 permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0] 20 # For input shape (say) NCHW, this should have the shape CNHW 21 permuted = input_.permute(*permute_axes).contiguous() 22 # Now flatten out all but the first axis and return 23 flattened = permuted.view(num_channels, -1) 24 return flattened 25 26 27def dice_score(input_, target, invert=False, channelwise=True, reduce_channel="sum", eps=1e-7): 28 if input_.shape != target.shape: 29 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 30 31 if channelwise: 32 # Flatten input and target to have the shape (C, N), 33 # where N is the number of samples 34 input_ = flatten_samples(input_) 35 target = flatten_samples(target) 36 # Compute numerator and denominator (by summing over samples and 37 # leaving the channels intact) 38 numerator = (input_ * target).sum(-1) 39 denominator = (input_ * input_).sum(-1) + (target * target).sum(-1) 40 channelwise_score = 2 * (numerator / denominator.clamp(min=eps)) 41 if invert: 42 channelwise_score = 1. - channelwise_score 43 44 # Reduce the dice score over the channels to compute the overall dice score. 45 # (default is to use the sum) 46 if reduce_channel is None: 47 score = channelwise_score 48 elif reduce_channel == "sum": 49 score = channelwise_score.sum() 50 elif reduce_channel == "mean": 51 score = channelwise_score.mean() 52 elif reduce_channel == "max": 53 score = channelwise_score.max() 54 elif reduce_channel == "min": 55 score = channelwise_score.min() 56 else: 57 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 58 59 else: 60 numerator = (input_ * target).sum() 61 denominator = (input_ * input_).sum() + (target * target).sum() 62 score = 2. * (numerator / denominator.clamp(min=eps)) 63 if invert: 64 score = 1. - score 65 66 return score 67 68 69class DiceLoss(nn.Module): 70 def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"): 71 if reduce_channel not in ("sum", "mean", "max", "min", None): 72 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 73 super().__init__() 74 self.channelwise = channelwise 75 self.eps = eps 76 self.reduce_channel = reduce_channel 77 78 # all torch_em classes should store init kwargs to easily recreate the init call 79 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel} 80 81 def forward(self, input_, target): 82 return dice_score(input_, target, 83 invert=True, channelwise=self.channelwise, 84 eps=self.eps, reduce_channel=self.reduce_channel) 85 86 87class DiceLossWithLogits(nn.Module): 88 def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"): 89 super().__init__() 90 self.channelwise = channelwise 91 self.eps = eps 92 self.reduce_channel = reduce_channel 93 94 # all torch_em classes should store init kwargs to easily recreate the init call 95 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel} 96 97 def forward(self, input_, target): 98 return dice_score( 99 nn.functional.sigmoid(input_), 100 target, 101 invert=True, 102 channelwise=self.channelwise, 103 eps=self.eps, 104 reduce_channel=self.reduce_channel, 105 ) 106 107 108class BCEDiceLoss(nn.Module): 109 110 def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7): 111 super().__init__() 112 self.alpha = alpha 113 self.beta = beta 114 self.channelwise = channelwise 115 self.eps = eps 116 117 # all torch_em classes should store init kwargs to easily recreate the init call 118 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps} 119 120 def forward(self, input_, target): 121 loss_dice = dice_score( 122 input_, 123 target, 124 invert=True, 125 channelwise=self.channelwise, 126 eps=self.eps 127 ) 128 loss_bce = nn.functional.binary_cross_entropy( 129 input_, target 130 ) 131 return self.alpha * loss_dice + self.beta * loss_bce 132 133 134# TODO think about how to handle combined losses like this for mixed precision training 135class BCEDiceLossWithLogits(nn.Module): 136 137 def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7): 138 super().__init__() 139 self.alpha = alpha 140 self.beta = beta 141 self.channelwise = channelwise 142 self.eps = eps 143 144 # all torch_em classes should store init kwargs to easily recreate the init call 145 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps} 146 147 def forward(self, input_, target): 148 loss_dice = dice_score( 149 torch.sigmoid(input_), 150 target, 151 invert=True, 152 channelwise=self.channelwise, 153 eps=self.eps 154 ) 155 loss_bce = nn.functional.binary_cross_entropy_with_logits( 156 input_, target 157 ) 158 return self.alpha * loss_dice + self.beta * loss_bce
7def flatten_samples(input_): 8 """ 9 Flattens a tensor or a variable such that the channel axis is first and the sample axis 10 is second. The shapes are transformed as follows: 11 (N, C, H, W) --> (C, N * H * W) 12 (N, C, D, H, W) --> (C, N * D * H * W) 13 (N, C) --> (C, N) 14 The input must be atleast 2d. 15 """ 16 # Get number of channels 17 num_channels = input_.size(1) 18 # Permute the channel axis to first 19 permute_axes = list(range(input_.dim())) 20 permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0] 21 # For input shape (say) NCHW, this should have the shape CNHW 22 permuted = input_.permute(*permute_axes).contiguous() 23 # Now flatten out all but the first axis and return 24 flattened = permuted.view(num_channels, -1) 25 return flattened
Flattens a tensor or a variable such that the channel axis is first and the sample axis is second. The shapes are transformed as follows: (N, C, H, W) --> (C, N * H * W) (N, C, D, H, W) --> (C, N * D * H * W) (N, C) --> (C, N) The input must be atleast 2d.
28def dice_score(input_, target, invert=False, channelwise=True, reduce_channel="sum", eps=1e-7): 29 if input_.shape != target.shape: 30 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 31 32 if channelwise: 33 # Flatten input and target to have the shape (C, N), 34 # where N is the number of samples 35 input_ = flatten_samples(input_) 36 target = flatten_samples(target) 37 # Compute numerator and denominator (by summing over samples and 38 # leaving the channels intact) 39 numerator = (input_ * target).sum(-1) 40 denominator = (input_ * input_).sum(-1) + (target * target).sum(-1) 41 channelwise_score = 2 * (numerator / denominator.clamp(min=eps)) 42 if invert: 43 channelwise_score = 1. - channelwise_score 44 45 # Reduce the dice score over the channels to compute the overall dice score. 46 # (default is to use the sum) 47 if reduce_channel is None: 48 score = channelwise_score 49 elif reduce_channel == "sum": 50 score = channelwise_score.sum() 51 elif reduce_channel == "mean": 52 score = channelwise_score.mean() 53 elif reduce_channel == "max": 54 score = channelwise_score.max() 55 elif reduce_channel == "min": 56 score = channelwise_score.min() 57 else: 58 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 59 60 else: 61 numerator = (input_ * target).sum() 62 denominator = (input_ * input_).sum() + (target * target).sum() 63 score = 2. * (numerator / denominator.clamp(min=eps)) 64 if invert: 65 score = 1. - score 66 67 return score
70class DiceLoss(nn.Module): 71 def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"): 72 if reduce_channel not in ("sum", "mean", "max", "min", None): 73 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 74 super().__init__() 75 self.channelwise = channelwise 76 self.eps = eps 77 self.reduce_channel = reduce_channel 78 79 # all torch_em classes should store init kwargs to easily recreate the init call 80 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel} 81 82 def forward(self, input_, target): 83 return dice_score(input_, target, 84 invert=True, channelwise=self.channelwise, 85 eps=self.eps, reduce_channel=self.reduce_channel)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
71 def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"): 72 if reduce_channel not in ("sum", "mean", "max", "min", None): 73 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 74 super().__init__() 75 self.channelwise = channelwise 76 self.eps = eps 77 self.reduce_channel = reduce_channel 78 79 # all torch_em classes should store init kwargs to easily recreate the init call 80 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
82 def forward(self, input_, target): 83 return dice_score(input_, target, 84 invert=True, channelwise=self.channelwise, 85 eps=self.eps, reduce_channel=self.reduce_channel)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
88class DiceLossWithLogits(nn.Module): 89 def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"): 90 super().__init__() 91 self.channelwise = channelwise 92 self.eps = eps 93 self.reduce_channel = reduce_channel 94 95 # all torch_em classes should store init kwargs to easily recreate the init call 96 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel} 97 98 def forward(self, input_, target): 99 return dice_score( 100 nn.functional.sigmoid(input_), 101 target, 102 invert=True, 103 channelwise=self.channelwise, 104 eps=self.eps, 105 reduce_channel=self.reduce_channel, 106 )
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
89 def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"): 90 super().__init__() 91 self.channelwise = channelwise 92 self.eps = eps 93 self.reduce_channel = reduce_channel 94 95 # all torch_em classes should store init kwargs to easily recreate the init call 96 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
98 def forward(self, input_, target): 99 return dice_score( 100 nn.functional.sigmoid(input_), 101 target, 102 invert=True, 103 channelwise=self.channelwise, 104 eps=self.eps, 105 reduce_channel=self.reduce_channel, 106 )
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
109class BCEDiceLoss(nn.Module): 110 111 def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7): 112 super().__init__() 113 self.alpha = alpha 114 self.beta = beta 115 self.channelwise = channelwise 116 self.eps = eps 117 118 # all torch_em classes should store init kwargs to easily recreate the init call 119 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps} 120 121 def forward(self, input_, target): 122 loss_dice = dice_score( 123 input_, 124 target, 125 invert=True, 126 channelwise=self.channelwise, 127 eps=self.eps 128 ) 129 loss_bce = nn.functional.binary_cross_entropy( 130 input_, target 131 ) 132 return self.alpha * loss_dice + self.beta * loss_bce
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
111 def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7): 112 super().__init__() 113 self.alpha = alpha 114 self.beta = beta 115 self.channelwise = channelwise 116 self.eps = eps 117 118 # all torch_em classes should store init kwargs to easily recreate the init call 119 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
121 def forward(self, input_, target): 122 loss_dice = dice_score( 123 input_, 124 target, 125 invert=True, 126 channelwise=self.channelwise, 127 eps=self.eps 128 ) 129 loss_bce = nn.functional.binary_cross_entropy( 130 input_, target 131 ) 132 return self.alpha * loss_dice + self.beta * loss_bce
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
136class BCEDiceLossWithLogits(nn.Module): 137 138 def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7): 139 super().__init__() 140 self.alpha = alpha 141 self.beta = beta 142 self.channelwise = channelwise 143 self.eps = eps 144 145 # all torch_em classes should store init kwargs to easily recreate the init call 146 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps} 147 148 def forward(self, input_, target): 149 loss_dice = dice_score( 150 torch.sigmoid(input_), 151 target, 152 invert=True, 153 channelwise=self.channelwise, 154 eps=self.eps 155 ) 156 loss_bce = nn.functional.binary_cross_entropy_with_logits( 157 input_, target 158 ) 159 return self.alpha * loss_dice + self.beta * loss_bce
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
138 def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7): 139 super().__init__() 140 self.alpha = alpha 141 self.beta = beta 142 self.channelwise = channelwise 143 self.eps = eps 144 145 # all torch_em classes should store init kwargs to easily recreate the init call 146 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
148 def forward(self, input_, target): 149 loss_dice = dice_score( 150 torch.sigmoid(input_), 151 target, 152 invert=True, 153 channelwise=self.channelwise, 154 eps=self.eps 155 ) 156 loss_bce = nn.functional.binary_cross_entropy_with_logits( 157 input_, target 158 ) 159 return self.alpha * loss_dice + self.beta * loss_bce
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile