torch_em.util.validation

  1import os
  2import imageio
  3import numpy as np
  4
  5from elf.io import open_file
  6from elf.util import normalize_index
  7
  8from ..data import ConcatDataset, ImageCollectionDataset, SegmentationDataset
  9from .util import get_trainer, get_normalizer
 10from .prediction import predict_with_halo
 11
 12try:
 13    import napari
 14except ImportError:
 15    napari = None
 16
 17# TODO implement prefab metrics
 18
 19
 20class SampleGenerator:
 21    def __init__(self, trainer, max_samples, need_gt, n_threads):
 22        self.need_gt = need_gt
 23        self.n_threads = n_threads
 24
 25        dataset = trainer.val_loader.dataset
 26        self.ndim = dataset.ndim
 27
 28        (n_samples, load_2d_from_3d, rois,
 29         raw_paths, raw_key,
 30         label_paths, label_key) = self.paths_from_ds(dataset)
 31
 32        if max_samples is None:
 33            self.n_samples = n_samples
 34        else:
 35            self.n_samples = min(max_samples, n_samples)
 36        self.load_2d_from_3d = load_2d_from_3d
 37        self.rois = rois
 38        self.raw_paths, self.raw_key = raw_paths, raw_key
 39        self.label_paths, self.label_key = label_paths, label_key
 40
 41        if self.load_2d_from_3d:
 42            shapes = [
 43                open_file(rp, 'r')[self.raw_key].shape if roi is None else tuple(r.stop - r.start for r in roi)
 44                for rp, roi in zip(self.raw_paths, self.rois)
 45            ]
 46            lens = [shape[0] for shape in shapes]
 47            self.offsets = np.cumsum(lens)
 48
 49    def paths_from_ds(self, dataset):
 50        if isinstance(dataset, ConcatDataset):
 51            datasets = dataset.datasets
 52            (n_samples, load_2d_from_3d, rois,
 53             raw_paths, raw_key,
 54             label_paths, label_key) = self.paths_from_ds(datasets[0])
 55
 56            for ds in datasets[1:]:
 57                ns, l2d3d, bb, rp, rk, lp, lk = self.paths_from_ds(ds)
 58                assert rk == raw_key
 59                assert lk == label_key
 60                assert l2d3d == load_2d_from_3d
 61                raw_paths.extend(rp)
 62                label_paths.extend(lp)
 63                rois.append(bb)
 64                n_samples += ns
 65
 66        elif isinstance(dataset, ImageCollectionDataset):
 67            raw_paths, label_paths = dataset.raw_images, dataset.label_images
 68            raw_key, label_key = None, None
 69            n_samples = len(raw_paths)
 70            load_2d_from_3d = False
 71            rois = [None] * n_samples
 72
 73        elif isinstance(dataset, SegmentationDataset):
 74            raw_paths, label_paths = [dataset.raw_path], [dataset.label_path]
 75            raw_key, label_key = dataset.raw_key, dataset.label_key
 76            shape = open_file(raw_paths[0], 'r')[raw_key].shape
 77
 78            roi = getattr(dataset, 'roi', None)
 79            if roi is not None:
 80                roi = normalize_index(roi, shape)
 81                shape = tuple(r.stop - r.start for r in roi)
 82            rois = [roi]
 83
 84            if self.ndim == len(shape):
 85                n_samples = len(raw_paths)
 86                load_2d_from_3d = False
 87            elif self.ndim == 2 and len(shape) == 3:
 88                n_samples = shape[0]
 89                load_2d_from_3d = True
 90            else:
 91                raise RuntimeError
 92
 93        else:
 94            raise RuntimeError(f"No support for dataset of type {type(dataset)}")
 95
 96        return (n_samples, load_2d_from_3d, rois,
 97                raw_paths, raw_key, label_paths, label_key)
 98
 99    def load_data(self, path, key, roi, z):
100        if key is None:
101            assert roi is None and z is None
102            return imageio.imread(path)
103
104        bb = np.s_[:, :, :] if roi is None else roi
105        if z is not None:
106            bb[0] = z if roi is None else roi[0].start + z
107
108        with open_file(path, 'r') as f:
109            ds = f[key]
110            ds.n_threads = self.n_threads
111            data = ds[bb]
112        return data
113
114    def load_sample(self, sample_id):
115        if self.load_2d_from_3d:
116            ds_id = 0
117            while True:
118                if sample_id < self.offsets[ds_id]:
119                    break
120                ds_id += 1
121            offset = self.offsets[ds_id - 1] if ds_id > 0 else 0
122            z = sample_id - offset
123        else:
124            ds_id = sample_id
125            z = None
126
127        roi = self.rois[ds_id]
128        raw = self.load_data(self.raw_paths[ds_id], self.raw_key, roi, z)
129        if not self.need_gt:
130            return raw
131        gt = self.load_data(self.label_paths[ds_id], self.label_key, roi, z)
132        return raw, gt
133
134    def __iter__(self):
135        for sample_id in range(self.n_samples):
136            sample = self.load_sample(sample_id)
137            yield sample
138
139
140def _predict(model, raw, trainer, gpu_ids, save_path, sample_id):
141
142    save_key = f"sample{sample_id}"
143    if save_path is not None and os.path.exists(save_path):
144        with open_file(save_path, 'r') as f:
145            if save_key in f:
146                print("Loading predictions for sample", sample_id, "from file")
147                ds = f[save_key]
148                ds.n_threads = 8
149                return ds[:]
150
151    normalizer = get_normalizer(trainer)
152    dataset = trainer.val_loader.dataset
153    ndim = dataset.ndim
154    if isinstance(dataset, ConcatDataset):
155        patch_shape = dataset.datasets[0].patch_shape
156    else:
157        patch_shape = dataset.patch_shape
158
159    if ndim == 2 and len(patch_shape) == 3:
160        patch_shape = patch_shape[1:]
161    assert len(patch_shape) == ndim
162
163    # choose a small halo and set the correct block shape
164    halo = (32, 32) if ndim == 2 else (8, 16, 16)
165    block_shape = tuple(psh - 2 * ha for psh, ha in zip(patch_shape, halo))
166
167    if save_path is None:
168        output = None
169    else:
170        f = open_file(save_path, 'a')
171        out_shape = (trainer.model.out_channels,) + raw.shape
172        chunks = (1,) + block_shape
173        output = f.create_dataset(save_key, shape=out_shape, chunks=chunks,
174                                  compression='gzip', dtype='float32')
175
176    gpu_ids = [int(gpu) if gpu != 'cpu' else gpu for gpu in gpu_ids]
177    pred = predict_with_halo(
178        raw, model, gpu_ids, block_shape, halo,
179        preprocess=normalizer,
180        output=output
181    )
182    if output is not None:
183        f.close()
184
185    return pred
186
187
188def _visualize(raw, prediction, ground_truth):
189    with napari.gui_qt():
190        viewer = napari.Viewer()
191        viewer.add_image(raw)
192        viewer.add_image(prediction)
193        if ground_truth is not None:
194            viewer.add_labels(ground_truth)
195
196
197def validate_checkpoint(
198    checkpoint,
199    gpu_ids,
200    save_path=None,
201    samples=None,
202    max_samples=None,
203    visualize=True,
204    metrics=None,
205    n_threads=None
206):
207    """Validate model for the given checkpoint visually and/or via metrics.
208    """
209    if visualize and napari is None:
210        raise RuntimeError
211
212    trainer = get_trainer(checkpoint, device='cpu')
213    n_threads = trainer.train_loader.num_workers if n_threads is None else n_threads
214    model = trainer.model
215    model.eval()
216
217    need_gt = metrics is not None
218    if samples is None:
219        samples = SampleGenerator(trainer, max_samples, need_gt, n_threads)
220    else:
221        assert isinstance(samples, (list, tuple))
222        if need_gt:
223            assert all(len(sample, 2) for sample in samples)
224        else:
225            assert all(isinstance(sample, np.ndarray) for sample in samples)
226
227    results = []
228    for sample_id, sample in enumerate(samples):
229        raw, gt = sample if need_gt else sample, None
230        pred = _predict(model, raw, trainer, gpu_ids, save_path, sample_id)
231        if visualize:
232            _visualize(raw, pred, gt)
233        if metrics is not None:
234            res = metrics(gt, pred)
235            results.append(res)
236    return results
237
238
239def main():
240    import argparse
241    parser = argparse.ArgumentParser()
242    parser.add_argument('-p', '--path', required=True,
243                        help="Path to the checkpoint")
244    parser.add_argument('-g', '--gpus', type=str, nargs='+', required=True)
245    parser.add_argument('-n', '--max_samples', type=int, default=None)
246    parser.add_argument('-d', '--data', default=None)
247    parser.add_argument('-s', '--save_path', default=None)
248    parser.add_argument('-k', '--key', default=None)
249    parser.add_argument('-t', '--n_threads', type=int, default=None)
250
251    args = parser.parse_args()
252    # TODO implement loading data
253    assert args.data is None
254    validate_checkpoint(args.path, args.gpus, args.save_path,
255                        max_samples=args.max_samples,
256                        n_threads=args.n_threads)
class SampleGenerator:
 21class SampleGenerator:
 22    def __init__(self, trainer, max_samples, need_gt, n_threads):
 23        self.need_gt = need_gt
 24        self.n_threads = n_threads
 25
 26        dataset = trainer.val_loader.dataset
 27        self.ndim = dataset.ndim
 28
 29        (n_samples, load_2d_from_3d, rois,
 30         raw_paths, raw_key,
 31         label_paths, label_key) = self.paths_from_ds(dataset)
 32
 33        if max_samples is None:
 34            self.n_samples = n_samples
 35        else:
 36            self.n_samples = min(max_samples, n_samples)
 37        self.load_2d_from_3d = load_2d_from_3d
 38        self.rois = rois
 39        self.raw_paths, self.raw_key = raw_paths, raw_key
 40        self.label_paths, self.label_key = label_paths, label_key
 41
 42        if self.load_2d_from_3d:
 43            shapes = [
 44                open_file(rp, 'r')[self.raw_key].shape if roi is None else tuple(r.stop - r.start for r in roi)
 45                for rp, roi in zip(self.raw_paths, self.rois)
 46            ]
 47            lens = [shape[0] for shape in shapes]
 48            self.offsets = np.cumsum(lens)
 49
 50    def paths_from_ds(self, dataset):
 51        if isinstance(dataset, ConcatDataset):
 52            datasets = dataset.datasets
 53            (n_samples, load_2d_from_3d, rois,
 54             raw_paths, raw_key,
 55             label_paths, label_key) = self.paths_from_ds(datasets[0])
 56
 57            for ds in datasets[1:]:
 58                ns, l2d3d, bb, rp, rk, lp, lk = self.paths_from_ds(ds)
 59                assert rk == raw_key
 60                assert lk == label_key
 61                assert l2d3d == load_2d_from_3d
 62                raw_paths.extend(rp)
 63                label_paths.extend(lp)
 64                rois.append(bb)
 65                n_samples += ns
 66
 67        elif isinstance(dataset, ImageCollectionDataset):
 68            raw_paths, label_paths = dataset.raw_images, dataset.label_images
 69            raw_key, label_key = None, None
 70            n_samples = len(raw_paths)
 71            load_2d_from_3d = False
 72            rois = [None] * n_samples
 73
 74        elif isinstance(dataset, SegmentationDataset):
 75            raw_paths, label_paths = [dataset.raw_path], [dataset.label_path]
 76            raw_key, label_key = dataset.raw_key, dataset.label_key
 77            shape = open_file(raw_paths[0], 'r')[raw_key].shape
 78
 79            roi = getattr(dataset, 'roi', None)
 80            if roi is not None:
 81                roi = normalize_index(roi, shape)
 82                shape = tuple(r.stop - r.start for r in roi)
 83            rois = [roi]
 84
 85            if self.ndim == len(shape):
 86                n_samples = len(raw_paths)
 87                load_2d_from_3d = False
 88            elif self.ndim == 2 and len(shape) == 3:
 89                n_samples = shape[0]
 90                load_2d_from_3d = True
 91            else:
 92                raise RuntimeError
 93
 94        else:
 95            raise RuntimeError(f"No support for dataset of type {type(dataset)}")
 96
 97        return (n_samples, load_2d_from_3d, rois,
 98                raw_paths, raw_key, label_paths, label_key)
 99
100    def load_data(self, path, key, roi, z):
101        if key is None:
102            assert roi is None and z is None
103            return imageio.imread(path)
104
105        bb = np.s_[:, :, :] if roi is None else roi
106        if z is not None:
107            bb[0] = z if roi is None else roi[0].start + z
108
109        with open_file(path, 'r') as f:
110            ds = f[key]
111            ds.n_threads = self.n_threads
112            data = ds[bb]
113        return data
114
115    def load_sample(self, sample_id):
116        if self.load_2d_from_3d:
117            ds_id = 0
118            while True:
119                if sample_id < self.offsets[ds_id]:
120                    break
121                ds_id += 1
122            offset = self.offsets[ds_id - 1] if ds_id > 0 else 0
123            z = sample_id - offset
124        else:
125            ds_id = sample_id
126            z = None
127
128        roi = self.rois[ds_id]
129        raw = self.load_data(self.raw_paths[ds_id], self.raw_key, roi, z)
130        if not self.need_gt:
131            return raw
132        gt = self.load_data(self.label_paths[ds_id], self.label_key, roi, z)
133        return raw, gt
134
135    def __iter__(self):
136        for sample_id in range(self.n_samples):
137            sample = self.load_sample(sample_id)
138            yield sample
SampleGenerator(trainer, max_samples, need_gt, n_threads)
22    def __init__(self, trainer, max_samples, need_gt, n_threads):
23        self.need_gt = need_gt
24        self.n_threads = n_threads
25
26        dataset = trainer.val_loader.dataset
27        self.ndim = dataset.ndim
28
29        (n_samples, load_2d_from_3d, rois,
30         raw_paths, raw_key,
31         label_paths, label_key) = self.paths_from_ds(dataset)
32
33        if max_samples is None:
34            self.n_samples = n_samples
35        else:
36            self.n_samples = min(max_samples, n_samples)
37        self.load_2d_from_3d = load_2d_from_3d
38        self.rois = rois
39        self.raw_paths, self.raw_key = raw_paths, raw_key
40        self.label_paths, self.label_key = label_paths, label_key
41
42        if self.load_2d_from_3d:
43            shapes = [
44                open_file(rp, 'r')[self.raw_key].shape if roi is None else tuple(r.stop - r.start for r in roi)
45                for rp, roi in zip(self.raw_paths, self.rois)
46            ]
47            lens = [shape[0] for shape in shapes]
48            self.offsets = np.cumsum(lens)
need_gt
n_threads
ndim
load_2d_from_3d
rois
def paths_from_ds(self, dataset):
50    def paths_from_ds(self, dataset):
51        if isinstance(dataset, ConcatDataset):
52            datasets = dataset.datasets
53            (n_samples, load_2d_from_3d, rois,
54             raw_paths, raw_key,
55             label_paths, label_key) = self.paths_from_ds(datasets[0])
56
57            for ds in datasets[1:]:
58                ns, l2d3d, bb, rp, rk, lp, lk = self.paths_from_ds(ds)
59                assert rk == raw_key
60                assert lk == label_key
61                assert l2d3d == load_2d_from_3d
62                raw_paths.extend(rp)
63                label_paths.extend(lp)
64                rois.append(bb)
65                n_samples += ns
66
67        elif isinstance(dataset, ImageCollectionDataset):
68            raw_paths, label_paths = dataset.raw_images, dataset.label_images
69            raw_key, label_key = None, None
70            n_samples = len(raw_paths)
71            load_2d_from_3d = False
72            rois = [None] * n_samples
73
74        elif isinstance(dataset, SegmentationDataset):
75            raw_paths, label_paths = [dataset.raw_path], [dataset.label_path]
76            raw_key, label_key = dataset.raw_key, dataset.label_key
77            shape = open_file(raw_paths[0], 'r')[raw_key].shape
78
79            roi = getattr(dataset, 'roi', None)
80            if roi is not None:
81                roi = normalize_index(roi, shape)
82                shape = tuple(r.stop - r.start for r in roi)
83            rois = [roi]
84
85            if self.ndim == len(shape):
86                n_samples = len(raw_paths)
87                load_2d_from_3d = False
88            elif self.ndim == 2 and len(shape) == 3:
89                n_samples = shape[0]
90                load_2d_from_3d = True
91            else:
92                raise RuntimeError
93
94        else:
95            raise RuntimeError(f"No support for dataset of type {type(dataset)}")
96
97        return (n_samples, load_2d_from_3d, rois,
98                raw_paths, raw_key, label_paths, label_key)
def load_data(self, path, key, roi, z):
100    def load_data(self, path, key, roi, z):
101        if key is None:
102            assert roi is None and z is None
103            return imageio.imread(path)
104
105        bb = np.s_[:, :, :] if roi is None else roi
106        if z is not None:
107            bb[0] = z if roi is None else roi[0].start + z
108
109        with open_file(path, 'r') as f:
110            ds = f[key]
111            ds.n_threads = self.n_threads
112            data = ds[bb]
113        return data
def load_sample(self, sample_id):
115    def load_sample(self, sample_id):
116        if self.load_2d_from_3d:
117            ds_id = 0
118            while True:
119                if sample_id < self.offsets[ds_id]:
120                    break
121                ds_id += 1
122            offset = self.offsets[ds_id - 1] if ds_id > 0 else 0
123            z = sample_id - offset
124        else:
125            ds_id = sample_id
126            z = None
127
128        roi = self.rois[ds_id]
129        raw = self.load_data(self.raw_paths[ds_id], self.raw_key, roi, z)
130        if not self.need_gt:
131            return raw
132        gt = self.load_data(self.label_paths[ds_id], self.label_key, roi, z)
133        return raw, gt
def validate_checkpoint( checkpoint, gpu_ids, save_path=None, samples=None, max_samples=None, visualize=True, metrics=None, n_threads=None):
198def validate_checkpoint(
199    checkpoint,
200    gpu_ids,
201    save_path=None,
202    samples=None,
203    max_samples=None,
204    visualize=True,
205    metrics=None,
206    n_threads=None
207):
208    """Validate model for the given checkpoint visually and/or via metrics.
209    """
210    if visualize and napari is None:
211        raise RuntimeError
212
213    trainer = get_trainer(checkpoint, device='cpu')
214    n_threads = trainer.train_loader.num_workers if n_threads is None else n_threads
215    model = trainer.model
216    model.eval()
217
218    need_gt = metrics is not None
219    if samples is None:
220        samples = SampleGenerator(trainer, max_samples, need_gt, n_threads)
221    else:
222        assert isinstance(samples, (list, tuple))
223        if need_gt:
224            assert all(len(sample, 2) for sample in samples)
225        else:
226            assert all(isinstance(sample, np.ndarray) for sample in samples)
227
228    results = []
229    for sample_id, sample in enumerate(samples):
230        raw, gt = sample if need_gt else sample, None
231        pred = _predict(model, raw, trainer, gpu_ids, save_path, sample_id)
232        if visualize:
233            _visualize(raw, pred, gt)
234        if metrics is not None:
235            res = metrics(gt, pred)
236            results.append(res)
237    return results

Validate model for the given checkpoint visually and/or via metrics.

def main():
240def main():
241    import argparse
242    parser = argparse.ArgumentParser()
243    parser.add_argument('-p', '--path', required=True,
244                        help="Path to the checkpoint")
245    parser.add_argument('-g', '--gpus', type=str, nargs='+', required=True)
246    parser.add_argument('-n', '--max_samples', type=int, default=None)
247    parser.add_argument('-d', '--data', default=None)
248    parser.add_argument('-s', '--save_path', default=None)
249    parser.add_argument('-k', '--key', default=None)
250    parser.add_argument('-t', '--n_threads', type=int, default=None)
251
252    args = parser.parse_args()
253    # TODO implement loading data
254    assert args.data is None
255    validate_checkpoint(args.path, args.gpus, args.save_path,
256                        max_samples=args.max_samples,
257                        n_threads=args.n_threads)