torch_em.loss.combined_loss

 1from typing import List
 2
 3import torch
 4
 5
 6class CombinedLoss(torch.nn.Module):
 7    """Combination of multiple losses.
 8
 9    Args:
10        losses: The loss functions to combine.
11        loss_weights: The weights for the loss functions.
12    """
13    def __init__(self, *losses: torch.nn.Module, loss_weights: List[float] = None):
14        super().__init__()
15        self.losses = torch.nn.ModuleList(losses)
16        n_losses = len(self.losses)
17        if loss_weights is None:
18            try:
19                self.loss_weights = [1.0 / n_losses] * n_losses
20            except ZeroDivisionError:
21                self.loss_weights = None
22        else:
23            assert len(loss_weights) == n_losses
24            self.loss_weights = loss_weights
25
26    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
27        """Compute the combined loss.
28
29        Args:
30            x: The prediction.
31            y: The target.
32
33        Returns:
34            The loss value.
35        """
36        assert self.loss_weights is not None
37        loss_value = sum([loss(x, y) * weight for loss, weight in zip(self.losses, self.loss_weights)])
38        return loss_value
class CombinedLoss(torch.nn.modules.module.Module):
 7class CombinedLoss(torch.nn.Module):
 8    """Combination of multiple losses.
 9
10    Args:
11        losses: The loss functions to combine.
12        loss_weights: The weights for the loss functions.
13    """
14    def __init__(self, *losses: torch.nn.Module, loss_weights: List[float] = None):
15        super().__init__()
16        self.losses = torch.nn.ModuleList(losses)
17        n_losses = len(self.losses)
18        if loss_weights is None:
19            try:
20                self.loss_weights = [1.0 / n_losses] * n_losses
21            except ZeroDivisionError:
22                self.loss_weights = None
23        else:
24            assert len(loss_weights) == n_losses
25            self.loss_weights = loss_weights
26
27    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
28        """Compute the combined loss.
29
30        Args:
31            x: The prediction.
32            y: The target.
33
34        Returns:
35            The loss value.
36        """
37        assert self.loss_weights is not None
38        loss_value = sum([loss(x, y) * weight for loss, weight in zip(self.losses, self.loss_weights)])
39        return loss_value

Combination of multiple losses.

Arguments:
  • losses: The loss functions to combine.
  • loss_weights: The weights for the loss functions.
CombinedLoss( *losses: torch.nn.modules.module.Module, loss_weights: List[float] = None)
14    def __init__(self, *losses: torch.nn.Module, loss_weights: List[float] = None):
15        super().__init__()
16        self.losses = torch.nn.ModuleList(losses)
17        n_losses = len(self.losses)
18        if loss_weights is None:
19            try:
20                self.loss_weights = [1.0 / n_losses] * n_losses
21            except ZeroDivisionError:
22                self.loss_weights = None
23        else:
24            assert len(loss_weights) == n_losses
25            self.loss_weights = loss_weights

Initialize internal Module state, shared by both nn.Module and ScriptModule.

losses
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
27    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
28        """Compute the combined loss.
29
30        Args:
31            x: The prediction.
32            y: The target.
33
34        Returns:
35            The loss value.
36        """
37        assert self.loss_weights is not None
38        loss_value = sum([loss(x, y) * weight for loss, weight in zip(self.losses, self.loss_weights)])
39        return loss_value

Compute the combined loss.

Arguments:
  • x: The prediction.
  • y: The target.
Returns:

The loss value.