torch_em.loss.wrapper

  1from typing import Callable, Sequence, Tuple, Union
  2
  3import torch
  4import torch.nn as nn
  5
  6
  7class LossWrapper(nn.Module):
  8    """A wrapper around a torch loss function.
  9
 10    Applies transformations to prediction and/or target before passing it to the loss.
 11
 12    Args:
 13        loss: The loss function.
 14        transform: The transformation applied to prediction and/or target.
 15            Must take both the prediction and target as arguments and then return them, potentially transformed.
 16    """
 17    def __init__(self, loss: nn.Module, transform: Callable):
 18        super().__init__()
 19        self.loss = loss
 20
 21        if not callable(transform):
 22            raise ValueError("transform has to be callable.")
 23        self.transform = transform
 24        self.init_kwargs = {'loss': loss, 'transform': transform}
 25
 26    def apply_transform(self, prediction, target, **kwargs):
 27        """@private
 28        """
 29        # Check if the prediction and target are lists.
 30        # If they are, apply the transform to each element individually.
 31        if isinstance(prediction, (list, tuple)):
 32            assert isinstance(target, (list, tuple))
 33            transformed_prediction, transformed_target = [], []
 34            for pred, targ in zip(prediction, target):
 35                tr_pred, tr_targ = self.transform(pred, targ, **kwargs)
 36                transformed_prediction.append(tr_pred)
 37                transformed_target.append(tr_targ)
 38            return transformed_prediction, transformed_target
 39        # Otherwise, we expect that prediction and target are both tensors.
 40        else:
 41            prediction, target = self.transform(prediction, target, **kwargs)
 42            return prediction, target
 43
 44    def forward(
 45        self,
 46        prediction: Union[Sequence[torch.Tensor], torch.Tensor],
 47        target: Union[Sequence[torch.Tensor], torch.Tensor],
 48        **kwargs
 49    ) -> torch.Tensor:
 50        """Apply the tranformations to prediction and/or target before computing the loss.
 51
 52        Args:
 53            prediction: The prediction.
 54            target: The target.
 55            kwargs: Additional keyword arguments for the transformation.
 56
 57        Returns:
 58            The loss.
 59        """
 60        prediction, target = self.apply_transform(prediction, target, **kwargs)
 61        loss = self.loss(prediction, target)
 62        return loss
 63
 64
 65#
 66# Loss transformations
 67#
 68def _crop(prediction, target, mask, channel_dim):
 69    if mask.shape[channel_dim] != 1:
 70        raise ValueError(
 71            "_crop only supports a mask with a singleton channel axis. Please consider using masking_method=multiply."
 72        )
 73    mask = mask.type(torch.bool)
 74    # remove singleton axis
 75    mask = mask.squeeze(channel_dim)
 76    # move channel axis to end
 77    prediction = prediction.moveaxis(channel_dim, -1)
 78    target = target.moveaxis(channel_dim, -1)
 79    # output has shape N x C
 80    # correct for torch_em.loss.dice.flatten_samples
 81    return prediction[mask], target[mask]
 82
 83
 84def _multiply(prediction, target, mask, channel_dim):
 85    prediction = prediction * mask
 86    target = target * mask
 87    return prediction, target
 88
 89
 90class ApplyMask:
 91    """Apply a mask to prediction and target, so that only values in the mask are taken into account for the loss.
 92
 93    Supports two different masking methods:
 94    - 'crop': Crop away the mask from the prediction and target. This only works if the mask just has a single channel,
 95        and if the loss function does not require spatial inputs.
 96    - 'multiply': Multiply the prediction and target with zeros outside of the mask.
 97
 98    Args:
 99        masking_method: The masking method to use. Can be one of 'crop' or 'multiply'.
100        channel_dim: The dimension of the channel axis.
101    """
102    MASKING_FUNCS = {"crop": _crop, "multiply": _multiply}
103
104    def __init__(self, masking_method: str = "crop", channel_dim: int = 1):
105        if masking_method not in self.MASKING_FUNCS.keys():
106            raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.")
107        self.masking_func = self.MASKING_FUNCS[masking_method]
108        self.channel_dim = channel_dim
109        self.init_kwargs = {"masking_method": masking_method, "channel_dim": channel_dim}
110
111    def __call__(
112        self, prediction: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
113    ) -> Tuple[torch.Tensor, torch.Tensor]:
114        """Mask predictions.
115
116        Args:
117            prediction: The prediction tensor.
118            target: The target tensor.
119            mask: The mask tensor.
120
121        Returns:
122            The masked prediction.
123            The masked target.
124        """
125        mask.requires_grad = False
126        return self.masking_func(prediction, target, mask, self.channel_dim)
127
128
129class ApplyAndRemoveMask(ApplyMask):
130    """Extract mask from extra channels from a target tensor and use it to mask the prediction.
131
132    Supports the same masking methods as `ApplyMask`.
133    """
134    def __call__(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
135        """Remove masking channels from the target and then apply the mask.
136
137        Args:
138            prediction: The prediction tensor.
139            target: The target tensor, with extra channels that contain the mask.
140
141        Returns:
142            The masked prediction.
143            The masked target, with extra channels removed.
144        """
145        assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}"
146        assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}"
147        assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}"
148        seperating_channel = target.size(1) // 2
149        mask = target[:, seperating_channel:]
150        target = target[:, :seperating_channel]
151        prediction, target = super().__call__(prediction, target, mask)
152        return prediction, target
153
154
155class MaskIgnoreLabel(ApplyMask):
156    """Mask ignore label from the target.
157
158    Supports the same masking methods as `ApplyMask`.
159
160    Args:
161        ignore_label: The ignore label, which will be msaked.
162        masking_method: The masking method to use. Can be one of 'crop' or 'multiply'.
163        channel_dim: The dimension of the channel axis.
164    """
165    def __init__(self, ignore_label: int = -1, masking_method: str = "crop", channel_dim: int = 1):
166        super().__init__(masking_method, channel_dim)
167        self.ignore_label = ignore_label
168        self.init_kwargs["ignore_label"] = ignore_label
169
170    def __call__(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
171        """Mask ignore label in the prediction and target.
172
173        Args:
174            prediction: The prediction tensor.
175            target: The target tensor.
176
177        Returns:
178            The masked prediction.
179            The masked target.
180        """
181        mask = (target != self.ignore_label)
182        prediction, target = super().__call__(prediction, target, mask)
183        return prediction, target
class LossWrapper(torch.nn.modules.module.Module):
 8class LossWrapper(nn.Module):
 9    """A wrapper around a torch loss function.
10
11    Applies transformations to prediction and/or target before passing it to the loss.
12
13    Args:
14        loss: The loss function.
15        transform: The transformation applied to prediction and/or target.
16            Must take both the prediction and target as arguments and then return them, potentially transformed.
17    """
18    def __init__(self, loss: nn.Module, transform: Callable):
19        super().__init__()
20        self.loss = loss
21
22        if not callable(transform):
23            raise ValueError("transform has to be callable.")
24        self.transform = transform
25        self.init_kwargs = {'loss': loss, 'transform': transform}
26
27    def apply_transform(self, prediction, target, **kwargs):
28        """@private
29        """
30        # Check if the prediction and target are lists.
31        # If they are, apply the transform to each element individually.
32        if isinstance(prediction, (list, tuple)):
33            assert isinstance(target, (list, tuple))
34            transformed_prediction, transformed_target = [], []
35            for pred, targ in zip(prediction, target):
36                tr_pred, tr_targ = self.transform(pred, targ, **kwargs)
37                transformed_prediction.append(tr_pred)
38                transformed_target.append(tr_targ)
39            return transformed_prediction, transformed_target
40        # Otherwise, we expect that prediction and target are both tensors.
41        else:
42            prediction, target = self.transform(prediction, target, **kwargs)
43            return prediction, target
44
45    def forward(
46        self,
47        prediction: Union[Sequence[torch.Tensor], torch.Tensor],
48        target: Union[Sequence[torch.Tensor], torch.Tensor],
49        **kwargs
50    ) -> torch.Tensor:
51        """Apply the tranformations to prediction and/or target before computing the loss.
52
53        Args:
54            prediction: The prediction.
55            target: The target.
56            kwargs: Additional keyword arguments for the transformation.
57
58        Returns:
59            The loss.
60        """
61        prediction, target = self.apply_transform(prediction, target, **kwargs)
62        loss = self.loss(prediction, target)
63        return loss

A wrapper around a torch loss function.

Applies transformations to prediction and/or target before passing it to the loss.

Arguments:
  • loss: The loss function.
  • transform: The transformation applied to prediction and/or target. Must take both the prediction and target as arguments and then return them, potentially transformed.
LossWrapper(loss: torch.nn.modules.module.Module, transform: Callable)
18    def __init__(self, loss: nn.Module, transform: Callable):
19        super().__init__()
20        self.loss = loss
21
22        if not callable(transform):
23            raise ValueError("transform has to be callable.")
24        self.transform = transform
25        self.init_kwargs = {'loss': loss, 'transform': transform}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

loss
transform
init_kwargs
def forward( self, prediction: Union[Sequence[torch.Tensor], torch.Tensor], target: Union[Sequence[torch.Tensor], torch.Tensor], **kwargs) -> torch.Tensor:
45    def forward(
46        self,
47        prediction: Union[Sequence[torch.Tensor], torch.Tensor],
48        target: Union[Sequence[torch.Tensor], torch.Tensor],
49        **kwargs
50    ) -> torch.Tensor:
51        """Apply the tranformations to prediction and/or target before computing the loss.
52
53        Args:
54            prediction: The prediction.
55            target: The target.
56            kwargs: Additional keyword arguments for the transformation.
57
58        Returns:
59            The loss.
60        """
61        prediction, target = self.apply_transform(prediction, target, **kwargs)
62        loss = self.loss(prediction, target)
63        return loss

Apply the tranformations to prediction and/or target before computing the loss.

Arguments:
  • prediction: The prediction.
  • target: The target.
  • kwargs: Additional keyword arguments for the transformation.
Returns:

The loss.

class ApplyMask:
 91class ApplyMask:
 92    """Apply a mask to prediction and target, so that only values in the mask are taken into account for the loss.
 93
 94    Supports two different masking methods:
 95    - 'crop': Crop away the mask from the prediction and target. This only works if the mask just has a single channel,
 96        and if the loss function does not require spatial inputs.
 97    - 'multiply': Multiply the prediction and target with zeros outside of the mask.
 98
 99    Args:
100        masking_method: The masking method to use. Can be one of 'crop' or 'multiply'.
101        channel_dim: The dimension of the channel axis.
102    """
103    MASKING_FUNCS = {"crop": _crop, "multiply": _multiply}
104
105    def __init__(self, masking_method: str = "crop", channel_dim: int = 1):
106        if masking_method not in self.MASKING_FUNCS.keys():
107            raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.")
108        self.masking_func = self.MASKING_FUNCS[masking_method]
109        self.channel_dim = channel_dim
110        self.init_kwargs = {"masking_method": masking_method, "channel_dim": channel_dim}
111
112    def __call__(
113        self, prediction: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
114    ) -> Tuple[torch.Tensor, torch.Tensor]:
115        """Mask predictions.
116
117        Args:
118            prediction: The prediction tensor.
119            target: The target tensor.
120            mask: The mask tensor.
121
122        Returns:
123            The masked prediction.
124            The masked target.
125        """
126        mask.requires_grad = False
127        return self.masking_func(prediction, target, mask, self.channel_dim)

Apply a mask to prediction and target, so that only values in the mask are taken into account for the loss.

Supports two different masking methods:

  • 'crop': Crop away the mask from the prediction and target. This only works if the mask just has a single channel, and if the loss function does not require spatial inputs.
  • 'multiply': Multiply the prediction and target with zeros outside of the mask.
Arguments:
  • masking_method: The masking method to use. Can be one of 'crop' or 'multiply'.
  • channel_dim: The dimension of the channel axis.
ApplyMask(masking_method: str = 'crop', channel_dim: int = 1)
105    def __init__(self, masking_method: str = "crop", channel_dim: int = 1):
106        if masking_method not in self.MASKING_FUNCS.keys():
107            raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.")
108        self.masking_func = self.MASKING_FUNCS[masking_method]
109        self.channel_dim = channel_dim
110        self.init_kwargs = {"masking_method": masking_method, "channel_dim": channel_dim}
MASKING_FUNCS = {'crop': <function _crop>, 'multiply': <function _multiply>}
masking_func
channel_dim
init_kwargs
class ApplyAndRemoveMask(ApplyMask):
130class ApplyAndRemoveMask(ApplyMask):
131    """Extract mask from extra channels from a target tensor and use it to mask the prediction.
132
133    Supports the same masking methods as `ApplyMask`.
134    """
135    def __call__(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
136        """Remove masking channels from the target and then apply the mask.
137
138        Args:
139            prediction: The prediction tensor.
140            target: The target tensor, with extra channels that contain the mask.
141
142        Returns:
143            The masked prediction.
144            The masked target, with extra channels removed.
145        """
146        assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}"
147        assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}"
148        assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}"
149        seperating_channel = target.size(1) // 2
150        mask = target[:, seperating_channel:]
151        target = target[:, :seperating_channel]
152        prediction, target = super().__call__(prediction, target, mask)
153        return prediction, target

Extract mask from extra channels from a target tensor and use it to mask the prediction.

Supports the same masking methods as ApplyMask.

class MaskIgnoreLabel(ApplyMask):
156class MaskIgnoreLabel(ApplyMask):
157    """Mask ignore label from the target.
158
159    Supports the same masking methods as `ApplyMask`.
160
161    Args:
162        ignore_label: The ignore label, which will be msaked.
163        masking_method: The masking method to use. Can be one of 'crop' or 'multiply'.
164        channel_dim: The dimension of the channel axis.
165    """
166    def __init__(self, ignore_label: int = -1, masking_method: str = "crop", channel_dim: int = 1):
167        super().__init__(masking_method, channel_dim)
168        self.ignore_label = ignore_label
169        self.init_kwargs["ignore_label"] = ignore_label
170
171    def __call__(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
172        """Mask ignore label in the prediction and target.
173
174        Args:
175            prediction: The prediction tensor.
176            target: The target tensor.
177
178        Returns:
179            The masked prediction.
180            The masked target.
181        """
182        mask = (target != self.ignore_label)
183        prediction, target = super().__call__(prediction, target, mask)
184        return prediction, target

Mask ignore label from the target.

Supports the same masking methods as ApplyMask.

Arguments:
  • ignore_label: The ignore label, which will be msaked.
  • masking_method: The masking method to use. Can be one of 'crop' or 'multiply'.
  • channel_dim: The dimension of the channel axis.
MaskIgnoreLabel( ignore_label: int = -1, masking_method: str = 'crop', channel_dim: int = 1)
166    def __init__(self, ignore_label: int = -1, masking_method: str = "crop", channel_dim: int = 1):
167        super().__init__(masking_method, channel_dim)
168        self.ignore_label = ignore_label
169        self.init_kwargs["ignore_label"] = ignore_label
ignore_label