
  1from functools import partial
  2from typing import List, Optional
  4import numpy as np
  5import elf.evaluation as elfval
  6import elf.segmentation as elfseg
  7import elf.segmentation.embeddings as elfemb
  8import torch
  9import torch.nn as nn
 10import vigra
 11from elf.segmentation.watershed import apply_size_filter
 14class BaseInstanceSegmentationMetric(nn.Module):
 15    """@private
 16    """
 17    def __init__(self, segmenter, metric, to_numpy=True):
 18        super().__init__()
 19        self.segmenter = segmenter
 20        self.metric = metric
 21        self.to_numpy = to_numpy
 23    def forward(self, input_, target):
 24        if self.to_numpy:
 25            input_ = input_.detach().cpu().numpy().astype("float32")
 26            target = target.detach().cpu().numpy()
 27        assert input_.ndim == target.ndim
 28        assert len(input_) == len(target)
 29        scores = []
 30        # compute the metric per batch
 31        for pred, trgt in zip(input_, target):
 32            seg = self.segmenter(pred)
 33            # by convention we assume that the segmentation channel is always in the last channel of trgt
 34            scores.append(self.metric(seg, trgt[-1].astype("uint32")))
 35        return torch.tensor(scores).mean()
 39# Segmenters
 42def filter_sizes(seg, min_seg_size, hmap=None):
 43    """@private
 44    """
 45    seg_ids, counts = np.unique(seg, return_counts=True)
 46    if hmap is None:
 47        bg_ids = seg_ids[counts < min_seg_size]
 48        seg[np.isin(seg, bg_ids)] = 0
 49    else:
 50        ndim = seg.ndim
 51        hmap_ = hmap if hmap.ndim == ndim else np.max(hmap, axis=0)
 52        seg, _ = apply_size_filter(seg, hmap_, min_seg_size)
 53    return seg
 56class MWS:
 57    """@private
 58    """
 59    def __init__(self, offsets, with_background, min_seg_size, strides=None):
 60        self.offsets = offsets
 61        self.with_background = with_background
 62        self.min_seg_size = min_seg_size
 63        if strides is None:
 64            strides = [4] * len(offsets[0])
 65        assert len(strides) == len(offsets[0])
 66        self.strides = strides
 68    def __call__(self, affinities):
 69        if self.with_background:
 70            assert len(affinities) == len(self.offsets) + 1
 71            mask, affinities = affinities[0], affinities[1:]
 72        else:
 73            assert len(affinities) == len(self.offsets)
 74            mask = None
 75        seg = elfseg.mutex_watershed.mutex_watershed(affinities, self.offsets, self.strides,
 76                                                     randomize_strides=True, mask=mask).astype("uint32")
 77        if self.min_seg_size > 0:
 78            seg = filter_sizes(seg, self.min_seg_size,
 79                               hmap=None if self.with_background else affinities)
 80        return seg
 83class EmbeddingMWS:
 84    """@private
 85    """
 86    def __init__(self, delta, offsets, with_background, min_seg_size, strides=None):
 87 = delta
 88        self.offsets = offsets
 89        self.with_background = with_background
 90        self.min_seg_size = min_seg_size
 91        if strides is None:
 92            strides = [4] * len(offsets[0])
 93        assert len(strides) == len(offsets[0])
 94        self.strides = strides
 96    def merge_background(self, seg, embeddings):
 97        seg += 1
 98        seg_ids, counts = np.unique(seg, return_counts=True)
 99        bg_seg = seg_ids[np.argmax(counts)]
100        mean_embeddings = []
101        for emb in embeddings:
102            mean_embeddings.append(vigra.analysis.extractRegionFeatures(emb, seg, features=["mean"])["mean"][None])
103        mean_embeddings = np.concatenate(mean_embeddings, axis=0)
104        bg_embed = mean_embeddings[:, bg_seg][:, None]
105        bg_probs = elfemb._embeddings_to_probabilities(mean_embeddings, bg_embed,, 0)
106        bg_ids = np.where(bg_probs > 0.5)
107        seg[np.isin(seg, bg_ids)] = 0
108        vigra.analysis.relabelConsecutive(seg, out=seg)
109        return seg
111    def __call__(self, embeddings):
112        weight = partial(elfemb.discriminative_loss_weight,
113        seg = elfemb.segment_embeddings_mws(
114            embeddings, "l2", self.offsets, strides=self.strides, weight_function=weight
115        ).astype("uint32")
116        if self.with_background:
117            seg = self.merge_background(seg, embeddings)
118        if self.min_seg_size > 0:
119            seg = filter_sizes(seg, self.min_seg_size)
120        return seg
123class Multicut:
124    """@private
125    """
126    def __init__(self, min_seg_size, anisotropic=False, dt_threshold=0.25, sigma_seeds=2.0, solver="decomposition"):
127        self.min_seg_size = min_seg_size
128        self.anisotropic = anisotropic
129        self.dt_threshold = dt_threshold
130        self.sigma_seeds = sigma_seeds
131        self.solver = solver
133    def __call__(self, boundaries):
134        if boundaries.shape[0] == 1:
135            boundaries = boundaries[0]
136        assert boundaries.ndim in (2, 3), f"{boundaries.ndim}"
137        if self.anisotropic and boundaries.ndim == 3:
138            ws, max_id = elfseg.stacked_watershed(boundaries, threshold=self.dt_threshold,
139                                                  sigma_seed=self.sigma_seeds,
140                                                  sigma_weights=self.sigma_seeds,
141                                                  n_threads=1)
142        else:
143            ws, max_id = elfseg.distance_transform_watershed(boundaries, threshold=self.dt_threshold,
144                                                             sigma_seeds=self.sigma_seeds,
145                                                             sigma_weights=self.sigma_seeds)
146        rag = elfseg.compute_rag(ws, max_id + 1, n_threads=1)
147        feats = elfseg.compute_boundary_mean_and_length(rag, boundaries, n_threads=1)[:, 0]
148        costs = elfseg.compute_edge_costs(feats)
149        solver = elfseg.get_multicut_solver(self.solver)
150        node_labels = solver(rag, costs, n_threads=1)
151        seg = elfseg.project_node_labels_to_pixels(rag, node_labels, n_threads=1).astype("uint32")
152        if self.min_seg_size > 0:
153            seg = filter_sizes(seg, self.min_seg_size, hmap=boundaries)
154        return seg
157class HDBScan:
158    """@private
159    """
160    def __init__(self, min_size, eps, remove_largest):
161        self.min_size = min_size
162        self.eps = eps
163        self.remove_largest = remove_largest
165    def __call__(self, embeddings):
166        return elfemb.segment_hdbscan(embeddings, self.min_size, self.eps, self.remove_largest)
170# Metrics
173class IOUError:
174    """@private
175    """
176    def __init__(self, threshold=0.5, metric="precision"):
177        self.threshold = threshold
178        self.metric = metric
180    def __call__(self, seg, target):
181        score = 1.0 - elfval.matching(seg, target, threshold=self.threshold)[self.metric]
182        return score
185class VariationOfInformation:
186    """@private
187    """
188    def __call__(self, seg, target):
189        vis, vim = elfval.variation_of_information(seg, target)
190        return vis + vim
193class AdaptedRandError:
194    """@private
195    """
196    def __call__(self, seg, target):
197        are, _ = elfval.rand_index(seg, target)
198        return are
201class SymmetricBestDice:
202    """@private
203    """
204    def __call__(self, seg, target):
205        score = 1.0 - elfval.symmetric_best_dice_score(seg, target)
206        return score
210# Prefab Full Metrics
214class EmbeddingMWSIOUMetric(BaseInstanceSegmentationMetric):
215    """Intersection over union metric based on mutex watershed computed from embedding-derived affinites.
217    This class can be used as validation metric when training a network for instance segmentation.
219    Args:
220        delta: The hinge distance of the contrastive loss for training the embeddings.
221        offsets: The offsets for deriving the affinities from the embeddings.
222        min_seg_size: Size for filtering the segmentation objects.
223        iou_threshold: Threshold for the intersection over union metric.
224        strides: The strides for the mutex watershed.
225    """
226    def __init__(
227        self,
228        delta: float,
229        offsets: List[List[int]],
230        min_seg_size: int,
231        iou_threshold: float = 0.5,
232        strides: Optional[List[int]] = None,
233    ):
234        segmenter = EmbeddingMWS(delta, offsets, with_background=True, min_seg_size=min_seg_size)
235        metric = IOUError(iou_threshold)
236        super().__init__(segmenter, metric)
237        self.init_kwargs = {"delta": delta, "offsets": offsets, "min_seg_size": min_seg_size,
238                            "iou_threshold": iou_threshold, "strides": strides}
241class EmbeddingMWSSBDMetric(BaseInstanceSegmentationMetric):
242    """Symmetric best dice metric based on mutex watershed computed from embedding-derived affinites.
244    This class can be used as validation metric when training a network for instance segmentation.
246    Args:
247        delta: The hinge distance of the contrastive loss for training the embeddings.
248        offsets: The offsets for deriving the affinities from the embeddings.
249        min_seg_size: Size for filtering the segmentation objects.
250        strides: The strides for the mutex watershed.
251    """
252    def __init__(self, delta: float, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int]] = None):
253        segmenter = EmbeddingMWS(delta, offsets, with_background=True, min_seg_size=min_seg_size)
254        metric = SymmetricBestDice()
255        super().__init__(segmenter, metric)
256        self.init_kwargs = {"delta": delta, "offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}
259class EmbeddingMWSVOIMetric(BaseInstanceSegmentationMetric):
260    """Variation of inofrmation metric based on mutex watershed computed from embedding-derived affinites.
262    This class can be used as validation metric when training a network for instance segmentation.
264    Args:
265        delta: The hinge distance of the contrastive loss for training the embeddings.
266        offsets: The offsets for deriving the affinities from the embeddings.
267        min_seg_size: Size for filtering the segmentation objects.
268        strides: The strides for the mutex watershed.
269    """
270    def __init__(self, delta: float, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int]] = None):
271        segmenter = EmbeddingMWS(delta, offsets, with_background=False, min_seg_size=min_seg_size)
272        metric = VariationOfInformation()
273        super().__init__(segmenter, metric)
274        self.init_kwargs = {"delta": delta, "offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}
277class EmbeddingMWSRandMetric(BaseInstanceSegmentationMetric):
278    """Rand index metric based on mutex watershed computed from embedding-derived affinites.
280    This class can be used as validation metric when training a network for instance segmentation.
282    Args:
283        delta: The hinge distance of the contrastive loss for training the embeddings.
284        offsets: The offsets for deriving the affinities from the embeddings.
285        min_seg_size: Size for filtering the segmentation objects.
286        strides: The strides for the mutex watershed.
287    """
288    def __init__(self, delta: float, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int]] = None):
289        segmenter = EmbeddingMWS(delta, offsets, with_background=False, min_seg_size=min_seg_size)
290        metric = AdaptedRandError()
291        super().__init__(segmenter, metric)
292        self.init_kwargs = {"delta": delta, "offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}
295class HDBScanIOUMetric(BaseInstanceSegmentationMetric):
296    """Intersection over union metric based on HDBScan computed from embeddings.
298    This class can be used as validation metric when training a network for instance segmentation.
300    Args:
301        min_size: The minimal segment size.
302        eps: The epsilon value for HDBScan.
303        iou_threshold: The threshold for the intersection over union value.
304    """
305    def __init__(self, min_size: int, eps: float, iou_threshold: float = 0.5):
306        segmenter = HDBScan(min_size=min_size, eps=eps, remove_largest=True)
307        metric = IOUError(iou_threshold)
308        super().__init__(segmenter, metric)
309        self.init_kwargs = {"min_size": min_size, "eps": eps, "iou_threshold": iou_threshold}
312class HDBScanSBDMetric(BaseInstanceSegmentationMetric):
313    """Symmetric best dice metric based on HDBScan computed from embeddings.
315    This class can be used as validation metric when training a network for instance segmentation.
317    Args:
318        min_size: The minimal segment size.
319        eps: The epsilon value for HDBScan.
320    """
321    def __init__(self, min_size: int, eps: float):
322        segmenter = HDBScan(min_size=min_size, eps=eps, remove_largest=True)
323        metric = SymmetricBestDice()
324        super().__init__(segmenter, metric)
325        self.init_kwargs = {"min_size": min_size, "eps": eps}
328class HDBScanRandMetric(BaseInstanceSegmentationMetric):
329    """Rand index metric based on HDBScan computed from embeddings.
331    This class can be used as validation metric when training a network for instance segmentation.
333    Args:
334        min_size: The minimal segment size.
335        eps: The epsilon value for HDBScan.
336    """
337    def __init__(self, min_size: int, eps: float):
338        segmenter = HDBScan(min_size=min_size, eps=eps, remove_largest=True)
339        metric = AdaptedRandError()
340        super().__init__(segmenter, metric)
341        self.init_kwargs = {"min_size": min_size, "eps": eps}
344class HDBScanVOIMetric(BaseInstanceSegmentationMetric):
345    """Variation of information metric based on HDBScan computed from embeddings.
347    This class can be used as validation metric when training a network for instance segmentation.
349    Args:
350        min_size: The minimal segment size.
351        eps: The epsilon value for HDBScan.
352    """
353    def __init__(self, min_size: int, eps: float):
354        segmenter = HDBScan(min_size=min_size, eps=eps, remove_largest=True)
355        metric = VariationOfInformation()
356        super().__init__(segmenter, metric)
357        self.init_kwargs = {"min_size": min_size, "eps": eps}
360class MulticutVOIMetric(BaseInstanceSegmentationMetric):
361    """Variation of information metric based on a multicut computed from boundary predictions.
363    This class can be used as validation metric when training a network for instance segmentation.
365    Args:
366        min_seg_size: The minimal segment size.
367        anisotropic: Whether to compute the watersheds in 2d for volumetric data.
368        dt_threshold: The threshold to apply to the boundary predictions before computing the distance transform.
369        sigma_seeds: The sigma value for smoothing the distance transform before computing seeds.
370    """
371    def __init__(
372        self, min_seg_size: int, anisotropic: bool = False, dt_threshold: float = 0.25, sigma_seeds: float = 2.0
373    ):
374        segmenter = Multicut(dt_threshold, anisotropic, sigma_seeds)
375        metric = VariationOfInformation()
376        super().__init__(segmenter, metric)
377        self.init_kwargs = {"anisotropic": anisotropic, "min_seg_size": min_seg_size,
378                            "dt_threshold": dt_threshold, "sigma_seeds": sigma_seeds}
381class MulticutRandMetric(BaseInstanceSegmentationMetric):
382    """Rand index metric based on a multicut computed from boundary predictions.
384    This class can be used as validation metric when training a network for instance segmentation.
386    Args:
387        min_seg_size: The minimal segment size.
388        anisotropic: Whether to compute the watersheds in 2d for volumetric data.
389        dt_threshold: The threshold to apply to the boundary predictions before computing the distance transform.
390        sigma_seeds: The sigma value for smoothing the distance transform before computing seeds.
391    """
392    def __init__(
393        self, min_seg_size: int, anisotropic: bool = False, dt_threshold: float = 0.25, sigma_seeds: float = 2.0
394    ):
395        segmenter = Multicut(dt_threshold, anisotropic, sigma_seeds)
396        metric = AdaptedRandError()
397        super().__init__(segmenter, metric)
398        self.init_kwargs = {"anisotropic": anisotropic, "min_seg_size": min_seg_size,
399                            "dt_threshold": dt_threshold, "sigma_seeds": sigma_seeds}
402class MWSIOUMetric(BaseInstanceSegmentationMetric):
403    """Intersection over union metric based on a mutex watershed computed from affinity predictions.
405    This class can be used as validation metric when training a network for instance segmentation.
407    Args:
408        offsets: The offsets corresponding to the affinity channels.
409        min_seg_size: The minimal segment size.
410        iou_threshold: The threshold for the intersection over union value.
411        strides: The strides for the mutex watershed.
412    """
413    def __init__(
414        self,
415        offsets: List[List[int]],
416        min_seg_size: int,
417        iou_threshold: float = 0.5,
418        strides: Optional[List[int]] = None
419    ):
420        segmenter = MWS(offsets, with_background=True, min_seg_size=min_seg_size, strides=strides)
421        metric = IOUError(iou_threshold)
422        super().__init__(segmenter, metric)
423        self.init_kwargs = {"offsets": offsets, "min_seg_size": min_seg_size,
424                            "iou_threshold": iou_threshold, "strides": strides}
427class MWSSBDMetric(BaseInstanceSegmentationMetric):
428    """Symmetric best dice score metric based on a mutex watershed computed from affinity predictions.
430    This class can be used as validation metric when training a network for instance segmentation.
432    Args:
433        offsets: The offsets corresponding to the affinity channels.
434        min_seg_size: The minimal segment size.
435        strides: The strides for the mutex watershed.
436    """
437    def __init__(self, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int]] = None):
438        segmenter = MWS(offsets, with_background=True, min_seg_size=min_seg_size, strides=strides)
439        metric = SymmetricBestDice()
440        super().__init__(segmenter, metric)
441        self.init_kwargs = {"offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}
444class MWSVOIMetric(BaseInstanceSegmentationMetric):
445    """Variation of information metric based on a mutex watershed computed from affinity predictions.
447    This class can be used as validation metric when training a network for instance segmentation.
449    Args:
450        offsets: The offsets corresponding to the affinity channels.
451        min_seg_size: The minimal segment size.
452        strides: The strides for the mutex watershed.
453    """
454    def __init__(self, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int]] = None):
455        segmenter = MWS(offsets, with_background=False, min_seg_size=min_seg_size, strides=strides)
456        metric = VariationOfInformation()
457        super().__init__(segmenter, metric)
458        self.init_kwargs = {"offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}
461class MWSRandMetric(BaseInstanceSegmentationMetric):
462    """Rand index metric based on a mutex watershed computed from affinity predictions.
464    This class can be used as validation metric when training a network for instance segmentation.
466    Args:
467        offsets: The offsets corresponding to the affinity channels.
468        min_seg_size: The minimal segment size.
469        strides: The strides for the mutex watershed.
470    """
471    def __init__(self, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int]] = None):
472        segmenter = MWS(offsets, with_background=False, min_seg_size=min_seg_size, strides=strides)
473        metric = AdaptedRandError()
474        super().__init__(segmenter, metric)
475        self.init_kwargs = {"offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}
