torch_em.util.grid_search

  1from typing import Callable, Dict, List, Optional, Tuple
  2
  3import bioimageio.core
  4import numpy as np
  5import torch.nn as nn
  6import xarray
  7
  8try:
  9    from micro_sam.instance_segmentation import InstanceSegmentationWithDecoder
 10    from micro_sam.evaluation.instance_segmentation import (
 11        default_grid_search_values_instance_segmentation_with_decoder,
 12        evaluate_instance_segmentation_grid_search,
 13        run_instance_segmentation_grid_search,
 14        _get_range_of_search_values,
 15    )
 16
 17    HAVE_MICRO_SAM = True
 18except ImportError:
 19    class InstanceSegmentationWithDecoder:
 20        def __init__(self, *args, **kwargs):
 21            pass
 22
 23    HAVE_MICRO_SAM = False
 24
 25from ..transform.raw import standardize
 26from .prediction import predict_with_padding, predict_with_halo
 27from .segmentation import watershed_from_components
 28
 29
 30def default_grid_search_values_boundary_based_instance_segmentation(
 31    threshold1_values=None, threshold2_values=None, min_size_values=None,
 32):
 33    """@private
 34    """
 35    if threshold1_values is None:
 36        threshold1_values = [0.5, 0.55, 0.6]
 37    if threshold2_values is None:
 38        threshold2_values = _get_range_of_search_values(
 39            [0.5, 0.9], step=0.1
 40        )
 41    if min_size_values is None:
 42        min_size_values = [25, 50, 75, 100, 200]
 43
 44    return {"min_size": min_size_values, "threshold1": threshold1_values, "threshold2": threshold2_values}
 45
 46
 47class _InstanceSegmentationBase(InstanceSegmentationWithDecoder):
 48    """Over-write micro_sam functionality so that it works for distance based segmentation with a U-net.
 49    """
 50    def __init__(self, model, preprocess=None, block_shape=None, halo=None):
 51        self._model = model
 52        self._preprocess = standardize if preprocess is None else preprocess
 53
 54        assert (block_shape is None) == (halo is None)
 55        self._block_shape = block_shape
 56        self._halo = halo
 57        self._is_initialized = False
 58
 59    def _initialize_torch(self, data):
 60        device = next(iter(self._model.parameters())).device
 61
 62        if self._block_shape is None:
 63            if hasattr(self._model, "scale_factors"):
 64                scale_factors = self._model.init_kwargs["scale_factors"]
 65                min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
 66            elif hasattr(self._model, "depth"):
 67                depth = self._model.depth
 68                min_divisible = [2**depth, 2**depth]
 69            else:
 70                raise RuntimeError
 71            input_ = self._preprocess(data)
 72            output = predict_with_padding(self._model, input_, min_divisible, device)
 73        else:
 74            output = predict_with_halo(
 75                data, self._model, [device], self._block_shape, self._halo,
 76                preprocess=self._preprocess,
 77            )
 78        return output
 79
 80    def _initialize_modelzoo(self, data):
 81        if self._block_shape is None:
 82            with bioimageio.core.create_prediction_pipeline(self._model) as pp:
 83                dims = tuple("bcyx") if data.ndim == 2 else tuple("bczyx")
 84                input_ = xarray.DataArray(data[None, None], dims=dims)
 85                output = bioimageio.core.prediction.predict_with_padding(pp, input_, padding=True)[0]
 86                output = output.squeeze().values
 87        else:
 88            raise NotImplementedError
 89        return output
 90
 91
 92class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase):
 93    """Wrapper for boundary based instance segmentation.
 94
 95    Instances of this class can be passed to `instance_segmentation_grid_search`.
 96
 97    Args:
 98        model: The model to evaluate. It must predict two channels:
 99            The first channel fpr foreground probabilities and the second for boundary probabilities.
100        preprocess: Optional preprocessing function to apply to the model inputs.
101        block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc.
102        halo: Halo for tiled prediction.
103    """
104    def __init__(
105        self,
106        model: nn.Module,
107        preprocess: Optional[Callable] = None,
108        block_shape: Tuple[int, ...] = None,
109        halo: Tuple[int, ...] = None,
110    ):
111        super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo)
112        self._foreground = None
113        self._boundaries = None
114
115    def initialize(self, data):
116        """@private
117        """
118        if isinstance(self._model, nn.Module):
119            output = self._initialize_torch(data)
120        else:
121            output = self._initialize_modelzoo(data)
122        assert output.shape[0] == 2
123
124        self._foreground = output[0]
125        self._boundaries = output[1]
126        self._is_initialized = True
127
128    def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"):
129        """@private
130        """
131        segmentation = watershed_from_components(
132            self._boundaries, self._foreground,
133            min_size=min_size, threshold1=threshold1, threshold2=threshold2,
134        )
135        if output_mode is not None:
136            segmentation = self._to_masks(segmentation, output_mode)
137        return segmentation
138
139
140class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase):
141    """Wrapper for distance based instance segmentation.
142
143    Instances of this class can be passed to `instance_segmentation_grid_search`.
144
145    Args:
146        model: The model to evaluate. It must predict three channels:
147            The first channel fpr foreground probabilities, the second for center distances
148            and the third for boundary distances.
149        preprocess: Optional preprocessing function to apply to the model inputs.
150        block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc.
151        halo: Halo for tiled prediction.
152    """
153    def __init__(
154        self,
155        model: nn.Module,
156        preprocess: Optional[Callable] = None,
157        block_shape: Tuple[int, ...] = None,
158        halo: Tuple[int, ...] = None,
159    ):
160        super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo)
161
162        self._foreground = None
163        self._center_distances = None
164        self._boundary_distances = None
165
166    def initialize(self, data):
167        """@private
168        """
169        if isinstance(self._model, nn.Module):
170            output = self._initialize_torch(data)
171        else:
172            output = self._initialize_modelzoo(data)
173
174        assert output.shape[0] == 3
175        self._foreground = output[0]
176        self._center_distances = output[1]
177        self._boundary_distances = output[2]
178        self._is_initialized = True
179
180
181def instance_segmentation_grid_search(
182    segmenter,
183    image_paths: List[str],
184    gt_paths: List[str],
185    result_dir: str,
186    grid_search_values: Optional[Dict] = None,
187    rois: Optional[List[Tuple[slice, ...]]] = None,
188    image_key: Optional[str] = None,
189    gt_key: Optional[str] = None,
190) -> Tuple[Dict, float]:
191    """Run grid search for instance segmentation.
192
193    Args:
194        segmenter: The segmentation wrapper. Needs to provide a 'initialize' and 'generate' function.
195            The class `DistanceBasedInstanceSegmentation` can be used for models predicting distances
196            for instance segmentation, `BoundaryBasedInstanceSegmentation` for models predicting boundaries.
197        image_paths: The paths to the images to use for the grid search.
198        gt_paths: The paths to the labels to use for the grid search.
199        result_dir: The directory for caching the grid search results.
200        grid_search_values: The values to test in the grid search.
201        rois: Region of interests to use for the evaluation. If given, must have the same length as `image_paths`.
202        image_key: The key to the internal dataset with the image data.
203            Leave None if the images are in a regular image format such as tif.
204        gt_key: The key to the internal dataset with the label data.
205            Leave None if the images are in a regular image format such as tif.
206
207    Returns:
208        The best parameters found by the grid search.
209        The best score of the grid search.
210    """
211    if not HAVE_MICRO_SAM:
212        raise RuntimeError(
213            "The gridsearch functionality requires micro_sam. Install it via `conda install -c conda-forge micro_sam.`"
214        )
215
216    if grid_search_values is None:
217        if isinstance(segmenter, DistanceBasedInstanceSegmentation):
218            grid_search_values = default_grid_search_values_instance_segmentation_with_decoder()
219        elif isinstance(segmenter, BoundaryBasedInstanceSegmentation):
220            grid_search_values = default_grid_search_values_boundary_based_instance_segmentation()
221        else:
222            raise ValueError(f"Could not derive default grid search values for segmenter of type {type(segmenter)}")
223
224    run_instance_segmentation_grid_search(
225        segmenter, grid_search_values, image_paths, gt_paths, result_dir,
226        embedding_dir=None, verbose_gs=True,
227        image_key=image_key, gt_key=gt_key, rois=rois,
228    )
229    best_kwargs, best_score = evaluate_instance_segmentation_grid_search(
230        result_dir, list(grid_search_values.keys())
231    )
232    return best_kwargs, best_score
class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase):
 93class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase):
 94    """Wrapper for boundary based instance segmentation.
 95
 96    Instances of this class can be passed to `instance_segmentation_grid_search`.
 97
 98    Args:
 99        model: The model to evaluate. It must predict two channels:
100            The first channel fpr foreground probabilities and the second for boundary probabilities.
101        preprocess: Optional preprocessing function to apply to the model inputs.
102        block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc.
103        halo: Halo for tiled prediction.
104    """
105    def __init__(
106        self,
107        model: nn.Module,
108        preprocess: Optional[Callable] = None,
109        block_shape: Tuple[int, ...] = None,
110        halo: Tuple[int, ...] = None,
111    ):
112        super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo)
113        self._foreground = None
114        self._boundaries = None
115
116    def initialize(self, data):
117        """@private
118        """
119        if isinstance(self._model, nn.Module):
120            output = self._initialize_torch(data)
121        else:
122            output = self._initialize_modelzoo(data)
123        assert output.shape[0] == 2
124
125        self._foreground = output[0]
126        self._boundaries = output[1]
127        self._is_initialized = True
128
129    def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"):
130        """@private
131        """
132        segmentation = watershed_from_components(
133            self._boundaries, self._foreground,
134            min_size=min_size, threshold1=threshold1, threshold2=threshold2,
135        )
136        if output_mode is not None:
137            segmentation = self._to_masks(segmentation, output_mode)
138        return segmentation

Wrapper for boundary based instance segmentation.

Instances of this class can be passed to instance_segmentation_grid_search.

Arguments:
  • model: The model to evaluate. It must predict two channels: The first channel fpr foreground probabilities and the second for boundary probabilities.
  • preprocess: Optional preprocessing function to apply to the model inputs.
  • block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc.
  • halo: Halo for tiled prediction.
BoundaryBasedInstanceSegmentation( model: torch.nn.modules.module.Module, preprocess: Optional[Callable] = None, block_shape: Tuple[int, ...] = None, halo: Tuple[int, ...] = None)
105    def __init__(
106        self,
107        model: nn.Module,
108        preprocess: Optional[Callable] = None,
109        block_shape: Tuple[int, ...] = None,
110        halo: Tuple[int, ...] = None,
111    ):
112        super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo)
113        self._foreground = None
114        self._boundaries = None
class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase):
141class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase):
142    """Wrapper for distance based instance segmentation.
143
144    Instances of this class can be passed to `instance_segmentation_grid_search`.
145
146    Args:
147        model: The model to evaluate. It must predict three channels:
148            The first channel fpr foreground probabilities, the second for center distances
149            and the third for boundary distances.
150        preprocess: Optional preprocessing function to apply to the model inputs.
151        block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc.
152        halo: Halo for tiled prediction.
153    """
154    def __init__(
155        self,
156        model: nn.Module,
157        preprocess: Optional[Callable] = None,
158        block_shape: Tuple[int, ...] = None,
159        halo: Tuple[int, ...] = None,
160    ):
161        super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo)
162
163        self._foreground = None
164        self._center_distances = None
165        self._boundary_distances = None
166
167    def initialize(self, data):
168        """@private
169        """
170        if isinstance(self._model, nn.Module):
171            output = self._initialize_torch(data)
172        else:
173            output = self._initialize_modelzoo(data)
174
175        assert output.shape[0] == 3
176        self._foreground = output[0]
177        self._center_distances = output[1]
178        self._boundary_distances = output[2]
179        self._is_initialized = True

Wrapper for distance based instance segmentation.

Instances of this class can be passed to instance_segmentation_grid_search.

Arguments:
  • model: The model to evaluate. It must predict three channels: The first channel fpr foreground probabilities, the second for center distances and the third for boundary distances.
  • preprocess: Optional preprocessing function to apply to the model inputs.
  • block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc.
  • halo: Halo for tiled prediction.
DistanceBasedInstanceSegmentation( model: torch.nn.modules.module.Module, preprocess: Optional[Callable] = None, block_shape: Tuple[int, ...] = None, halo: Tuple[int, ...] = None)
154    def __init__(
155        self,
156        model: nn.Module,
157        preprocess: Optional[Callable] = None,
158        block_shape: Tuple[int, ...] = None,
159        halo: Tuple[int, ...] = None,
160    ):
161        super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo)
162
163        self._foreground = None
164        self._center_distances = None
165        self._boundary_distances = None
class InstanceSegmentationWithDecoder:
20    class InstanceSegmentationWithDecoder:
21        def __init__(self, *args, **kwargs):
22            pass
InstanceSegmentationWithDecoder(*args, **kwargs)
21        def __init__(self, *args, **kwargs):
22            pass