torch_em.loss.distance_based
1import torch 2import torch.nn as nn 3 4from .dice import DiceLoss 5 6 7class DistanceLoss(nn.Module): 8 """Loss for distance based instance segmentation. 9 10 Expects input and targets with three channels: foreground and two distance channels. 11 Typically the distance channels are centroid and inverted boundary distance. 12 13 Args: 14 mask_distances_in_bg: whether to mask the loss for distance predictions in the background. 15 foreground_loss: the loss for comparing foreground predictions and target. 16 distance_loss: the loss for comparing distance predictions and target. 17 """ 18 def __init__( 19 self, 20 mask_distances_in_bg: bool = True, 21 foreground_loss: nn.Module = DiceLoss(), 22 distance_loss: nn.Module = nn.MSELoss(reduction="mean") 23 ) -> None: 24 super().__init__() 25 26 self.foreground_loss = foreground_loss 27 self.distance_loss = distance_loss 28 self.mask_distances_in_bg = mask_distances_in_bg 29 30 self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg} 31 32 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 33 assert input_.shape == target.shape, input_.shape 34 assert input_.shape[1] == 3, input_.shape 35 36 # IMPORTANT: preserve the channels! 37 # Otherwise the Dice Loss will do all kinds of shennanigans. 38 # Because it always interprets the first axis as channel, 39 # and treats it differently (sums over it independently). 40 # This will lead to a very large dice loss that dominates over everything else. 41 fg_input, fg_target = input_[:, 0:1], target[:, 0:1] 42 fg_loss = self.foreground_loss(fg_input, fg_target) 43 44 cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2] 45 if self.mask_distances_in_bg: 46 mask = fg_target 47 cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask) 48 else: 49 cdist_loss = self.distance_loss(cdist_input, cdist_target) 50 51 bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3] 52 if self.mask_distances_in_bg: 53 mask = fg_target 54 bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask) 55 else: 56 bdist_loss = self.distance_loss(bdist_input, bdist_target) 57 58 overall_loss = fg_loss + cdist_loss + bdist_loss 59 return overall_loss 60 61 62class DiceBasedDistanceLoss(DistanceLoss): 63 """Similar to `DistanceLoss`, using the dice score for all losses. 64 65 Args: 66 mask_distances_in_bg: whether to mask the loss for distance predictions in the background. 67 """ 68 def __init__(self, mask_distances_in_bg: bool) -> None: 69 super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())
class
DistanceLoss(torch.nn.modules.module.Module):
8class DistanceLoss(nn.Module): 9 """Loss for distance based instance segmentation. 10 11 Expects input and targets with three channels: foreground and two distance channels. 12 Typically the distance channels are centroid and inverted boundary distance. 13 14 Args: 15 mask_distances_in_bg: whether to mask the loss for distance predictions in the background. 16 foreground_loss: the loss for comparing foreground predictions and target. 17 distance_loss: the loss for comparing distance predictions and target. 18 """ 19 def __init__( 20 self, 21 mask_distances_in_bg: bool = True, 22 foreground_loss: nn.Module = DiceLoss(), 23 distance_loss: nn.Module = nn.MSELoss(reduction="mean") 24 ) -> None: 25 super().__init__() 26 27 self.foreground_loss = foreground_loss 28 self.distance_loss = distance_loss 29 self.mask_distances_in_bg = mask_distances_in_bg 30 31 self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg} 32 33 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 34 assert input_.shape == target.shape, input_.shape 35 assert input_.shape[1] == 3, input_.shape 36 37 # IMPORTANT: preserve the channels! 38 # Otherwise the Dice Loss will do all kinds of shennanigans. 39 # Because it always interprets the first axis as channel, 40 # and treats it differently (sums over it independently). 41 # This will lead to a very large dice loss that dominates over everything else. 42 fg_input, fg_target = input_[:, 0:1], target[:, 0:1] 43 fg_loss = self.foreground_loss(fg_input, fg_target) 44 45 cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2] 46 if self.mask_distances_in_bg: 47 mask = fg_target 48 cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask) 49 else: 50 cdist_loss = self.distance_loss(cdist_input, cdist_target) 51 52 bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3] 53 if self.mask_distances_in_bg: 54 mask = fg_target 55 bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask) 56 else: 57 bdist_loss = self.distance_loss(bdist_input, bdist_target) 58 59 overall_loss = fg_loss + cdist_loss + bdist_loss 60 return overall_loss
Loss for distance based instance segmentation.
Expects input and targets with three channels: foreground and two distance channels. Typically the distance channels are centroid and inverted boundary distance.
Arguments:
- mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
- foreground_loss: the loss for comparing foreground predictions and target.
- distance_loss: the loss for comparing distance predictions and target.
DistanceLoss( mask_distances_in_bg: bool = True, foreground_loss: torch.nn.modules.module.Module = DiceLoss(), distance_loss: torch.nn.modules.module.Module = MSELoss())
19 def __init__( 20 self, 21 mask_distances_in_bg: bool = True, 22 foreground_loss: nn.Module = DiceLoss(), 23 distance_loss: nn.Module = nn.MSELoss(reduction="mean") 24 ) -> None: 25 super().__init__() 26 27 self.foreground_loss = foreground_loss 28 self.distance_loss = distance_loss 29 self.mask_distances_in_bg = mask_distances_in_bg 30 31 self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
33 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 34 assert input_.shape == target.shape, input_.shape 35 assert input_.shape[1] == 3, input_.shape 36 37 # IMPORTANT: preserve the channels! 38 # Otherwise the Dice Loss will do all kinds of shennanigans. 39 # Because it always interprets the first axis as channel, 40 # and treats it differently (sums over it independently). 41 # This will lead to a very large dice loss that dominates over everything else. 42 fg_input, fg_target = input_[:, 0:1], target[:, 0:1] 43 fg_loss = self.foreground_loss(fg_input, fg_target) 44 45 cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2] 46 if self.mask_distances_in_bg: 47 mask = fg_target 48 cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask) 49 else: 50 cdist_loss = self.distance_loss(cdist_input, cdist_target) 51 52 bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3] 53 if self.mask_distances_in_bg: 54 mask = fg_target 55 bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask) 56 else: 57 bdist_loss = self.distance_loss(bdist_input, bdist_target) 58 59 overall_loss = fg_loss + cdist_loss + bdist_loss 60 return overall_loss
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
63class DiceBasedDistanceLoss(DistanceLoss): 64 """Similar to `DistanceLoss`, using the dice score for all losses. 65 66 Args: 67 mask_distances_in_bg: whether to mask the loss for distance predictions in the background. 68 """ 69 def __init__(self, mask_distances_in_bg: bool) -> None: 70 super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())
Similar to DistanceLoss
, using the dice score for all losses.
Arguments:
- mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
DiceBasedDistanceLoss(mask_distances_in_bg: bool)
69 def __init__(self, mask_distances_in_bg: bool) -> None: 70 super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())
Initialize internal Module state, shared by both nn.Module and ScriptModule.