torch_em.loss.dice
1from typing import Optional 2 3import torch 4import torch.nn as nn 5 6 7def flatten_samples(input_: torch.Tensor) -> torch.Tensor: 8 """Flattens a tensor or a variable such that the channel axis is first and the sample (batch) axis is second. 9 10 The shapes are transformed as follows: 11 (N, C, H, W) --> (C, N * H * W) 12 (N, C, D, H, W) --> (C, N * D * H * W) 13 (N, C) --> (C, N) 14 The input must be atleast 2d. 15 16 Args: 17 The input tensor. 18 19 Returns: 20 The transformed input tensor. 21 """ 22 # Get number of channels 23 num_channels = input_.size(1) 24 # Permute the channel axis to first 25 permute_axes = list(range(input_.dim())) 26 permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0] 27 # For input shape (say) NCHW, this should have the shape CNHW 28 permuted = input_.permute(*permute_axes).contiguous() 29 # Now flatten out all but the first axis and return 30 flattened = permuted.view(num_channels, -1) 31 return flattened 32 33 34def dice_score( 35 input_: torch.Tensor, 36 target: torch.Tensor, 37 invert: bool = False, 38 channelwise: bool = True, 39 reduce_channel: Optional[str] = "sum", 40 eps: float = 1e-7, 41) -> torch.Tensor: 42 """Compute the dice score between input and target. 43 44 Args: 45 input_: The input tensor. 46 target: The target tensor. 47 invert: Whether to invert the returned dice score to obtain the dice error instead of the dice score. 48 channelwise: Whether to return the dice score independently per channel. 49 reduce_channel: How to return the dice score over the channel axis. 50 eps: The epsilon value added to the denominator for numerical stability. 51 52 Returns: 53 The dice score. 54 """ 55 if input_.shape != target.shape: 56 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 57 58 if channelwise: 59 # Flatten input and target to have the shape (C, N), 60 # where N is the number of samples 61 input_ = flatten_samples(input_) 62 target = flatten_samples(target) 63 # Compute numerator and denominator (by summing over samples and 64 # leaving the channels intact) 65 numerator = (input_ * target).sum(-1) 66 denominator = (input_ * input_).sum(-1) + (target * target).sum(-1) 67 channelwise_score = 2 * (numerator / denominator.clamp(min=eps)) 68 if invert: 69 channelwise_score = 1. - channelwise_score 70 71 # Reduce the dice score over the channels to compute the overall dice score. 72 # (default is to use the sum) 73 if reduce_channel is None: 74 score = channelwise_score 75 elif reduce_channel == "sum": 76 score = channelwise_score.sum() 77 elif reduce_channel == "mean": 78 score = channelwise_score.mean() 79 elif reduce_channel == "max": 80 score = channelwise_score.max() 81 elif reduce_channel == "min": 82 score = channelwise_score.min() 83 else: 84 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 85 86 else: 87 numerator = (input_ * target).sum() 88 denominator = (input_ * input_).sum() + (target * target).sum() 89 score = 2. * (numerator / denominator.clamp(min=eps)) 90 if invert: 91 score = 1. - score 92 93 return score 94 95 96class DiceLoss(nn.Module): 97 """Loss computed based on the dice error between a binary input and binary target. 98 99 Args: 100 channelwise: Whether to return the dice score independently per channel. 101 eps: The epsilon value added to the denominator for numerical stability. 102 reduce_channel: How to return the dice score over the channel axis. 103 """ 104 def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"): 105 if reduce_channel not in ("sum", "mean", "max", "min", None): 106 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 107 108 super().__init__() 109 self.channelwise = channelwise 110 self.eps = eps 111 self.reduce_channel = reduce_channel 112 113 # all torch_em classes should store init kwargs to easily recreate the init call 114 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel} 115 116 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 117 """Compute the loss. 118 119 Args: 120 input_: The binary input. 121 target: The binary target. 122 123 Returns: 124 The dice loss. 125 """ 126 return dice_score( 127 input_=input_, 128 target=target, 129 invert=True, 130 channelwise=self.channelwise, 131 eps=self.eps, 132 reduce_channel=self.reduce_channel 133 ) 134 135 136class DiceLossWithLogits(nn.Module): 137 """Loss computed based on the dice error between logits and binary target. 138 139 Args: 140 channelwise: Whether to return the dice score independently per channel. 141 eps: The epsilon value added to the denominator for numerical stability. 142 reduce_channel: How to return the dice score over the channel axis. 143 """ 144 def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"): 145 if reduce_channel not in ("sum", "mean", "max", "min", None): 146 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 147 148 super().__init__() 149 self.channelwise = channelwise 150 self.eps = eps 151 self.reduce_channel = reduce_channel 152 153 # all torch_em classes should store init kwargs to easily recreate the init call 154 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel} 155 156 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 157 """Compute the loss. 158 159 Args: 160 input_: The logits. 161 target: The binary target. 162 163 Returns: 164 The dice loss. 165 """ 166 return dice_score( 167 input_=nn.functional.sigmoid(input_), 168 target=target, 169 invert=True, 170 channelwise=self.channelwise, 171 eps=self.eps, 172 reduce_channel=self.reduce_channel, 173 ) 174 175 176class BCEDiceLoss(nn.Module): 177 """Loss computed based on the binary cross entropy and the dice error between binary inputs and binary target. 178 179 Args: 180 alpha: The weight for combining the BCE and dice loss. 181 channelwise: Whether to return the dice score independently per channel. 182 eps: The epsilon value added to the denominator for numerical stability. 183 reduce_channel: How to return the dice score over the channel axis. 184 """ 185 def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7): 186 super().__init__() 187 self.alpha = alpha 188 self.beta = beta 189 self.channelwise = channelwise 190 self.eps = eps 191 192 # All torch_em classes should store init kwargs to easily recreate the init call. 193 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps} 194 195 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 196 """Compute the loss. 197 198 Args: 199 input_: The binary input. 200 target: The binary target. 201 202 Returns: 203 The combined BCE and dice loss. 204 """ 205 loss_dice = dice_score( 206 input_=input_, 207 target=target, 208 invert=True, 209 channelwise=self.channelwise, 210 eps=self.eps 211 ) 212 loss_bce = nn.functional.binary_cross_entropy(input_, target) 213 return self.alpha * loss_dice + self.beta * loss_bce 214 215 216# TODO think about how to handle combined losses like this for mixed precision training 217class BCEDiceLossWithLogits(nn.Module): 218 """Loss computed based on the binary cross entropy and the dice error between logits and binary target. 219 220 Args: 221 alpha: The weight for combining the BCE and dice loss. 222 channelwise: Whether to return the dice score independently per channel. 223 eps: The epsilon value added to the denominator for numerical stability. 224 reduce_channel: How to return the dice score over the channel axis. 225 """ 226 def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7): 227 super().__init__() 228 self.alpha = alpha 229 self.beta = beta 230 self.channelwise = channelwise 231 self.eps = eps 232 233 # All torch_em classes should store init kwargs to easily recreate the init call. 234 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps} 235 236 def forward(self, input_, target): 237 """Compute the loss. 238 239 Args: 240 input_: The logits. 241 target: The binary target. 242 243 Returns: 244 The combined BCE and dice loss. 245 """ 246 loss_dice = dice_score( 247 input_=nn.functional.sigmoid(input_), 248 target=target, 249 invert=True, 250 channelwise=self.channelwise, 251 eps=self.eps 252 ) 253 254 loss_bce = nn.functional.binary_cross_entropy_with_logits(input_, target) 255 256 return self.alpha * loss_dice + self.beta * loss_bce
8def flatten_samples(input_: torch.Tensor) -> torch.Tensor: 9 """Flattens a tensor or a variable such that the channel axis is first and the sample (batch) axis is second. 10 11 The shapes are transformed as follows: 12 (N, C, H, W) --> (C, N * H * W) 13 (N, C, D, H, W) --> (C, N * D * H * W) 14 (N, C) --> (C, N) 15 The input must be atleast 2d. 16 17 Args: 18 The input tensor. 19 20 Returns: 21 The transformed input tensor. 22 """ 23 # Get number of channels 24 num_channels = input_.size(1) 25 # Permute the channel axis to first 26 permute_axes = list(range(input_.dim())) 27 permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0] 28 # For input shape (say) NCHW, this should have the shape CNHW 29 permuted = input_.permute(*permute_axes).contiguous() 30 # Now flatten out all but the first axis and return 31 flattened = permuted.view(num_channels, -1) 32 return flattened
Flattens a tensor or a variable such that the channel axis is first and the sample (batch) axis is second.
The shapes are transformed as follows:
(N, C, H, W) --> (C, N * H * W) (N, C, D, H, W) --> (C, N * D * H * W) (N, C) --> (C, N)
The input must be atleast 2d.
Arguments:
- The input tensor.
Returns:
The transformed input tensor.
35def dice_score( 36 input_: torch.Tensor, 37 target: torch.Tensor, 38 invert: bool = False, 39 channelwise: bool = True, 40 reduce_channel: Optional[str] = "sum", 41 eps: float = 1e-7, 42) -> torch.Tensor: 43 """Compute the dice score between input and target. 44 45 Args: 46 input_: The input tensor. 47 target: The target tensor. 48 invert: Whether to invert the returned dice score to obtain the dice error instead of the dice score. 49 channelwise: Whether to return the dice score independently per channel. 50 reduce_channel: How to return the dice score over the channel axis. 51 eps: The epsilon value added to the denominator for numerical stability. 52 53 Returns: 54 The dice score. 55 """ 56 if input_.shape != target.shape: 57 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 58 59 if channelwise: 60 # Flatten input and target to have the shape (C, N), 61 # where N is the number of samples 62 input_ = flatten_samples(input_) 63 target = flatten_samples(target) 64 # Compute numerator and denominator (by summing over samples and 65 # leaving the channels intact) 66 numerator = (input_ * target).sum(-1) 67 denominator = (input_ * input_).sum(-1) + (target * target).sum(-1) 68 channelwise_score = 2 * (numerator / denominator.clamp(min=eps)) 69 if invert: 70 channelwise_score = 1. - channelwise_score 71 72 # Reduce the dice score over the channels to compute the overall dice score. 73 # (default is to use the sum) 74 if reduce_channel is None: 75 score = channelwise_score 76 elif reduce_channel == "sum": 77 score = channelwise_score.sum() 78 elif reduce_channel == "mean": 79 score = channelwise_score.mean() 80 elif reduce_channel == "max": 81 score = channelwise_score.max() 82 elif reduce_channel == "min": 83 score = channelwise_score.min() 84 else: 85 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 86 87 else: 88 numerator = (input_ * target).sum() 89 denominator = (input_ * input_).sum() + (target * target).sum() 90 score = 2. * (numerator / denominator.clamp(min=eps)) 91 if invert: 92 score = 1. - score 93 94 return score
Compute the dice score between input and target.
Arguments:
- input_: The input tensor.
- target: The target tensor.
- invert: Whether to invert the returned dice score to obtain the dice error instead of the dice score.
- channelwise: Whether to return the dice score independently per channel.
- reduce_channel: How to return the dice score over the channel axis.
- eps: The epsilon value added to the denominator for numerical stability.
Returns:
The dice score.
97class DiceLoss(nn.Module): 98 """Loss computed based on the dice error between a binary input and binary target. 99 100 Args: 101 channelwise: Whether to return the dice score independently per channel. 102 eps: The epsilon value added to the denominator for numerical stability. 103 reduce_channel: How to return the dice score over the channel axis. 104 """ 105 def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"): 106 if reduce_channel not in ("sum", "mean", "max", "min", None): 107 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 108 109 super().__init__() 110 self.channelwise = channelwise 111 self.eps = eps 112 self.reduce_channel = reduce_channel 113 114 # all torch_em classes should store init kwargs to easily recreate the init call 115 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel} 116 117 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 118 """Compute the loss. 119 120 Args: 121 input_: The binary input. 122 target: The binary target. 123 124 Returns: 125 The dice loss. 126 """ 127 return dice_score( 128 input_=input_, 129 target=target, 130 invert=True, 131 channelwise=self.channelwise, 132 eps=self.eps, 133 reduce_channel=self.reduce_channel 134 )
Loss computed based on the dice error between a binary input and binary target.
Arguments:
- channelwise: Whether to return the dice score independently per channel.
- eps: The epsilon value added to the denominator for numerical stability.
- reduce_channel: How to return the dice score over the channel axis.
105 def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"): 106 if reduce_channel not in ("sum", "mean", "max", "min", None): 107 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 108 109 super().__init__() 110 self.channelwise = channelwise 111 self.eps = eps 112 self.reduce_channel = reduce_channel 113 114 # all torch_em classes should store init kwargs to easily recreate the init call 115 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
117 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 118 """Compute the loss. 119 120 Args: 121 input_: The binary input. 122 target: The binary target. 123 124 Returns: 125 The dice loss. 126 """ 127 return dice_score( 128 input_=input_, 129 target=target, 130 invert=True, 131 channelwise=self.channelwise, 132 eps=self.eps, 133 reduce_channel=self.reduce_channel 134 )
Compute the loss.
Arguments:
- input_: The binary input.
- target: The binary target.
Returns:
The dice loss.
137class DiceLossWithLogits(nn.Module): 138 """Loss computed based on the dice error between logits and binary target. 139 140 Args: 141 channelwise: Whether to return the dice score independently per channel. 142 eps: The epsilon value added to the denominator for numerical stability. 143 reduce_channel: How to return the dice score over the channel axis. 144 """ 145 def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"): 146 if reduce_channel not in ("sum", "mean", "max", "min", None): 147 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 148 149 super().__init__() 150 self.channelwise = channelwise 151 self.eps = eps 152 self.reduce_channel = reduce_channel 153 154 # all torch_em classes should store init kwargs to easily recreate the init call 155 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel} 156 157 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 158 """Compute the loss. 159 160 Args: 161 input_: The logits. 162 target: The binary target. 163 164 Returns: 165 The dice loss. 166 """ 167 return dice_score( 168 input_=nn.functional.sigmoid(input_), 169 target=target, 170 invert=True, 171 channelwise=self.channelwise, 172 eps=self.eps, 173 reduce_channel=self.reduce_channel, 174 )
Loss computed based on the dice error between logits and binary target.
Arguments:
- channelwise: Whether to return the dice score independently per channel.
- eps: The epsilon value added to the denominator for numerical stability.
- reduce_channel: How to return the dice score over the channel axis.
145 def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"): 146 if reduce_channel not in ("sum", "mean", "max", "min", None): 147 raise ValueError(f"Unsupported channel reduction {reduce_channel}") 148 149 super().__init__() 150 self.channelwise = channelwise 151 self.eps = eps 152 self.reduce_channel = reduce_channel 153 154 # all torch_em classes should store init kwargs to easily recreate the init call 155 self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
157 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 158 """Compute the loss. 159 160 Args: 161 input_: The logits. 162 target: The binary target. 163 164 Returns: 165 The dice loss. 166 """ 167 return dice_score( 168 input_=nn.functional.sigmoid(input_), 169 target=target, 170 invert=True, 171 channelwise=self.channelwise, 172 eps=self.eps, 173 reduce_channel=self.reduce_channel, 174 )
Compute the loss.
Arguments:
- input_: The logits.
- target: The binary target.
Returns:
The dice loss.
177class BCEDiceLoss(nn.Module): 178 """Loss computed based on the binary cross entropy and the dice error between binary inputs and binary target. 179 180 Args: 181 alpha: The weight for combining the BCE and dice loss. 182 channelwise: Whether to return the dice score independently per channel. 183 eps: The epsilon value added to the denominator for numerical stability. 184 reduce_channel: How to return the dice score over the channel axis. 185 """ 186 def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7): 187 super().__init__() 188 self.alpha = alpha 189 self.beta = beta 190 self.channelwise = channelwise 191 self.eps = eps 192 193 # All torch_em classes should store init kwargs to easily recreate the init call. 194 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps} 195 196 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 197 """Compute the loss. 198 199 Args: 200 input_: The binary input. 201 target: The binary target. 202 203 Returns: 204 The combined BCE and dice loss. 205 """ 206 loss_dice = dice_score( 207 input_=input_, 208 target=target, 209 invert=True, 210 channelwise=self.channelwise, 211 eps=self.eps 212 ) 213 loss_bce = nn.functional.binary_cross_entropy(input_, target) 214 return self.alpha * loss_dice + self.beta * loss_bce
Loss computed based on the binary cross entropy and the dice error between binary inputs and binary target.
Arguments:
- alpha: The weight for combining the BCE and dice loss.
- channelwise: Whether to return the dice score independently per channel.
- eps: The epsilon value added to the denominator for numerical stability.
- reduce_channel: How to return the dice score over the channel axis.
186 def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7): 187 super().__init__() 188 self.alpha = alpha 189 self.beta = beta 190 self.channelwise = channelwise 191 self.eps = eps 192 193 # All torch_em classes should store init kwargs to easily recreate the init call. 194 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
196 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 197 """Compute the loss. 198 199 Args: 200 input_: The binary input. 201 target: The binary target. 202 203 Returns: 204 The combined BCE and dice loss. 205 """ 206 loss_dice = dice_score( 207 input_=input_, 208 target=target, 209 invert=True, 210 channelwise=self.channelwise, 211 eps=self.eps 212 ) 213 loss_bce = nn.functional.binary_cross_entropy(input_, target) 214 return self.alpha * loss_dice + self.beta * loss_bce
Compute the loss.
Arguments:
- input_: The binary input.
- target: The binary target.
Returns:
The combined BCE and dice loss.
218class BCEDiceLossWithLogits(nn.Module): 219 """Loss computed based on the binary cross entropy and the dice error between logits and binary target. 220 221 Args: 222 alpha: The weight for combining the BCE and dice loss. 223 channelwise: Whether to return the dice score independently per channel. 224 eps: The epsilon value added to the denominator for numerical stability. 225 reduce_channel: How to return the dice score over the channel axis. 226 """ 227 def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7): 228 super().__init__() 229 self.alpha = alpha 230 self.beta = beta 231 self.channelwise = channelwise 232 self.eps = eps 233 234 # All torch_em classes should store init kwargs to easily recreate the init call. 235 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps} 236 237 def forward(self, input_, target): 238 """Compute the loss. 239 240 Args: 241 input_: The logits. 242 target: The binary target. 243 244 Returns: 245 The combined BCE and dice loss. 246 """ 247 loss_dice = dice_score( 248 input_=nn.functional.sigmoid(input_), 249 target=target, 250 invert=True, 251 channelwise=self.channelwise, 252 eps=self.eps 253 ) 254 255 loss_bce = nn.functional.binary_cross_entropy_with_logits(input_, target) 256 257 return self.alpha * loss_dice + self.beta * loss_bce
Loss computed based on the binary cross entropy and the dice error between logits and binary target.
Arguments:
- alpha: The weight for combining the BCE and dice loss.
- channelwise: Whether to return the dice score independently per channel.
- eps: The epsilon value added to the denominator for numerical stability.
- reduce_channel: How to return the dice score over the channel axis.
227 def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7): 228 super().__init__() 229 self.alpha = alpha 230 self.beta = beta 231 self.channelwise = channelwise 232 self.eps = eps 233 234 # All torch_em classes should store init kwargs to easily recreate the init call. 235 self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
237 def forward(self, input_, target): 238 """Compute the loss. 239 240 Args: 241 input_: The logits. 242 target: The binary target. 243 244 Returns: 245 The combined BCE and dice loss. 246 """ 247 loss_dice = dice_score( 248 input_=nn.functional.sigmoid(input_), 249 target=target, 250 invert=True, 251 channelwise=self.channelwise, 252 eps=self.eps 253 ) 254 255 loss_bce = nn.functional.binary_cross_entropy_with_logits(input_, target) 256 257 return self.alpha * loss_dice + self.beta * loss_bce
Compute the loss.
Arguments:
- input_: The logits.
- target: The binary target.
Returns:
The combined BCE and dice loss.