torch_em.loss.distance_based

 1import torch
 2import torch.nn as nn
 3
 4from .dice import DiceLoss
 5
 6
 7class DistanceLoss(nn.Module):
 8    """Loss for distance based instance segmentation.
 9
10    Expects input and targets with three channels: foreground and two distance channels.
11    Typically the distance channels are centroid and inverted boundary distance.
12
13    Args:
14        mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
15        foreground_loss: the loss for comparing foreground predictions and target.
16        distance_loss: the loss for comparing distance predictions and target.
17    """
18    def __init__(
19        self,
20        mask_distances_in_bg: bool = True,
21        foreground_loss: nn.Module = DiceLoss(),
22        distance_loss: nn.Module = nn.MSELoss(reduction="mean")
23    ) -> None:
24        super().__init__()
25
26        self.foreground_loss = foreground_loss
27        self.distance_loss = distance_loss
28        self.mask_distances_in_bg = mask_distances_in_bg
29
30        self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg}
31
32    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
33        assert input_.shape == target.shape, input_.shape
34        assert input_.shape[1] == 3, input_.shape
35
36        # IMPORTANT: preserve the channels!
37        # Otherwise the Dice Loss will do all kinds of shennanigans.
38        # Because it always interprets the first axis as channel,
39        # and treats it differently (sums over it independently).
40        # This will lead to a very large dice loss that dominates over everything else.
41        fg_input, fg_target = input_[:, 0:1], target[:, 0:1]
42        fg_loss = self.foreground_loss(fg_input, fg_target)
43
44        cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2]
45        if self.mask_distances_in_bg:
46            mask = fg_target
47            cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask)
48        else:
49            cdist_loss = self.distance_loss(cdist_input, cdist_target)
50
51        bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3]
52        if self.mask_distances_in_bg:
53            mask = fg_target
54            bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask)
55        else:
56            bdist_loss = self.distance_loss(bdist_input, bdist_target)
57
58        overall_loss = fg_loss + cdist_loss + bdist_loss
59        return overall_loss
60
61
62class DiceBasedDistanceLoss(DistanceLoss):
63    """Similar to `DistanceLoss`, using the dice score for all losses.
64
65    Args:
66        mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
67    """
68    def __init__(self, mask_distances_in_bg: bool) -> None:
69        super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())
class DistanceLoss(torch.nn.modules.module.Module):
 8class DistanceLoss(nn.Module):
 9    """Loss for distance based instance segmentation.
10
11    Expects input and targets with three channels: foreground and two distance channels.
12    Typically the distance channels are centroid and inverted boundary distance.
13
14    Args:
15        mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
16        foreground_loss: the loss for comparing foreground predictions and target.
17        distance_loss: the loss for comparing distance predictions and target.
18    """
19    def __init__(
20        self,
21        mask_distances_in_bg: bool = True,
22        foreground_loss: nn.Module = DiceLoss(),
23        distance_loss: nn.Module = nn.MSELoss(reduction="mean")
24    ) -> None:
25        super().__init__()
26
27        self.foreground_loss = foreground_loss
28        self.distance_loss = distance_loss
29        self.mask_distances_in_bg = mask_distances_in_bg
30
31        self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg}
32
33    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
34        assert input_.shape == target.shape, input_.shape
35        assert input_.shape[1] == 3, input_.shape
36
37        # IMPORTANT: preserve the channels!
38        # Otherwise the Dice Loss will do all kinds of shennanigans.
39        # Because it always interprets the first axis as channel,
40        # and treats it differently (sums over it independently).
41        # This will lead to a very large dice loss that dominates over everything else.
42        fg_input, fg_target = input_[:, 0:1], target[:, 0:1]
43        fg_loss = self.foreground_loss(fg_input, fg_target)
44
45        cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2]
46        if self.mask_distances_in_bg:
47            mask = fg_target
48            cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask)
49        else:
50            cdist_loss = self.distance_loss(cdist_input, cdist_target)
51
52        bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3]
53        if self.mask_distances_in_bg:
54            mask = fg_target
55            bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask)
56        else:
57            bdist_loss = self.distance_loss(bdist_input, bdist_target)
58
59        overall_loss = fg_loss + cdist_loss + bdist_loss
60        return overall_loss

Loss for distance based instance segmentation.

Expects input and targets with three channels: foreground and two distance channels. Typically the distance channels are centroid and inverted boundary distance.

Arguments:
  • mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
  • foreground_loss: the loss for comparing foreground predictions and target.
  • distance_loss: the loss for comparing distance predictions and target.
DistanceLoss( mask_distances_in_bg: bool = True, foreground_loss: torch.nn.modules.module.Module = DiceLoss(), distance_loss: torch.nn.modules.module.Module = MSELoss())
19    def __init__(
20        self,
21        mask_distances_in_bg: bool = True,
22        foreground_loss: nn.Module = DiceLoss(),
23        distance_loss: nn.Module = nn.MSELoss(reduction="mean")
24    ) -> None:
25        super().__init__()
26
27        self.foreground_loss = foreground_loss
28        self.distance_loss = distance_loss
29        self.mask_distances_in_bg = mask_distances_in_bg
30
31        self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

foreground_loss
distance_loss
mask_distances_in_bg
init_kwargs
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
33    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
34        assert input_.shape == target.shape, input_.shape
35        assert input_.shape[1] == 3, input_.shape
36
37        # IMPORTANT: preserve the channels!
38        # Otherwise the Dice Loss will do all kinds of shennanigans.
39        # Because it always interprets the first axis as channel,
40        # and treats it differently (sums over it independently).
41        # This will lead to a very large dice loss that dominates over everything else.
42        fg_input, fg_target = input_[:, 0:1], target[:, 0:1]
43        fg_loss = self.foreground_loss(fg_input, fg_target)
44
45        cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2]
46        if self.mask_distances_in_bg:
47            mask = fg_target
48            cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask)
49        else:
50            cdist_loss = self.distance_loss(cdist_input, cdist_target)
51
52        bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3]
53        if self.mask_distances_in_bg:
54            mask = fg_target
55            bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask)
56        else:
57            bdist_loss = self.distance_loss(bdist_input, bdist_target)
58
59        overall_loss = fg_loss + cdist_loss + bdist_loss
60        return overall_loss

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class DiceBasedDistanceLoss(DistanceLoss):
63class DiceBasedDistanceLoss(DistanceLoss):
64    """Similar to `DistanceLoss`, using the dice score for all losses.
65
66    Args:
67        mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
68    """
69    def __init__(self, mask_distances_in_bg: bool) -> None:
70        super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())

Similar to DistanceLoss, using the dice score for all losses.

Arguments:
  • mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
DiceBasedDistanceLoss(mask_distances_in_bg: bool)
69    def __init__(self, mask_distances_in_bg: bool) -> None:
70        super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())

Initialize internal Module state, shared by both nn.Module and ScriptModule.