torch_em.loss.affinity_side_loss

  1from typing import List, Optional, Tuple
  2
  3import numpy as np
  4import torch
  5import torch.nn as nn
  6from .dice import dice_score
  7
  8
  9def shift_tensor(tensor: torch.Tensor, offset: List[int]) -> torch.Tensor:
 10    """Shift a tensor by the given spatial offset.
 11
 12    Args:
 13        tensor: A 4D (2 spatial dims) or 5D (3 spatial dims) tensor. Needs to be of float type.
 14        offset: A 2d or 3d spatial offset used for shifting the tensor
 15
 16    Returns:
 17        The shifted tensor.
 18    """
 19    ndim = len(offset)
 20    assert ndim in (2, 3)
 21    diff = tensor.dim() - ndim
 22
 23    # don't pad for the first dimensions
 24    # (usually batch and/or channel dimension)
 25    slice_ = diff * [slice(None)]
 26
 27    # torch padding behaviour is a bit weird.
 28    # we use nn.ReplicationPadND
 29    # (torch.nn.functional.pad is even weirder and ReflectionPad is not supported in 3d)
 30    # still, padding needs to be given in the inverse spatial order
 31
 32    # add padding in inverse spatial order
 33    padding = []
 34    for off in offset[::-1]:
 35        # if we have a negative offset, we need to shift "to the left",
 36        # which means padding at the right border
 37        # if we have a positive offset, we need to shift "to the right",
 38        # which means padding to the left border
 39        padding.extend([max(0, off), max(0, -off)])
 40
 41    # add slicing in the normal spatial order
 42    for off in offset:
 43        if off == 0:
 44            slice_.append(slice(None))
 45        elif off > 0:
 46            slice_.append(slice(None, -off))
 47        else:
 48            slice_.append(slice(-off, None))
 49
 50    # pad the spatial part of the tensor with replication padding
 51    slice_ = tuple(slice_)
 52    padding = tuple(padding)
 53    padder = nn.ReplicationPad2d if ndim == 2 else nn.ReplicationPad3d
 54    padder = padder(padding)
 55    shifted = padder(tensor)
 56
 57    # slice the oadded tensor to get the spatially shifted tensor
 58    shifted = shifted[slice_]
 59    assert shifted.shape == tensor.shape
 60
 61    return shifted
 62
 63
 64def invert_offsets(offsets):
 65    """@private
 66    """
 67    return [[-off for off in offset] for offset in offsets]
 68
 69
 70def segmentation_to_affinities(segmentation: torch.Tensor, offsets: List[List[int]]) -> torch.Tensor:
 71    """Transform segmentation to affinities.
 72
 73    Args:
 74        segmentation: A 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor.
 75            The channel axis (= dimension 1) needs to be a singleton.
 76        offsets: List of offsets for which to compute the affinities.
 77
 78    Returns:
 79        The affinities.
 80    """
 81    assert segmentation.shape[1] == 1, f"{segmentation.shape}"
 82    # Shift the segmentation and substract the shifted tensor from the segmentation.
 83    # We need to shift in the opposite direction of the offsets, so we invert them before applying the shift.
 84    offsets_ = invert_offsets(offsets)
 85    shifted = torch.cat([shift_tensor(segmentation.float(), off) for off in offsets_], dim=1)
 86    affs = (segmentation - shifted)
 87    # The affinities are 1, where we had the same segment id (the difference is 0) and 0 otherwise.
 88    affs.eq_(0.)
 89    return affs
 90
 91
 92def embeddings_to_affinities(embeddings: torch.Tensor, offsets: List[List[int]], delta: float) -> torch.Tensor:
 93    """Transform embeddings to affinities.
 94
 95    Args:
 96        embeddings: The pixel-wise embeddings.
 97        offsets: The offsets for computing affinities.
 98        delta: The push force hinge used for training the embedding prediction network.
 99
100    Returns:
101        The affinities.
102    """
103    # Shift the embeddings by the offsets and stack them along a new axis.
104    # We need to shift in the opposite direction of the offsets, so we invert them before applying the shift.
105    offsets_ = invert_offsets(offsets)
106    shifted = torch.cat([shift_tensor(embeddings, off).unsqueeze(1) for off in offsets_], dim=1)
107    # Substract the embeddings from the shifted embeddings, take the norm and
108    # transform to affinities based on the delta distance.
109    affs = (2 * delta - torch.norm(embeddings.unsqueeze(1) - shifted, dim=2)) / (2 * delta)
110    affs = torch.clamp(affs, min=0) ** 2
111    return affs
112
113
114class AffinitySideLoss(nn.Module):
115    """Loss computed between affinities derived from predicted embeddings and a target segmentation.
116
117    The offsets for the affinities will be derived randomly from the given `offset_ranges`.
118
119    Args:
120        offset_ranges: Ranges for the offsets to sampled.
121        n_samples: Number of offsets to sample per loss computation.
122        delta: The push force hinge used for training the embedding prediction network.
123    """
124    def __init__(self, offset_ranges: List[Tuple[int, int]], n_samples: int, delta: float):
125        assert all(len(orange) == 2 for orange in offset_ranges)
126        super().__init__()
127        self.ndim = len(offset_ranges)
128        self.offset_ranges = offset_ranges
129        self.n_samples = n_samples
130        self.delta = delta
131
132    def __call__(
133        self,
134        input_: torch.Tensor,
135        target: torch.Tensor,
136        ignore_labels: Optional[List[int]] = None,
137        ignore_in_variance_term: Optional[List[int]] = None,
138        ignore_in_distance_term: Optional[List[int]] = None,
139    ) -> torch.Tensor:
140        """Compute loss between affinities derived from predicted embeddings and a target segmentation.
141
142        Note: Support for the ignore labels is currently not implemented.
143
144        Args:
145            input_: The predicted embeddings.
146            target: The target segmentation.
147            ignore_labels: Ignore labels for the loss computation.
148            ignore_in_variance_term: Ignore labels for the variance term.
149            ignore_in_distance_term: Ignore labels for the distance term.
150
151        Returns:
152            The affinity loss value.
153        """
154        assert input_.dim() == target.dim(), f"{input_.dim()}, {target.dim()}"
155        assert input_.shape[2:] == target.shape[2:]
156
157        # Sample the offsets.
158        offsets = [[np.random.randint(orange[0], orange[1]) for orange in self.offset_ranges]
159                   for _ in range(self.n_samples)]
160
161        # We invert the affinities and the target affinities,
162        # so that we get boundaries as foreground, which is benefitial for the dice loss.
163        # Compute affinities from emebeddings.
164        affs = 1. - embeddings_to_affinities(input_, offsets, self.delta)
165
166        # Compute groundtruth affinities from the target segmentation.
167        target_affs = 1. - segmentation_to_affinities(target, offsets)
168        assert affs.shape == target_affs.shape, f"{affs.shape}, {target_affs.shape}"
169
170        # TODO implement masking the ignore labels
171        # Compute the dice score between affinities and target affinities.
172        return dice_score(affs, target_affs, invert=True)
def shift_tensor(tensor: torch.Tensor, offset: List[int]) -> torch.Tensor:
10def shift_tensor(tensor: torch.Tensor, offset: List[int]) -> torch.Tensor:
11    """Shift a tensor by the given spatial offset.
12
13    Args:
14        tensor: A 4D (2 spatial dims) or 5D (3 spatial dims) tensor. Needs to be of float type.
15        offset: A 2d or 3d spatial offset used for shifting the tensor
16
17    Returns:
18        The shifted tensor.
19    """
20    ndim = len(offset)
21    assert ndim in (2, 3)
22    diff = tensor.dim() - ndim
23
24    # don't pad for the first dimensions
25    # (usually batch and/or channel dimension)
26    slice_ = diff * [slice(None)]
27
28    # torch padding behaviour is a bit weird.
29    # we use nn.ReplicationPadND
30    # (torch.nn.functional.pad is even weirder and ReflectionPad is not supported in 3d)
31    # still, padding needs to be given in the inverse spatial order
32
33    # add padding in inverse spatial order
34    padding = []
35    for off in offset[::-1]:
36        # if we have a negative offset, we need to shift "to the left",
37        # which means padding at the right border
38        # if we have a positive offset, we need to shift "to the right",
39        # which means padding to the left border
40        padding.extend([max(0, off), max(0, -off)])
41
42    # add slicing in the normal spatial order
43    for off in offset:
44        if off == 0:
45            slice_.append(slice(None))
46        elif off > 0:
47            slice_.append(slice(None, -off))
48        else:
49            slice_.append(slice(-off, None))
50
51    # pad the spatial part of the tensor with replication padding
52    slice_ = tuple(slice_)
53    padding = tuple(padding)
54    padder = nn.ReplicationPad2d if ndim == 2 else nn.ReplicationPad3d
55    padder = padder(padding)
56    shifted = padder(tensor)
57
58    # slice the oadded tensor to get the spatially shifted tensor
59    shifted = shifted[slice_]
60    assert shifted.shape == tensor.shape
61
62    return shifted

Shift a tensor by the given spatial offset.

Arguments:
  • tensor: A 4D (2 spatial dims) or 5D (3 spatial dims) tensor. Needs to be of float type.
  • offset: A 2d or 3d spatial offset used for shifting the tensor
Returns:

The shifted tensor.

def segmentation_to_affinities(segmentation: torch.Tensor, offsets: List[List[int]]) -> torch.Tensor:
71def segmentation_to_affinities(segmentation: torch.Tensor, offsets: List[List[int]]) -> torch.Tensor:
72    """Transform segmentation to affinities.
73
74    Args:
75        segmentation: A 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor.
76            The channel axis (= dimension 1) needs to be a singleton.
77        offsets: List of offsets for which to compute the affinities.
78
79    Returns:
80        The affinities.
81    """
82    assert segmentation.shape[1] == 1, f"{segmentation.shape}"
83    # Shift the segmentation and substract the shifted tensor from the segmentation.
84    # We need to shift in the opposite direction of the offsets, so we invert them before applying the shift.
85    offsets_ = invert_offsets(offsets)
86    shifted = torch.cat([shift_tensor(segmentation.float(), off) for off in offsets_], dim=1)
87    affs = (segmentation - shifted)
88    # The affinities are 1, where we had the same segment id (the difference is 0) and 0 otherwise.
89    affs.eq_(0.)
90    return affs

Transform segmentation to affinities.

Arguments:
  • segmentation: A 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor. The channel axis (= dimension 1) needs to be a singleton.
  • offsets: List of offsets for which to compute the affinities.
Returns:

The affinities.

def embeddings_to_affinities( embeddings: torch.Tensor, offsets: List[List[int]], delta: float) -> torch.Tensor:
 93def embeddings_to_affinities(embeddings: torch.Tensor, offsets: List[List[int]], delta: float) -> torch.Tensor:
 94    """Transform embeddings to affinities.
 95
 96    Args:
 97        embeddings: The pixel-wise embeddings.
 98        offsets: The offsets for computing affinities.
 99        delta: The push force hinge used for training the embedding prediction network.
100
101    Returns:
102        The affinities.
103    """
104    # Shift the embeddings by the offsets and stack them along a new axis.
105    # We need to shift in the opposite direction of the offsets, so we invert them before applying the shift.
106    offsets_ = invert_offsets(offsets)
107    shifted = torch.cat([shift_tensor(embeddings, off).unsqueeze(1) for off in offsets_], dim=1)
108    # Substract the embeddings from the shifted embeddings, take the norm and
109    # transform to affinities based on the delta distance.
110    affs = (2 * delta - torch.norm(embeddings.unsqueeze(1) - shifted, dim=2)) / (2 * delta)
111    affs = torch.clamp(affs, min=0) ** 2
112    return affs

Transform embeddings to affinities.

Arguments:
  • embeddings: The pixel-wise embeddings.
  • offsets: The offsets for computing affinities.
  • delta: The push force hinge used for training the embedding prediction network.
Returns:

The affinities.

class AffinitySideLoss(torch.nn.modules.module.Module):
115class AffinitySideLoss(nn.Module):
116    """Loss computed between affinities derived from predicted embeddings and a target segmentation.
117
118    The offsets for the affinities will be derived randomly from the given `offset_ranges`.
119
120    Args:
121        offset_ranges: Ranges for the offsets to sampled.
122        n_samples: Number of offsets to sample per loss computation.
123        delta: The push force hinge used for training the embedding prediction network.
124    """
125    def __init__(self, offset_ranges: List[Tuple[int, int]], n_samples: int, delta: float):
126        assert all(len(orange) == 2 for orange in offset_ranges)
127        super().__init__()
128        self.ndim = len(offset_ranges)
129        self.offset_ranges = offset_ranges
130        self.n_samples = n_samples
131        self.delta = delta
132
133    def __call__(
134        self,
135        input_: torch.Tensor,
136        target: torch.Tensor,
137        ignore_labels: Optional[List[int]] = None,
138        ignore_in_variance_term: Optional[List[int]] = None,
139        ignore_in_distance_term: Optional[List[int]] = None,
140    ) -> torch.Tensor:
141        """Compute loss between affinities derived from predicted embeddings and a target segmentation.
142
143        Note: Support for the ignore labels is currently not implemented.
144
145        Args:
146            input_: The predicted embeddings.
147            target: The target segmentation.
148            ignore_labels: Ignore labels for the loss computation.
149            ignore_in_variance_term: Ignore labels for the variance term.
150            ignore_in_distance_term: Ignore labels for the distance term.
151
152        Returns:
153            The affinity loss value.
154        """
155        assert input_.dim() == target.dim(), f"{input_.dim()}, {target.dim()}"
156        assert input_.shape[2:] == target.shape[2:]
157
158        # Sample the offsets.
159        offsets = [[np.random.randint(orange[0], orange[1]) for orange in self.offset_ranges]
160                   for _ in range(self.n_samples)]
161
162        # We invert the affinities and the target affinities,
163        # so that we get boundaries as foreground, which is benefitial for the dice loss.
164        # Compute affinities from emebeddings.
165        affs = 1. - embeddings_to_affinities(input_, offsets, self.delta)
166
167        # Compute groundtruth affinities from the target segmentation.
168        target_affs = 1. - segmentation_to_affinities(target, offsets)
169        assert affs.shape == target_affs.shape, f"{affs.shape}, {target_affs.shape}"
170
171        # TODO implement masking the ignore labels
172        # Compute the dice score between affinities and target affinities.
173        return dice_score(affs, target_affs, invert=True)

Loss computed between affinities derived from predicted embeddings and a target segmentation.

The offsets for the affinities will be derived randomly from the given offset_ranges.

Arguments:
  • offset_ranges: Ranges for the offsets to sampled.
  • n_samples: Number of offsets to sample per loss computation.
  • delta: The push force hinge used for training the embedding prediction network.
AffinitySideLoss(offset_ranges: List[Tuple[int, int]], n_samples: int, delta: float)
125    def __init__(self, offset_ranges: List[Tuple[int, int]], n_samples: int, delta: float):
126        assert all(len(orange) == 2 for orange in offset_ranges)
127        super().__init__()
128        self.ndim = len(offset_ranges)
129        self.offset_ranges = offset_ranges
130        self.n_samples = n_samples
131        self.delta = delta

Initialize internal Module state, shared by both nn.Module and ScriptModule.

ndim
offset_ranges
n_samples
delta