torch_em.util.validation

  1import os
  2from typing import Callable, List, Optional, Sequence, Tuple, Union
  3
  4import imageio.v3 as imageio
  5import numpy as np
  6
  7from elf.io import open_file
  8from elf.util import normalize_index
  9
 10from .prediction import predict_with_halo
 11from .util import get_trainer, get_normalizer
 12from ..data import ConcatDataset, ImageCollectionDataset, SegmentationDataset
 13
 14try:
 15    import napari
 16except ImportError:
 17    napari = None
 18
 19
 20class SampleGenerator:
 21    """@private
 22    """
 23    def __init__(self, trainer, max_samples, need_gt, n_threads):
 24        self.need_gt = need_gt
 25        self.n_threads = n_threads
 26
 27        dataset = trainer.val_loader.dataset
 28        self.ndim = dataset.ndim
 29
 30        (n_samples, load_2d_from_3d, rois,
 31         raw_paths, raw_key,
 32         label_paths, label_key) = self.paths_from_ds(dataset)
 33
 34        if max_samples is None:
 35            self.n_samples = n_samples
 36        else:
 37            self.n_samples = min(max_samples, n_samples)
 38        self.load_2d_from_3d = load_2d_from_3d
 39        self.rois = rois
 40        self.raw_paths, self.raw_key = raw_paths, raw_key
 41        self.label_paths, self.label_key = label_paths, label_key
 42
 43        if self.load_2d_from_3d:
 44            shapes = [
 45                open_file(rp, "r")[self.raw_key].shape if roi is None else tuple(r.stop - r.start for r in roi)
 46                for rp, roi in zip(self.raw_paths, self.rois)
 47            ]
 48            lens = [shape[0] for shape in shapes]
 49            self.offsets = np.cumsum(lens)
 50
 51    def paths_from_ds(self, dataset):
 52        if isinstance(dataset, ConcatDataset):
 53            datasets = dataset.datasets
 54            (
 55                n_samples, load_2d_from_3d, rois, raw_paths,
 56                raw_key, label_paths, label_key
 57            ) = self.paths_from_ds(datasets[0])
 58
 59            for ds in datasets[1:]:
 60                ns, l2d3d, bb, rp, rk, lp, lk = self.paths_from_ds(ds)
 61                assert rk == raw_key
 62                assert lk == label_key
 63                assert l2d3d == load_2d_from_3d
 64                raw_paths.extend(rp)
 65                label_paths.extend(lp)
 66                rois.append(bb)
 67                n_samples += ns
 68
 69        elif isinstance(dataset, ImageCollectionDataset):
 70            raw_paths, label_paths = dataset.raw_images, dataset.label_images
 71            raw_key, label_key = None, None
 72            n_samples = len(raw_paths)
 73            load_2d_from_3d = False
 74            rois = [None] * n_samples
 75
 76        elif isinstance(dataset, SegmentationDataset):
 77            raw_paths, label_paths = [dataset.raw_path], [dataset.label_path]
 78            raw_key, label_key = dataset.raw_key, dataset.label_key
 79            shape = open_file(raw_paths[0], 'r')[raw_key].shape
 80
 81            roi = getattr(dataset, 'roi', None)
 82            if roi is not None:
 83                roi = normalize_index(roi, shape)
 84                shape = tuple(r.stop - r.start for r in roi)
 85            rois = [roi]
 86
 87            if self.ndim == len(shape):
 88                n_samples = len(raw_paths)
 89                load_2d_from_3d = False
 90            elif self.ndim == 2 and len(shape) == 3:
 91                n_samples = shape[0]
 92                load_2d_from_3d = True
 93            else:
 94                raise RuntimeError
 95
 96        else:
 97            raise RuntimeError(f"No support for dataset of type {type(dataset)}")
 98
 99        return (n_samples, load_2d_from_3d, rois,
100                raw_paths, raw_key, label_paths, label_key)
101
102    def load_data(self, path, key, roi, z):
103        if key is None:
104            assert roi is None and z is None
105            return imageio.imread(path)
106
107        bb = np.s_[:, :, :] if roi is None else roi
108        if z is not None:
109            bb[0] = z if roi is None else roi[0].start + z
110
111        with open_file(path, 'r') as f:
112            ds = f[key]
113            ds.n_threads = self.n_threads
114            data = ds[bb]
115        return data
116
117    def load_sample(self, sample_id):
118        if self.load_2d_from_3d:
119            ds_id = 0
120            while True:
121                if sample_id < self.offsets[ds_id]:
122                    break
123                ds_id += 1
124            offset = self.offsets[ds_id - 1] if ds_id > 0 else 0
125            z = sample_id - offset
126        else:
127            ds_id = sample_id
128            z = None
129
130        roi = self.rois[ds_id]
131        raw = self.load_data(self.raw_paths[ds_id], self.raw_key, roi, z)
132        if not self.need_gt:
133            return raw
134        gt = self.load_data(self.label_paths[ds_id], self.label_key, roi, z)
135        return raw, gt
136
137    def __iter__(self):
138        for sample_id in range(self.n_samples):
139            sample = self.load_sample(sample_id)
140            yield sample
141
142
143def _predict(model, raw, trainer, gpu_ids, save_path, sample_id):
144    save_key = f"sample{sample_id}"
145    if save_path is not None and os.path.exists(save_path):
146        with open_file(save_path, "r") as f:
147            if save_key in f:
148                print("Loading predictions for sample", sample_id, "from file")
149                ds = f[save_key]
150                ds.n_threads = 8
151                return ds[:]
152
153    normalizer = get_normalizer(trainer)
154    dataset = trainer.val_loader.dataset
155    ndim = dataset.ndim
156    if isinstance(dataset, ConcatDataset):
157        patch_shape = dataset.datasets[0].patch_shape
158    else:
159        patch_shape = dataset.patch_shape
160
161    if ndim == 2 and len(patch_shape) == 3:
162        patch_shape = patch_shape[1:]
163    assert len(patch_shape) == ndim
164
165    # choose a small halo and set the correct block shape
166    halo = (32, 32) if ndim == 2 else (8, 16, 16)
167    block_shape = tuple(psh - 2 * ha for psh, ha in zip(patch_shape, halo))
168
169    if save_path is None:
170        output = None
171    else:
172        f = open_file(save_path, "a")
173        out_shape = (trainer.model.out_channels,) + raw.shape
174        chunks = (1,) + block_shape
175        output = f.create_dataset(save_key, shape=out_shape, chunks=chunks, compression="gzip", dtype="float32")
176
177    gpu_ids = [int(gpu) if gpu != "cpu" else gpu for gpu in gpu_ids]
178    pred = predict_with_halo(raw, model, gpu_ids, block_shape, halo, preprocess=normalizer, output=output)
179    if output is not None:
180        f.close()
181
182    return pred
183
184
185def _visualize(raw, prediction, ground_truth):
186    with napari.gui_qt():
187        viewer = napari.Viewer()
188        viewer.add_image(raw)
189        viewer.add_image(prediction)
190        if ground_truth is not None:
191            viewer.add_labels(ground_truth)
192
193
194def validate_checkpoint(
195    checkpoint: str,
196    gpu_ids: List[int],
197    save_path: Optional[str] = None,
198    samples: Optional[Sequence[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]] = None,
199    max_samples: Optional[int] = None,
200    visualize: bool = True,
201    metrics: Optional[Callable] = None,
202    n_threads: Optional[int] = None,
203) -> List[float]:
204    """Validate model for the given checkpoint visually and/or via metrics.
205
206    Args:
207        checkpoint: The path to the checkpoint to evaluate.
208        gpu_ids: The gpu ids to use for prediction.
209        save_path: Optional path for saving the predictions.
210        samples: The samples to use for evaluation. If None, the validation loader of the trainer is used.
211        max_samples: The maximum number of samples to evaluate.
212        visualize: Whether to visualize the predictions with napari.
213        metrics: The metric to use for evaluating the samples.
214        n_threads: The number of threads to use for parallelization.
215
216    Returns:
217        A list of metric scores.
218    """
219    if visualize and napari is None:
220        raise RuntimeError
221
222    trainer = get_trainer(checkpoint, device="cpu")
223    n_threads = trainer.train_loader.num_workers if n_threads is None else n_threads
224    model = trainer.model
225    model.eval()
226
227    need_gt = metrics is not None
228    if samples is None:
229        samples = SampleGenerator(trainer, max_samples, need_gt, n_threads)
230    else:
231        assert isinstance(samples, (list, tuple))
232        if need_gt:
233            assert all(len(sample, 2) for sample in samples)
234        else:
235            assert all(isinstance(sample, np.ndarray) for sample in samples)
236
237    results = []
238    for sample_id, sample in enumerate(samples):
239        raw, gt = sample if need_gt else sample, None
240        pred = _predict(model, raw, trainer, gpu_ids, save_path, sample_id)
241        if visualize:
242            _visualize(raw, pred, gt)
243        if metrics is not None:
244            res = metrics(gt, pred)
245            results.append(res)
246    return results
247
248
249def main():
250    """@private
251    """
252    import argparse
253    parser = argparse.ArgumentParser()
254    parser.add_argument("-p", "--path", required=True, help="Path to the checkpoint")
255    parser.add_argument("-g", "--gpus", type=str, nargs="+", required=True)
256    parser.add_argument("-n", "--max_samples", type=int, default=None)
257    parser.add_argument("-d", "--data", default=None)
258    parser.add_argument("-s", "--save_path", default=None)
259    parser.add_argument("-k", "--key", default=None)
260    parser.add_argument("-t", "--n_threads", type=int, default=None)
261
262    args = parser.parse_args()
263    # TODO implement loading data
264    assert args.data is None
265    validate_checkpoint(args.path, args.gpus, args.save_path, max_samples=args.max_samples, n_threads=args.n_threads)
def validate_checkpoint( checkpoint: str, gpu_ids: List[int], save_path: Optional[str] = None, samples: Optional[Sequence[Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]]] = None, max_samples: Optional[int] = None, visualize: bool = True, metrics: Optional[Callable] = None, n_threads: Optional[int] = None) -> List[float]:
195def validate_checkpoint(
196    checkpoint: str,
197    gpu_ids: List[int],
198    save_path: Optional[str] = None,
199    samples: Optional[Sequence[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]] = None,
200    max_samples: Optional[int] = None,
201    visualize: bool = True,
202    metrics: Optional[Callable] = None,
203    n_threads: Optional[int] = None,
204) -> List[float]:
205    """Validate model for the given checkpoint visually and/or via metrics.
206
207    Args:
208        checkpoint: The path to the checkpoint to evaluate.
209        gpu_ids: The gpu ids to use for prediction.
210        save_path: Optional path for saving the predictions.
211        samples: The samples to use for evaluation. If None, the validation loader of the trainer is used.
212        max_samples: The maximum number of samples to evaluate.
213        visualize: Whether to visualize the predictions with napari.
214        metrics: The metric to use for evaluating the samples.
215        n_threads: The number of threads to use for parallelization.
216
217    Returns:
218        A list of metric scores.
219    """
220    if visualize and napari is None:
221        raise RuntimeError
222
223    trainer = get_trainer(checkpoint, device="cpu")
224    n_threads = trainer.train_loader.num_workers if n_threads is None else n_threads
225    model = trainer.model
226    model.eval()
227
228    need_gt = metrics is not None
229    if samples is None:
230        samples = SampleGenerator(trainer, max_samples, need_gt, n_threads)
231    else:
232        assert isinstance(samples, (list, tuple))
233        if need_gt:
234            assert all(len(sample, 2) for sample in samples)
235        else:
236            assert all(isinstance(sample, np.ndarray) for sample in samples)
237
238    results = []
239    for sample_id, sample in enumerate(samples):
240        raw, gt = sample if need_gt else sample, None
241        pred = _predict(model, raw, trainer, gpu_ids, save_path, sample_id)
242        if visualize:
243            _visualize(raw, pred, gt)
244        if metrics is not None:
245            res = metrics(gt, pred)
246            results.append(res)
247    return results

Validate model for the given checkpoint visually and/or via metrics.

Arguments:
  • checkpoint: The path to the checkpoint to evaluate.
  • gpu_ids: The gpu ids to use for prediction.
  • save_path: Optional path for saving the predictions.
  • samples: The samples to use for evaluation. If None, the validation loader of the trainer is used.
  • max_samples: The maximum number of samples to evaluate.
  • visualize: Whether to visualize the predictions with napari.
  • metrics: The metric to use for evaluating the samples.
  • n_threads: The number of threads to use for parallelization.
Returns:

A list of metric scores.