torch_em.loss.wrapper

  1import torch
  2import torch.nn as nn
  3
  4
  5class LossWrapper(nn.Module):
  6    """ Wrapper around a torch loss function.
  7
  8    Applies transformations to prediction and/or target before passing it to the loss.
  9    """
 10    def __init__(self, loss, transform):
 11        super().__init__()
 12        self.loss = loss
 13
 14        if not callable(transform):
 15            raise ValueError("transform has to be callable.")
 16        self.transform = transform
 17        self.init_kwargs = {'loss': loss, 'transform': transform}
 18
 19    def apply_transform(self, prediction, target, **kwargs):
 20        # check if the tensors (prediction and target are lists)
 21        # if they are, apply the transform to each element inidvidually
 22        if isinstance(prediction, (list, tuple)):
 23            assert isinstance(target, (list, tuple))
 24            transformed_prediction, transformed_target = [], []
 25            for pred, targ in zip(prediction, target):
 26                tr_pred, tr_targ = self.transform(pred, targ, **kwargs)
 27                transformed_prediction.append(tr_pred)
 28                transformed_target.append(tr_targ)
 29            return transformed_prediction, transformed_target
 30        # tensor input
 31        else:
 32            prediction, target = self.transform(prediction, target, **kwargs)
 33            return prediction, target
 34
 35    def forward(self, prediction, target, **kwargs):
 36        prediction, target = self.apply_transform(prediction, target, **kwargs)
 37        loss = self.loss(prediction, target)
 38        return loss
 39
 40
 41#
 42# Loss transformations
 43#
 44
 45
 46class ApplyMask:
 47    def _crop(prediction, target, mask, channel_dim):
 48        if mask.shape[channel_dim] != 1:
 49            raise ValueError(
 50                "_crop only supports a mask with a singleton channel axis. \
 51                Please consider using masking_method=multiply."
 52            )
 53        mask = mask.type(torch.bool)
 54        # remove singleton axis
 55        mask = mask.squeeze(channel_dim)
 56        # move channel axis to end
 57        prediction = prediction.moveaxis(channel_dim, -1)
 58        target = target.moveaxis(channel_dim, -1)
 59        # output has shape N x C
 60        # correct for torch_em.loss.dice.flatten_samples
 61        return prediction[mask], target[mask]
 62
 63    def _multiply(prediction, target, mask, channel_dim):
 64        prediction = prediction * mask
 65        target = target * mask
 66        return prediction, target
 67
 68    MASKING_FUNCS = {
 69        "crop": _crop,
 70        "multiply": _multiply,
 71    }
 72
 73    def __init__(self, masking_method="crop", channel_dim=1):
 74        if masking_method not in self.MASKING_FUNCS.keys():
 75            raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.")
 76        self.masking_func = self.MASKING_FUNCS[masking_method]
 77        self.channel_dim = channel_dim
 78
 79        self.init_kwargs = {
 80            "masking_method": masking_method,
 81            "channel_dim": channel_dim,
 82        }
 83
 84    def __call__(self, prediction, target, mask):
 85        mask.requires_grad = False
 86        return self.masking_func(prediction, target, mask, self.channel_dim)
 87
 88
 89class ApplyAndRemoveMask(ApplyMask):
 90    def __call__(self, prediction, target):
 91        assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}"
 92        assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}"
 93        assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}"
 94        seperating_channel = target.size(1) // 2
 95        mask = target[:, seperating_channel:]
 96        target = target[:, :seperating_channel]
 97        prediction, target = super().__call__(prediction, target, mask)
 98        return prediction, target
 99
100
101class MaskIgnoreLabel(ApplyMask):
102    def __init__(self, ignore_label=-1, masking_method="crop", channel_dim=1):
103        super().__init__(masking_method, channel_dim)
104        self.ignore_label = ignore_label
105        self.init_kwargs["ignore_label"] = ignore_label
106
107    def __call__(self, prediction, target):
108        mask = (target != self.ignore_label)
109        prediction, target = super().__call__(prediction, target, mask)
110        return prediction, target
class LossWrapper(torch.nn.modules.module.Module):
 6class LossWrapper(nn.Module):
 7    """ Wrapper around a torch loss function.
 8
 9    Applies transformations to prediction and/or target before passing it to the loss.
10    """
11    def __init__(self, loss, transform):
12        super().__init__()
13        self.loss = loss
14
15        if not callable(transform):
16            raise ValueError("transform has to be callable.")
17        self.transform = transform
18        self.init_kwargs = {'loss': loss, 'transform': transform}
19
20    def apply_transform(self, prediction, target, **kwargs):
21        # check if the tensors (prediction and target are lists)
22        # if they are, apply the transform to each element inidvidually
23        if isinstance(prediction, (list, tuple)):
24            assert isinstance(target, (list, tuple))
25            transformed_prediction, transformed_target = [], []
26            for pred, targ in zip(prediction, target):
27                tr_pred, tr_targ = self.transform(pred, targ, **kwargs)
28                transformed_prediction.append(tr_pred)
29                transformed_target.append(tr_targ)
30            return transformed_prediction, transformed_target
31        # tensor input
32        else:
33            prediction, target = self.transform(prediction, target, **kwargs)
34            return prediction, target
35
36    def forward(self, prediction, target, **kwargs):
37        prediction, target = self.apply_transform(prediction, target, **kwargs)
38        loss = self.loss(prediction, target)
39        return loss

Wrapper around a torch loss function.

Applies transformations to prediction and/or target before passing it to the loss.

LossWrapper(loss, transform)
11    def __init__(self, loss, transform):
12        super().__init__()
13        self.loss = loss
14
15        if not callable(transform):
16            raise ValueError("transform has to be callable.")
17        self.transform = transform
18        self.init_kwargs = {'loss': loss, 'transform': transform}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

loss
transform
init_kwargs
def apply_transform(self, prediction, target, **kwargs):
20    def apply_transform(self, prediction, target, **kwargs):
21        # check if the tensors (prediction and target are lists)
22        # if they are, apply the transform to each element inidvidually
23        if isinstance(prediction, (list, tuple)):
24            assert isinstance(target, (list, tuple))
25            transformed_prediction, transformed_target = [], []
26            for pred, targ in zip(prediction, target):
27                tr_pred, tr_targ = self.transform(pred, targ, **kwargs)
28                transformed_prediction.append(tr_pred)
29                transformed_target.append(tr_targ)
30            return transformed_prediction, transformed_target
31        # tensor input
32        else:
33            prediction, target = self.transform(prediction, target, **kwargs)
34            return prediction, target
def forward(self, prediction, target, **kwargs):
36    def forward(self, prediction, target, **kwargs):
37        prediction, target = self.apply_transform(prediction, target, **kwargs)
38        loss = self.loss(prediction, target)
39        return loss

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ApplyMask:
47class ApplyMask:
48    def _crop(prediction, target, mask, channel_dim):
49        if mask.shape[channel_dim] != 1:
50            raise ValueError(
51                "_crop only supports a mask with a singleton channel axis. \
52                Please consider using masking_method=multiply."
53            )
54        mask = mask.type(torch.bool)
55        # remove singleton axis
56        mask = mask.squeeze(channel_dim)
57        # move channel axis to end
58        prediction = prediction.moveaxis(channel_dim, -1)
59        target = target.moveaxis(channel_dim, -1)
60        # output has shape N x C
61        # correct for torch_em.loss.dice.flatten_samples
62        return prediction[mask], target[mask]
63
64    def _multiply(prediction, target, mask, channel_dim):
65        prediction = prediction * mask
66        target = target * mask
67        return prediction, target
68
69    MASKING_FUNCS = {
70        "crop": _crop,
71        "multiply": _multiply,
72    }
73
74    def __init__(self, masking_method="crop", channel_dim=1):
75        if masking_method not in self.MASKING_FUNCS.keys():
76            raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.")
77        self.masking_func = self.MASKING_FUNCS[masking_method]
78        self.channel_dim = channel_dim
79
80        self.init_kwargs = {
81            "masking_method": masking_method,
82            "channel_dim": channel_dim,
83        }
84
85    def __call__(self, prediction, target, mask):
86        mask.requires_grad = False
87        return self.masking_func(prediction, target, mask, self.channel_dim)
ApplyMask(masking_method='crop', channel_dim=1)
74    def __init__(self, masking_method="crop", channel_dim=1):
75        if masking_method not in self.MASKING_FUNCS.keys():
76            raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.")
77        self.masking_func = self.MASKING_FUNCS[masking_method]
78        self.channel_dim = channel_dim
79
80        self.init_kwargs = {
81            "masking_method": masking_method,
82            "channel_dim": channel_dim,
83        }
MASKING_FUNCS = {'crop': <function ApplyMask._crop>, 'multiply': <function ApplyMask._multiply>}
masking_func
channel_dim
init_kwargs
class ApplyAndRemoveMask(ApplyMask):
90class ApplyAndRemoveMask(ApplyMask):
91    def __call__(self, prediction, target):
92        assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}"
93        assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}"
94        assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}"
95        seperating_channel = target.size(1) // 2
96        mask = target[:, seperating_channel:]
97        target = target[:, :seperating_channel]
98        prediction, target = super().__call__(prediction, target, mask)
99        return prediction, target
class MaskIgnoreLabel(ApplyMask):
102class MaskIgnoreLabel(ApplyMask):
103    def __init__(self, ignore_label=-1, masking_method="crop", channel_dim=1):
104        super().__init__(masking_method, channel_dim)
105        self.ignore_label = ignore_label
106        self.init_kwargs["ignore_label"] = ignore_label
107
108    def __call__(self, prediction, target):
109        mask = (target != self.ignore_label)
110        prediction, target = super().__call__(prediction, target, mask)
111        return prediction, target
MaskIgnoreLabel(ignore_label=-1, masking_method='crop', channel_dim=1)
103    def __init__(self, ignore_label=-1, masking_method="crop", channel_dim=1):
104        super().__init__(masking_method, channel_dim)
105        self.ignore_label = ignore_label
106        self.init_kwargs["ignore_label"] = ignore_label
ignore_label