torch_em.loss.affinity_side_loss

  1import numpy as np
  2import torch
  3import torch.nn as nn
  4from .dice import dice_score
  5
  6
  7def shift_tensor(tensor, offset):
  8    """ Shift a tensor by the given (spatial) offset.
  9    Arguments:
 10        tensor [torch.Tensor] - 4D (=2 spatial dims) or 5D (=3 spatial dims) tensor.
 11            Needs to be of float type.
 12        offset (tuple) - 2d or 3d spatial offset used for shifting the tensor
 13    """
 14
 15    ndim = len(offset)
 16    assert ndim in (2, 3)
 17    diff = tensor.dim() - ndim
 18
 19    # don't pad for the first dimensions
 20    # (usually batch and/or channel dimension)
 21    slice_ = diff * [slice(None)]
 22
 23    # torch padding behaviour is a bit weird.
 24    # we use nn.ReplicationPadND
 25    # (torch.nn.functional.pad is even weirder and ReflectionPad is not supported in 3d)
 26    # still, padding needs to be given in the inverse spatial order
 27
 28    # add padding in inverse spatial order
 29    padding = []
 30    for off in offset[::-1]:
 31        # if we have a negative offset, we need to shift "to the left",
 32        # which means padding at the right border
 33        # if we have a positive offset, we need to shift "to the right",
 34        # which means padding to the left border
 35        padding.extend([max(0, off), max(0, -off)])
 36
 37    # add slicing in the normal spatial order
 38    for off in offset:
 39        if off == 0:
 40            slice_.append(slice(None))
 41        elif off > 0:
 42            slice_.append(slice(None, -off))
 43        else:
 44            slice_.append(slice(-off, None))
 45
 46    # pad the spatial part of the tensor with replication padding
 47    slice_ = tuple(slice_)
 48    padding = tuple(padding)
 49    padder = nn.ReplicationPad2d if ndim == 2 else nn.ReplicationPad3d
 50    padder = padder(padding)
 51    shifted = padder(tensor)
 52
 53    # slice the oadded tensor to get the spatially shifted tensor
 54    shifted = shifted[slice_]
 55    assert shifted.shape == tensor.shape
 56
 57    return shifted
 58
 59
 60def invert_offsets(offsets):
 61    return [[-off for off in offset] for offset in offsets]
 62
 63
 64def segmentation_to_affinities(segmentation, offsets):
 65    """ Transform segmentation to affinities.
 66    Arguments:
 67        segmentation [torch.tensor] - 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor.
 68            The channel axis (= dimension 1) needs to be a singleton.
 69        offsets [list[tuple]] - list of offsets for which to compute the affinities.
 70    """
 71    assert segmentation.shape[1] == 1, f"{segmentation.shape}"
 72    # shift the segmentation and substract the shifted tensor from the segmentation
 73    # we need to shift in the opposite direction of the offsets, so we invert them
 74    # before applying the shift
 75    offsets_ = invert_offsets(offsets)
 76    shifted = torch.cat([shift_tensor(segmentation.float(), off) for off in offsets_], dim=1)
 77    affs = (segmentation - shifted)
 78    # the affinities are 1, where we had the same segment id (the difference is 0)
 79    # and 0 otherwise
 80    affs.eq_(0.)
 81    return affs
 82
 83
 84def embeddings_to_affinities(embeddings, offsets, delta):
 85    """ Transform embeddings to affinities.
 86    """
 87    # shift the embeddings by the offsets and stack them along a new axis
 88    # we need to shift in the opposite direction of the offsets, so we invert them
 89    # before applying the shift
 90    offsets_ = invert_offsets(offsets)
 91    shifted = torch.cat([shift_tensor(embeddings, off).unsqueeze(1) for off in offsets_], dim=1)
 92    # substract the embeddings from the shifted embeddings, take the norm and
 93    # transform to affinities based on the delta distance
 94    affs = (2 * delta - torch.norm(embeddings.unsqueeze(1) - shifted, dim=2)) / (2 * delta)
 95    affs = torch.clamp(affs, min=0) ** 2
 96    return affs
 97
 98
 99class AffinitySideLoss(nn.Module):
100    def __init__(self, offset_ranges, n_samples, delta):
101        assert all(len(orange) == 2 for orange in offset_ranges)
102        super().__init__()
103        self.ndim = len(offset_ranges)
104        self.offset_ranges = offset_ranges
105        self.n_samples = n_samples
106        self.delta = delta
107
108    def __call__(
109        self,
110        input_,
111        target,
112        ignore_labels=None,
113        ignore_in_variance_term=None,
114        ignore_in_distance_term=None,
115    ):
116        assert input_.dim() == target.dim(), f"{input_.dim()}, {target.dim()}"
117        assert input_.shape[2:] == target.shape[2:]
118
119        # sample offsets
120        offsets = [[np.random.randint(orange[0], orange[1]) for orange in self.offset_ranges]
121                   for _ in range(self.n_samples)]
122
123        # we invert the affinities and the target affinities
124        # so that we get boundaries as foreground, which is benefitial for the dice loss.
125        # compute affinities from emebeddings
126        affs = 1. - embeddings_to_affinities(input_, offsets, self.delta)
127
128        # compute groundtruth affinities from target
129        target_affs = 1. - segmentation_to_affinities(target, offsets)
130        assert affs.shape == target_affs.shape, f"{affs.shape}, {target_affs.shape}"
131
132        # TODO implement masking the ignore labels
133        # compute the dice score between affinities and target affinities
134        loss = dice_score(affs, target_affs, invert=True)
135        return loss
def shift_tensor(tensor, offset):
 8def shift_tensor(tensor, offset):
 9    """ Shift a tensor by the given (spatial) offset.
10    Arguments:
11        tensor [torch.Tensor] - 4D (=2 spatial dims) or 5D (=3 spatial dims) tensor.
12            Needs to be of float type.
13        offset (tuple) - 2d or 3d spatial offset used for shifting the tensor
14    """
15
16    ndim = len(offset)
17    assert ndim in (2, 3)
18    diff = tensor.dim() - ndim
19
20    # don't pad for the first dimensions
21    # (usually batch and/or channel dimension)
22    slice_ = diff * [slice(None)]
23
24    # torch padding behaviour is a bit weird.
25    # we use nn.ReplicationPadND
26    # (torch.nn.functional.pad is even weirder and ReflectionPad is not supported in 3d)
27    # still, padding needs to be given in the inverse spatial order
28
29    # add padding in inverse spatial order
30    padding = []
31    for off in offset[::-1]:
32        # if we have a negative offset, we need to shift "to the left",
33        # which means padding at the right border
34        # if we have a positive offset, we need to shift "to the right",
35        # which means padding to the left border
36        padding.extend([max(0, off), max(0, -off)])
37
38    # add slicing in the normal spatial order
39    for off in offset:
40        if off == 0:
41            slice_.append(slice(None))
42        elif off > 0:
43            slice_.append(slice(None, -off))
44        else:
45            slice_.append(slice(-off, None))
46
47    # pad the spatial part of the tensor with replication padding
48    slice_ = tuple(slice_)
49    padding = tuple(padding)
50    padder = nn.ReplicationPad2d if ndim == 2 else nn.ReplicationPad3d
51    padder = padder(padding)
52    shifted = padder(tensor)
53
54    # slice the oadded tensor to get the spatially shifted tensor
55    shifted = shifted[slice_]
56    assert shifted.shape == tensor.shape
57
58    return shifted

Shift a tensor by the given (spatial) offset.

Arguments:
  • tensor [torch.Tensor] - 4D (=2 spatial dims) or 5D (=3 spatial dims) tensor. Needs to be of float type.
  • offset (tuple) - 2d or 3d spatial offset used for shifting the tensor
def invert_offsets(offsets):
61def invert_offsets(offsets):
62    return [[-off for off in offset] for offset in offsets]
def segmentation_to_affinities(segmentation, offsets):
65def segmentation_to_affinities(segmentation, offsets):
66    """ Transform segmentation to affinities.
67    Arguments:
68        segmentation [torch.tensor] - 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor.
69            The channel axis (= dimension 1) needs to be a singleton.
70        offsets [list[tuple]] - list of offsets for which to compute the affinities.
71    """
72    assert segmentation.shape[1] == 1, f"{segmentation.shape}"
73    # shift the segmentation and substract the shifted tensor from the segmentation
74    # we need to shift in the opposite direction of the offsets, so we invert them
75    # before applying the shift
76    offsets_ = invert_offsets(offsets)
77    shifted = torch.cat([shift_tensor(segmentation.float(), off) for off in offsets_], dim=1)
78    affs = (segmentation - shifted)
79    # the affinities are 1, where we had the same segment id (the difference is 0)
80    # and 0 otherwise
81    affs.eq_(0.)
82    return affs

Transform segmentation to affinities.

Arguments:
  • segmentation [torch.tensor] - 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor. The channel axis (= dimension 1) needs to be a singleton.
  • offsets [list[tuple]] - list of offsets for which to compute the affinities.
def embeddings_to_affinities(embeddings, offsets, delta):
85def embeddings_to_affinities(embeddings, offsets, delta):
86    """ Transform embeddings to affinities.
87    """
88    # shift the embeddings by the offsets and stack them along a new axis
89    # we need to shift in the opposite direction of the offsets, so we invert them
90    # before applying the shift
91    offsets_ = invert_offsets(offsets)
92    shifted = torch.cat([shift_tensor(embeddings, off).unsqueeze(1) for off in offsets_], dim=1)
93    # substract the embeddings from the shifted embeddings, take the norm and
94    # transform to affinities based on the delta distance
95    affs = (2 * delta - torch.norm(embeddings.unsqueeze(1) - shifted, dim=2)) / (2 * delta)
96    affs = torch.clamp(affs, min=0) ** 2
97    return affs

Transform embeddings to affinities.

class AffinitySideLoss(torch.nn.modules.module.Module):
100class AffinitySideLoss(nn.Module):
101    def __init__(self, offset_ranges, n_samples, delta):
102        assert all(len(orange) == 2 for orange in offset_ranges)
103        super().__init__()
104        self.ndim = len(offset_ranges)
105        self.offset_ranges = offset_ranges
106        self.n_samples = n_samples
107        self.delta = delta
108
109    def __call__(
110        self,
111        input_,
112        target,
113        ignore_labels=None,
114        ignore_in_variance_term=None,
115        ignore_in_distance_term=None,
116    ):
117        assert input_.dim() == target.dim(), f"{input_.dim()}, {target.dim()}"
118        assert input_.shape[2:] == target.shape[2:]
119
120        # sample offsets
121        offsets = [[np.random.randint(orange[0], orange[1]) for orange in self.offset_ranges]
122                   for _ in range(self.n_samples)]
123
124        # we invert the affinities and the target affinities
125        # so that we get boundaries as foreground, which is benefitial for the dice loss.
126        # compute affinities from emebeddings
127        affs = 1. - embeddings_to_affinities(input_, offsets, self.delta)
128
129        # compute groundtruth affinities from target
130        target_affs = 1. - segmentation_to_affinities(target, offsets)
131        assert affs.shape == target_affs.shape, f"{affs.shape}, {target_affs.shape}"
132
133        # TODO implement masking the ignore labels
134        # compute the dice score between affinities and target affinities
135        loss = dice_score(affs, target_affs, invert=True)
136        return loss

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AffinitySideLoss(offset_ranges, n_samples, delta)
101    def __init__(self, offset_ranges, n_samples, delta):
102        assert all(len(orange) == 2 for orange in offset_ranges)
103        super().__init__()
104        self.ndim = len(offset_ranges)
105        self.offset_ranges = offset_ranges
106        self.n_samples = n_samples
107        self.delta = delta

Initializes internal Module state, shared by both nn.Module and ScriptModule.

ndim
offset_ranges
n_samples
delta
Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile