torch_em.util.grid_search

  1import numpy as np
  2import torch.nn as nn
  3import xarray
  4
  5import bioimageio.core
  6
  7from micro_sam.instance_segmentation import InstanceSegmentationWithDecoder
  8from micro_sam.evaluation.instance_segmentation import (
  9    default_grid_search_values_instance_segmentation_with_decoder,
 10    evaluate_instance_segmentation_grid_search,
 11    run_instance_segmentation_grid_search,
 12    _get_range_of_search_values,
 13)
 14
 15from ..transform.raw import standardize
 16from .prediction import predict_with_padding, predict_with_halo
 17from .segmentation import watershed_from_components
 18
 19
 20def default_grid_search_values_boundary_based_instance_segmentation(
 21    threshold1_values=None,
 22    threshold2_values=None,
 23    min_size_values=None,
 24):
 25    if threshold1_values is None:
 26        threshold1_values = [0.5, 0.55, 0.6]
 27    if threshold2_values is None:
 28        threshold2_values = _get_range_of_search_values(
 29            [0.5, 0.9], step=0.1
 30        )
 31    if min_size_values is None:
 32        min_size_values = [25, 50, 75, 100, 200]
 33
 34    return {
 35        "min_size": min_size_values,
 36        "threshold1": threshold1_values,
 37        "threshold2": threshold2_values,
 38    }
 39
 40
 41class _InstanceSegmentationBase(InstanceSegmentationWithDecoder):
 42    def __init__(self, model, preprocess=None, block_shape=None, halo=None):
 43        self._model = model
 44        self._preprocess = standardize if preprocess is None else preprocess
 45
 46        assert (block_shape is None) == (halo is None)
 47        self._block_shape = block_shape
 48        self._halo = halo
 49
 50        self._is_initialized = False
 51
 52    def _initialize_torch(self, data):
 53        device = next(iter(self._model.parameters())).device
 54
 55        if self._block_shape is None:
 56            if hasattr(self._model, "scale_factors"):
 57                scale_factors = self._model.init_kwargs["scale_factors"]
 58                min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
 59            elif hasattr(self._model, "depth"):
 60                depth = self._model.depth
 61                min_divisible = [2**depth, 2**depth]
 62            else:
 63                raise RuntimeError
 64            input_ = self._preprocess(data)
 65            output = predict_with_padding(self._model, input_, min_divisible, device)
 66        else:
 67            output = predict_with_halo(
 68                data, self._model, [device], self._block_shape, self._halo,
 69                preprocess=self._preprocess,
 70            )
 71        return output
 72
 73    def _initialize_modelzoo(self, data):
 74        if self._block_shape is None:
 75            with bioimageio.core.create_prediction_pipeline(self._model) as pp:
 76                dims = tuple("bcyx") if data.ndim == 2 else tuple("bczyx")
 77                input_ = xarray.DataArray(data[None, None], dims=dims)
 78                output = bioimageio.core.prediction.predict_with_padding(pp, input_, padding=True)[0]
 79                output = output.squeeze().values
 80        else:
 81            raise NotImplementedError
 82        return output
 83
 84
 85class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase):
 86    def __init__(self, model, preprocess=None, block_shape=None, halo=None):
 87        super().__init__(
 88            model=model, preprocess=preprocess, block_shape=block_shape, halo=halo
 89        )
 90
 91        self._foreground = None
 92        self._boundaries = None
 93
 94    def initialize(self, data):
 95        if isinstance(self._model, nn.Module):
 96            output = self._initialize_torch(data)
 97        else:
 98            output = self._initialize_modelzoo(data)
 99
100        assert output.shape[0] == 2
101
102        self._foreground = output[0]
103        self._boundaries = output[1]
104
105        self._is_initialized = True
106
107    def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"):
108        segmentation = watershed_from_components(
109            self._boundaries, self._foreground,
110            min_size=min_size, threshold1=threshold1, threshold2=threshold2,
111        )
112        if output_mode is not None:
113            segmentation = self._to_masks(segmentation, output_mode)
114        return segmentation
115
116
117class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase):
118    """Over-write micro_sam functionality so that it works for distance based
119    segmentation with a U-net.
120    """
121    def __init__(self, model, preprocess=None, block_shape=None, halo=None):
122        super().__init__(
123            model=model, preprocess=preprocess, block_shape=block_shape, halo=halo
124        )
125
126        self._foreground = None
127        self._center_distances = None
128        self._boundary_distances = None
129
130    def initialize(self, data):
131        if isinstance(self._model, nn.Module):
132            output = self._initialize_torch(data)
133        else:
134            output = self._initialize_modelzoo(data)
135
136        assert output.shape[0] == 3
137        self._foreground = output[0]
138        self._center_distances = output[1]
139        self._boundary_distances = output[2]
140
141        self._is_initialized = True
142
143
144def instance_segmentation_grid_search(
145    segmenter, image_paths, gt_paths, result_dir,
146    grid_search_values=None, rois=None,
147    image_key=None, gt_key=None,
148):
149    if grid_search_values is None:
150        if isinstance(segmenter, DistanceBasedInstanceSegmentation):
151            grid_search_values = default_grid_search_values_instance_segmentation_with_decoder()
152        elif isinstance(segmenter, BoundaryBasedInstanceSegmentation):
153            grid_search_values = default_grid_search_values_boundary_based_instance_segmentation()
154        else:
155            raise ValueError(f"Could not derive default grid search values for segmenter of type {type(segmenter)}")
156
157    run_instance_segmentation_grid_search(
158        segmenter, grid_search_values, image_paths, gt_paths, result_dir,
159        embedding_dir=None, verbose_gs=True,
160        image_key=image_key, gt_key=gt_key, rois=rois,
161    )
162    best_kwargs, best_score = evaluate_instance_segmentation_grid_search(
163        result_dir, list(grid_search_values.keys())
164    )
165    return best_kwargs, best_score
def default_grid_search_values_boundary_based_instance_segmentation(threshold1_values=None, threshold2_values=None, min_size_values=None):
21def default_grid_search_values_boundary_based_instance_segmentation(
22    threshold1_values=None,
23    threshold2_values=None,
24    min_size_values=None,
25):
26    if threshold1_values is None:
27        threshold1_values = [0.5, 0.55, 0.6]
28    if threshold2_values is None:
29        threshold2_values = _get_range_of_search_values(
30            [0.5, 0.9], step=0.1
31        )
32    if min_size_values is None:
33        min_size_values = [25, 50, 75, 100, 200]
34
35    return {
36        "min_size": min_size_values,
37        "threshold1": threshold1_values,
38        "threshold2": threshold2_values,
39    }
class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase):
 86class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase):
 87    def __init__(self, model, preprocess=None, block_shape=None, halo=None):
 88        super().__init__(
 89            model=model, preprocess=preprocess, block_shape=block_shape, halo=halo
 90        )
 91
 92        self._foreground = None
 93        self._boundaries = None
 94
 95    def initialize(self, data):
 96        if isinstance(self._model, nn.Module):
 97            output = self._initialize_torch(data)
 98        else:
 99            output = self._initialize_modelzoo(data)
100
101        assert output.shape[0] == 2
102
103        self._foreground = output[0]
104        self._boundaries = output[1]
105
106        self._is_initialized = True
107
108    def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"):
109        segmentation = watershed_from_components(
110            self._boundaries, self._foreground,
111            min_size=min_size, threshold1=threshold1, threshold2=threshold2,
112        )
113        if output_mode is not None:
114            segmentation = self._to_masks(segmentation, output_mode)
115        return segmentation

Generates an instance segmentation without prompts, using a decoder.

Implements the same interface as AutomaticMaskGenerator.

Use this class as follows:

segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
Arguments:
  • predictor: The segment anything predictor.
  • decoder: The decoder to predict intermediate representations for instance segmentation.
BoundaryBasedInstanceSegmentation(model, preprocess=None, block_shape=None, halo=None)
87    def __init__(self, model, preprocess=None, block_shape=None, halo=None):
88        super().__init__(
89            model=model, preprocess=preprocess, block_shape=block_shape, halo=halo
90        )
91
92        self._foreground = None
93        self._boundaries = None
def initialize(self, data):
 95    def initialize(self, data):
 96        if isinstance(self._model, nn.Module):
 97            output = self._initialize_torch(data)
 98        else:
 99            output = self._initialize_modelzoo(data)
100
101        assert output.shape[0] == 2
102
103        self._foreground = output[0]
104        self._boundaries = output[1]
105
106        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to be verbose.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
def generate( self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode='binary_mask'):
108    def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"):
109        segmentation = watershed_from_components(
110            self._boundaries, self._foreground,
111            min_size=min_size, threshold1=threshold1, threshold2=threshold2,
112        )
113        if output_mode is not None:
114            segmentation = self._to_masks(segmentation, output_mode)
115        return segmentation

Generate instance segmentation for the currently initialized image.

Arguments:
  • center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions).
  • boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions).
  • foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction.
  • foreground_threshold: Foreground predictions above this value will be used as foreground mask.
  • distance_smoothing: Sigma value for smoothing the distance predictions.
  • min_size: Minimal object size in the segmentation result.
  • output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
Returns:

The instance segmentation masks.

Inherited Members
micro_sam.instance_segmentation.InstanceSegmentationWithDecoder
is_initialized
get_state
set_state
clear_state
class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase):
118class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase):
119    """Over-write micro_sam functionality so that it works for distance based
120    segmentation with a U-net.
121    """
122    def __init__(self, model, preprocess=None, block_shape=None, halo=None):
123        super().__init__(
124            model=model, preprocess=preprocess, block_shape=block_shape, halo=halo
125        )
126
127        self._foreground = None
128        self._center_distances = None
129        self._boundary_distances = None
130
131    def initialize(self, data):
132        if isinstance(self._model, nn.Module):
133            output = self._initialize_torch(data)
134        else:
135            output = self._initialize_modelzoo(data)
136
137        assert output.shape[0] == 3
138        self._foreground = output[0]
139        self._center_distances = output[1]
140        self._boundary_distances = output[2]
141
142        self._is_initialized = True

Over-write micro_sam functionality so that it works for distance based segmentation with a U-net.

DistanceBasedInstanceSegmentation(model, preprocess=None, block_shape=None, halo=None)
122    def __init__(self, model, preprocess=None, block_shape=None, halo=None):
123        super().__init__(
124            model=model, preprocess=preprocess, block_shape=block_shape, halo=halo
125        )
126
127        self._foreground = None
128        self._center_distances = None
129        self._boundary_distances = None
def initialize(self, data):
131    def initialize(self, data):
132        if isinstance(self._model, nn.Module):
133            output = self._initialize_torch(data)
134        else:
135            output = self._initialize_modelzoo(data)
136
137        assert output.shape[0] == 3
138        self._foreground = output[0]
139        self._center_distances = output[1]
140        self._boundary_distances = output[2]
141
142        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to be verbose.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
Inherited Members
micro_sam.instance_segmentation.InstanceSegmentationWithDecoder
is_initialized
generate
get_state
set_state
clear_state