torch_em.loss.cldice
1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4 5from .dice import dice_score 6 7# From "clDice -- A Novel Topology-Preserving Loss Function for Tubular Structure Segmentation": 8# https://arxiv.org/abs/2003.07311 9 10 11class SoftSkeletonize(torch.nn.Module): 12 """`SoftSkeletonize` is a differentiable approximation for skeletonization, 13 which applies iterative min- and max-pooling as a proxy for 14 morphological erosion and dilation. 15 16 Args: 17 num_iter: Number of iterations for soft-skeletonization. 18 Should be greater or equal to than the maximum observed radius. 19 """ 20 def __init__(self, num_iter: int = 5): 21 22 super(SoftSkeletonize, self).__init__() 23 self.num_iter = num_iter 24 25 def soft_erode(self, input_: torch.Tensor): 26 27 if len(input_.shape) == 4: 28 p1 = -F.max_pool2d(-input_, (3, 1), (1, 1), (1, 0)) 29 p2 = -F.max_pool2d(-input_, (1, 3), (1, 1), (0, 1)) 30 return torch.min(p1, p2) 31 elif len(input_.shape) == 5: 32 p1 = -F.max_pool3d(-input_, (3, 1, 1), (1, 1, 1), (1, 0, 0)) 33 p2 = -F.max_pool3d(-input_, (1, 3, 1), (1, 1, 1), (0, 1, 0)) 34 p3 = -F.max_pool3d(-input_, (1, 1, 3), (1, 1, 1), (0, 0, 1)) 35 return torch.min(torch.min(p1, p2), p3) 36 37 def soft_dilate(self, input_: torch.Tensor): 38 39 if len(input_.shape) == 4: 40 return F.max_pool2d(input_, (3, 3), (1, 1), (1, 1)) 41 elif len(input_.shape) == 5: 42 return F.max_pool3d(input_, (3, 3, 3), (1, 1, 1), (1, 1, 1)) 43 44 def soft_open(self, input_: torch.Tensor): 45 46 return self.soft_dilate(self.soft_erode(input_)) 47 48 def soft_skel(self, input_: torch.Tensor): 49 50 input1 = self.soft_open(input_) 51 skel = F.relu(input_ - input1) 52 53 for j in range(self.num_iter): 54 input_ = self.soft_erode(input_) 55 input1 = self.soft_open(input_) 56 delta = F.relu(input_-input1) 57 skel = skel + F.relu(delta - skel * delta) 58 59 return skel 60 61 def forward(self, input_: torch.Tensor): 62 """Skeletonize the input prediction. 63 64 Args: 65 input_: The input logits. 66 67 Returns: 68 The skeletonization. 69 """ 70 return self.soft_skel(input_) 71 72 73def cldice_score( 74 input_: torch.Tensor, 75 target: torch.Tensor, 76 num_iter: int = 5, 77 invert: bool = False, 78 eps: float = 1e-7, 79) -> torch.Tensor: 80 """Adapted from .dice.py `dice_score`. Compute the soft clDice score between input and target. 81 82 Args: 83 input_: The input tensor. 84 target: The target tensor. 85 num_iter: Number of iterations for soft-skeletonization. 86 invert: Whether to invert the returned dice score to obtain the cldice error instead of the cldice score. 87 channelwise: Not implemented; whether to return the dice score independently per channel. 88 reduce_channel: Not implemented; how to return the dice score over the channel axis. 89 eps: The epsilon value added to the denominator for numerical stability. 90 91 Returns: 92 The clDice score. 93 """ 94 if input_.shape != target.shape: 95 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 96 97 soft_skeletonize = SoftSkeletonize(num_iter=num_iter) 98 skel_input = soft_skeletonize(input_) 99 skel_target = soft_skeletonize(target) 100 101 t_prec = (skel_input * target).sum() / (skel_input.sum()).clamp(min=eps) 102 t_sens = (skel_target * input_).sum() / (skel_target.sum()).clamp(min=eps) 103 score = 2.*(t_prec*t_sens)/(t_prec+t_sens).clamp(min=eps) 104 105 if invert: 106 score = 1. - score 107 108 return score 109 110 111class SoftclDiceLoss(nn.Module): 112 """Combined soft Dice and clDice loss for segmentation of tubular structures. 113 114 The soft clDice loss computes topology-aware loss by computing the 115 soft skeleton of both the prediction and target 116 and measuring overlap of the two skeletons. Teaches the model to learn 117 skeletons directly. In the clDice paper, the authors recommend using 118 the combined soft-Dice and soft-clDice loss to learn topology-aware 119 segmentations, which is implemented below as `CombinedclDiceLoss`. 120 121 Args: 122 num_iter: Number of iterations for soft-skeletonization. 123 eps: The epsilon value added to the denominator for numerical 124 stability. 125 exclude_background: Whether to exclude background channel 0 from the 126 loss computation. 127 Useful for multi-class segmentation. 128 channelwise: Not implemented; Whether to return the dice score 129 independently per channel. 130 reduce_channel: Not implemented; The epsilon value added to the 131 denominator for numerical stability. 132 """ 133 def __init__(self, num_iter: int = 5, eps: float = 1e-7, 134 exclude_background: bool = False): 135 super(SoftclDiceLoss, self).__init__() 136 137 self.num_iter = num_iter 138 self.eps = eps 139 self.exclude_background = exclude_background 140 self.init_kwargs = {"num_iter": num_iter, "eps": eps, "exclude_background": exclude_background} 141 142 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 143 """Compute soft clDice score between the input logits and binary target. 144 145 Args: 146 input_: The input logits. 147 target: The binary target. 148 149 Returns: 150 The soft clDice score. 151 """ 152 if input_.shape != target.shape: 153 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 154 155 if self.exclude_background: 156 target = target[:, 1:, :, :] 157 input_ = input_[:, 1:, :, :] 158 159 cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps) 160 161 return cldice 162 163 164# TODO implement `channelwise` for multiclass segmentation 165# TODO consider if `exclude_background` is needed for multiclass segmentation 166class CombinedclDiceLoss(SoftclDiceLoss): 167 """Combined soft-Dice and soft-clDice loss for segmentation of tubular structures. 168 169 The soft-clDice loss computes topology-aware loss by computing the 170 soft skeleton of both the prediction and target and measuring overlap 171 of the two skeletons. This encourages the model to preserve the 172 connectivity and topology of tubular structures. The final loss is a 173 weighted combination of soft Dice and clDice, controlled by alpha. 174 175 Args: 176 num_iter: Number of iterations for soft-skeletonization. 177 alpha: The weight for combining the soft Dice and soft clDice loss. 178 eps: The epsilon value added to the denominator for numerical 179 stability. 180 exclude_background: Whether to exclude background channel 0 from the 181 loss computation. Useful for multi-class segmentation. 182 invert: Not implemented; Whether to invert the returned dice score to 183 obtain the dice error instead of the dice score. 184 channelwise: Not implemented; Whether to return the dice score 185 independently per channel. 186 reduce_chnanel: Not implemented; How to return the dice score over the 187 channel axis. 188 189 """ 190 def __init__(self, num_iter: int = 5, alpha: float = 0.5, eps: float = 1e-7, 191 exclude_background: bool = False): 192 super(CombinedclDiceLoss, self).__init__(num_iter=num_iter, eps=eps, exclude_background=exclude_background) 193 194 self.alpha = alpha 195 self.init_kwargs = {"num_iter": num_iter, "alpha": alpha, "eps": eps, "exclude_background": exclude_background} 196 197 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 198 """Compute combined soft Dice and clDice loss. 199 200 Args: 201 input_: The input logits. 202 target: The binary target. 203 204 Returns: 205 Combined clDice loss. 206 """ 207 if input_.shape != target.shape: 208 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 209 210 if self.exclude_background: 211 target = target[:, 1:, :, :] 212 input_ = input_[:, 1:, :, :] 213 dice = dice_score(input_, target, invert=True, channelwise=False, eps=self.eps) 214 cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps) 215 216 return (1.0-self.alpha)*dice+self.alpha*cldice
12class SoftSkeletonize(torch.nn.Module): 13 """`SoftSkeletonize` is a differentiable approximation for skeletonization, 14 which applies iterative min- and max-pooling as a proxy for 15 morphological erosion and dilation. 16 17 Args: 18 num_iter: Number of iterations for soft-skeletonization. 19 Should be greater or equal to than the maximum observed radius. 20 """ 21 def __init__(self, num_iter: int = 5): 22 23 super(SoftSkeletonize, self).__init__() 24 self.num_iter = num_iter 25 26 def soft_erode(self, input_: torch.Tensor): 27 28 if len(input_.shape) == 4: 29 p1 = -F.max_pool2d(-input_, (3, 1), (1, 1), (1, 0)) 30 p2 = -F.max_pool2d(-input_, (1, 3), (1, 1), (0, 1)) 31 return torch.min(p1, p2) 32 elif len(input_.shape) == 5: 33 p1 = -F.max_pool3d(-input_, (3, 1, 1), (1, 1, 1), (1, 0, 0)) 34 p2 = -F.max_pool3d(-input_, (1, 3, 1), (1, 1, 1), (0, 1, 0)) 35 p3 = -F.max_pool3d(-input_, (1, 1, 3), (1, 1, 1), (0, 0, 1)) 36 return torch.min(torch.min(p1, p2), p3) 37 38 def soft_dilate(self, input_: torch.Tensor): 39 40 if len(input_.shape) == 4: 41 return F.max_pool2d(input_, (3, 3), (1, 1), (1, 1)) 42 elif len(input_.shape) == 5: 43 return F.max_pool3d(input_, (3, 3, 3), (1, 1, 1), (1, 1, 1)) 44 45 def soft_open(self, input_: torch.Tensor): 46 47 return self.soft_dilate(self.soft_erode(input_)) 48 49 def soft_skel(self, input_: torch.Tensor): 50 51 input1 = self.soft_open(input_) 52 skel = F.relu(input_ - input1) 53 54 for j in range(self.num_iter): 55 input_ = self.soft_erode(input_) 56 input1 = self.soft_open(input_) 57 delta = F.relu(input_-input1) 58 skel = skel + F.relu(delta - skel * delta) 59 60 return skel 61 62 def forward(self, input_: torch.Tensor): 63 """Skeletonize the input prediction. 64 65 Args: 66 input_: The input logits. 67 68 Returns: 69 The skeletonization. 70 """ 71 return self.soft_skel(input_)
SoftSkeletonize is a differentiable approximation for skeletonization,
which applies iterative min- and max-pooling as a proxy for
morphological erosion and dilation.
Arguments:
- num_iter: Number of iterations for soft-skeletonization. Should be greater or equal to than the maximum observed radius.
21 def __init__(self, num_iter: int = 5): 22 23 super(SoftSkeletonize, self).__init__() 24 self.num_iter = num_iter
Initialize internal Module state, shared by both nn.Module and ScriptModule.
26 def soft_erode(self, input_: torch.Tensor): 27 28 if len(input_.shape) == 4: 29 p1 = -F.max_pool2d(-input_, (3, 1), (1, 1), (1, 0)) 30 p2 = -F.max_pool2d(-input_, (1, 3), (1, 1), (0, 1)) 31 return torch.min(p1, p2) 32 elif len(input_.shape) == 5: 33 p1 = -F.max_pool3d(-input_, (3, 1, 1), (1, 1, 1), (1, 0, 0)) 34 p2 = -F.max_pool3d(-input_, (1, 3, 1), (1, 1, 1), (0, 1, 0)) 35 p3 = -F.max_pool3d(-input_, (1, 1, 3), (1, 1, 1), (0, 0, 1)) 36 return torch.min(torch.min(p1, p2), p3)
49 def soft_skel(self, input_: torch.Tensor): 50 51 input1 = self.soft_open(input_) 52 skel = F.relu(input_ - input1) 53 54 for j in range(self.num_iter): 55 input_ = self.soft_erode(input_) 56 input1 = self.soft_open(input_) 57 delta = F.relu(input_-input1) 58 skel = skel + F.relu(delta - skel * delta) 59 60 return skel
62 def forward(self, input_: torch.Tensor): 63 """Skeletonize the input prediction. 64 65 Args: 66 input_: The input logits. 67 68 Returns: 69 The skeletonization. 70 """ 71 return self.soft_skel(input_)
Skeletonize the input prediction.
Arguments:
- input_: The input logits.
Returns:
The skeletonization.
74def cldice_score( 75 input_: torch.Tensor, 76 target: torch.Tensor, 77 num_iter: int = 5, 78 invert: bool = False, 79 eps: float = 1e-7, 80) -> torch.Tensor: 81 """Adapted from .dice.py `dice_score`. Compute the soft clDice score between input and target. 82 83 Args: 84 input_: The input tensor. 85 target: The target tensor. 86 num_iter: Number of iterations for soft-skeletonization. 87 invert: Whether to invert the returned dice score to obtain the cldice error instead of the cldice score. 88 channelwise: Not implemented; whether to return the dice score independently per channel. 89 reduce_channel: Not implemented; how to return the dice score over the channel axis. 90 eps: The epsilon value added to the denominator for numerical stability. 91 92 Returns: 93 The clDice score. 94 """ 95 if input_.shape != target.shape: 96 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 97 98 soft_skeletonize = SoftSkeletonize(num_iter=num_iter) 99 skel_input = soft_skeletonize(input_) 100 skel_target = soft_skeletonize(target) 101 102 t_prec = (skel_input * target).sum() / (skel_input.sum()).clamp(min=eps) 103 t_sens = (skel_target * input_).sum() / (skel_target.sum()).clamp(min=eps) 104 score = 2.*(t_prec*t_sens)/(t_prec+t_sens).clamp(min=eps) 105 106 if invert: 107 score = 1. - score 108 109 return score
Adapted from .dice.py dice_score. Compute the soft clDice score between input and target.
Arguments:
- input_: The input tensor.
- target: The target tensor.
- num_iter: Number of iterations for soft-skeletonization.
- invert: Whether to invert the returned dice score to obtain the cldice error instead of the cldice score.
- channelwise: Not implemented; whether to return the dice score independently per channel.
- reduce_channel: Not implemented; how to return the dice score over the channel axis.
- eps: The epsilon value added to the denominator for numerical stability.
Returns:
The clDice score.
112class SoftclDiceLoss(nn.Module): 113 """Combined soft Dice and clDice loss for segmentation of tubular structures. 114 115 The soft clDice loss computes topology-aware loss by computing the 116 soft skeleton of both the prediction and target 117 and measuring overlap of the two skeletons. Teaches the model to learn 118 skeletons directly. In the clDice paper, the authors recommend using 119 the combined soft-Dice and soft-clDice loss to learn topology-aware 120 segmentations, which is implemented below as `CombinedclDiceLoss`. 121 122 Args: 123 num_iter: Number of iterations for soft-skeletonization. 124 eps: The epsilon value added to the denominator for numerical 125 stability. 126 exclude_background: Whether to exclude background channel 0 from the 127 loss computation. 128 Useful for multi-class segmentation. 129 channelwise: Not implemented; Whether to return the dice score 130 independently per channel. 131 reduce_channel: Not implemented; The epsilon value added to the 132 denominator for numerical stability. 133 """ 134 def __init__(self, num_iter: int = 5, eps: float = 1e-7, 135 exclude_background: bool = False): 136 super(SoftclDiceLoss, self).__init__() 137 138 self.num_iter = num_iter 139 self.eps = eps 140 self.exclude_background = exclude_background 141 self.init_kwargs = {"num_iter": num_iter, "eps": eps, "exclude_background": exclude_background} 142 143 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 144 """Compute soft clDice score between the input logits and binary target. 145 146 Args: 147 input_: The input logits. 148 target: The binary target. 149 150 Returns: 151 The soft clDice score. 152 """ 153 if input_.shape != target.shape: 154 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 155 156 if self.exclude_background: 157 target = target[:, 1:, :, :] 158 input_ = input_[:, 1:, :, :] 159 160 cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps) 161 162 return cldice
Combined soft Dice and clDice loss for segmentation of tubular structures.
The soft clDice loss computes topology-aware loss by computing the
soft skeleton of both the prediction and target
and measuring overlap of the two skeletons. Teaches the model to learn
skeletons directly. In the clDice paper, the authors recommend using
the combined soft-Dice and soft-clDice loss to learn topology-aware
segmentations, which is implemented below as `CombinedclDiceLoss`.
Arguments:
- num_iter: Number of iterations for soft-skeletonization.
- eps: The epsilon value added to the denominator for numerical stability.
- exclude_background: Whether to exclude background channel 0 from the loss computation. Useful for multi-class segmentation.
- channelwise: Not implemented; Whether to return the dice score independently per channel.
- reduce_channel: Not implemented; The epsilon value added to the denominator for numerical stability.
134 def __init__(self, num_iter: int = 5, eps: float = 1e-7, 135 exclude_background: bool = False): 136 super(SoftclDiceLoss, self).__init__() 137 138 self.num_iter = num_iter 139 self.eps = eps 140 self.exclude_background = exclude_background 141 self.init_kwargs = {"num_iter": num_iter, "eps": eps, "exclude_background": exclude_background}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
143 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 144 """Compute soft clDice score between the input logits and binary target. 145 146 Args: 147 input_: The input logits. 148 target: The binary target. 149 150 Returns: 151 The soft clDice score. 152 """ 153 if input_.shape != target.shape: 154 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 155 156 if self.exclude_background: 157 target = target[:, 1:, :, :] 158 input_ = input_[:, 1:, :, :] 159 160 cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps) 161 162 return cldice
Compute soft clDice score between the input logits and binary target.
Arguments:
- input_: The input logits.
- target: The binary target.
Returns:
The soft clDice score.
167class CombinedclDiceLoss(SoftclDiceLoss): 168 """Combined soft-Dice and soft-clDice loss for segmentation of tubular structures. 169 170 The soft-clDice loss computes topology-aware loss by computing the 171 soft skeleton of both the prediction and target and measuring overlap 172 of the two skeletons. This encourages the model to preserve the 173 connectivity and topology of tubular structures. The final loss is a 174 weighted combination of soft Dice and clDice, controlled by alpha. 175 176 Args: 177 num_iter: Number of iterations for soft-skeletonization. 178 alpha: The weight for combining the soft Dice and soft clDice loss. 179 eps: The epsilon value added to the denominator for numerical 180 stability. 181 exclude_background: Whether to exclude background channel 0 from the 182 loss computation. Useful for multi-class segmentation. 183 invert: Not implemented; Whether to invert the returned dice score to 184 obtain the dice error instead of the dice score. 185 channelwise: Not implemented; Whether to return the dice score 186 independently per channel. 187 reduce_chnanel: Not implemented; How to return the dice score over the 188 channel axis. 189 190 """ 191 def __init__(self, num_iter: int = 5, alpha: float = 0.5, eps: float = 1e-7, 192 exclude_background: bool = False): 193 super(CombinedclDiceLoss, self).__init__(num_iter=num_iter, eps=eps, exclude_background=exclude_background) 194 195 self.alpha = alpha 196 self.init_kwargs = {"num_iter": num_iter, "alpha": alpha, "eps": eps, "exclude_background": exclude_background} 197 198 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 199 """Compute combined soft Dice and clDice loss. 200 201 Args: 202 input_: The input logits. 203 target: The binary target. 204 205 Returns: 206 Combined clDice loss. 207 """ 208 if input_.shape != target.shape: 209 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 210 211 if self.exclude_background: 212 target = target[:, 1:, :, :] 213 input_ = input_[:, 1:, :, :] 214 dice = dice_score(input_, target, invert=True, channelwise=False, eps=self.eps) 215 cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps) 216 217 return (1.0-self.alpha)*dice+self.alpha*cldice
Combined soft-Dice and soft-clDice loss for segmentation of tubular structures.
The soft-clDice loss computes topology-aware loss by computing the
soft skeleton of both the prediction and target and measuring overlap
of the two skeletons. This encourages the model to preserve the
connectivity and topology of tubular structures. The final loss is a
weighted combination of soft Dice and clDice, controlled by alpha.
Arguments:
- num_iter: Number of iterations for soft-skeletonization.
- alpha: The weight for combining the soft Dice and soft clDice loss.
- eps: The epsilon value added to the denominator for numerical stability.
- exclude_background: Whether to exclude background channel 0 from the loss computation. Useful for multi-class segmentation.
- invert: Not implemented; Whether to invert the returned dice score to obtain the dice error instead of the dice score.
- channelwise: Not implemented; Whether to return the dice score independently per channel.
- reduce_chnanel: Not implemented; How to return the dice score over the channel axis.
191 def __init__(self, num_iter: int = 5, alpha: float = 0.5, eps: float = 1e-7, 192 exclude_background: bool = False): 193 super(CombinedclDiceLoss, self).__init__(num_iter=num_iter, eps=eps, exclude_background=exclude_background) 194 195 self.alpha = alpha 196 self.init_kwargs = {"num_iter": num_iter, "alpha": alpha, "eps": eps, "exclude_background": exclude_background}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
198 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 199 """Compute combined soft Dice and clDice loss. 200 201 Args: 202 input_: The input logits. 203 target: The binary target. 204 205 Returns: 206 Combined clDice loss. 207 """ 208 if input_.shape != target.shape: 209 raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.") 210 211 if self.exclude_background: 212 target = target[:, 1:, :, :] 213 input_ = input_[:, 1:, :, :] 214 dice = dice_score(input_, target, invert=True, channelwise=False, eps=self.eps) 215 cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps) 216 217 return (1.0-self.alpha)*dice+self.alpha*cldice
Compute combined soft Dice and clDice loss.
Arguments:
- input_: The input logits.
- target: The binary target.
Returns:
Combined clDice loss.