torch_em.util.validation
1import os 2import imageio 3import numpy as np 4 5from elf.io import open_file 6from elf.util import normalize_index 7 8from ..data import ConcatDataset, ImageCollectionDataset, SegmentationDataset 9from .util import get_trainer, get_normalizer 10from .prediction import predict_with_halo 11 12try: 13 import napari 14except ImportError: 15 napari = None 16 17# TODO implement prefab metrics 18 19 20class SampleGenerator: 21 def __init__(self, trainer, max_samples, need_gt, n_threads): 22 self.need_gt = need_gt 23 self.n_threads = n_threads 24 25 dataset = trainer.val_loader.dataset 26 self.ndim = dataset.ndim 27 28 (n_samples, load_2d_from_3d, rois, 29 raw_paths, raw_key, 30 label_paths, label_key) = self.paths_from_ds(dataset) 31 32 if max_samples is None: 33 self.n_samples = n_samples 34 else: 35 self.n_samples = min(max_samples, n_samples) 36 self.load_2d_from_3d = load_2d_from_3d 37 self.rois = rois 38 self.raw_paths, self.raw_key = raw_paths, raw_key 39 self.label_paths, self.label_key = label_paths, label_key 40 41 if self.load_2d_from_3d: 42 shapes = [ 43 open_file(rp, 'r')[self.raw_key].shape if roi is None else tuple(r.stop - r.start for r in roi) 44 for rp, roi in zip(self.raw_paths, self.rois) 45 ] 46 lens = [shape[0] for shape in shapes] 47 self.offsets = np.cumsum(lens) 48 49 def paths_from_ds(self, dataset): 50 if isinstance(dataset, ConcatDataset): 51 datasets = dataset.datasets 52 (n_samples, load_2d_from_3d, rois, 53 raw_paths, raw_key, 54 label_paths, label_key) = self.paths_from_ds(datasets[0]) 55 56 for ds in datasets[1:]: 57 ns, l2d3d, bb, rp, rk, lp, lk = self.paths_from_ds(ds) 58 assert rk == raw_key 59 assert lk == label_key 60 assert l2d3d == load_2d_from_3d 61 raw_paths.extend(rp) 62 label_paths.extend(lp) 63 rois.append(bb) 64 n_samples += ns 65 66 elif isinstance(dataset, ImageCollectionDataset): 67 raw_paths, label_paths = dataset.raw_images, dataset.label_images 68 raw_key, label_key = None, None 69 n_samples = len(raw_paths) 70 load_2d_from_3d = False 71 rois = [None] * n_samples 72 73 elif isinstance(dataset, SegmentationDataset): 74 raw_paths, label_paths = [dataset.raw_path], [dataset.label_path] 75 raw_key, label_key = dataset.raw_key, dataset.label_key 76 shape = open_file(raw_paths[0], 'r')[raw_key].shape 77 78 roi = getattr(dataset, 'roi', None) 79 if roi is not None: 80 roi = normalize_index(roi, shape) 81 shape = tuple(r.stop - r.start for r in roi) 82 rois = [roi] 83 84 if self.ndim == len(shape): 85 n_samples = len(raw_paths) 86 load_2d_from_3d = False 87 elif self.ndim == 2 and len(shape) == 3: 88 n_samples = shape[0] 89 load_2d_from_3d = True 90 else: 91 raise RuntimeError 92 93 else: 94 raise RuntimeError(f"No support for dataset of type {type(dataset)}") 95 96 return (n_samples, load_2d_from_3d, rois, 97 raw_paths, raw_key, label_paths, label_key) 98 99 def load_data(self, path, key, roi, z): 100 if key is None: 101 assert roi is None and z is None 102 return imageio.imread(path) 103 104 bb = np.s_[:, :, :] if roi is None else roi 105 if z is not None: 106 bb[0] = z if roi is None else roi[0].start + z 107 108 with open_file(path, 'r') as f: 109 ds = f[key] 110 ds.n_threads = self.n_threads 111 data = ds[bb] 112 return data 113 114 def load_sample(self, sample_id): 115 if self.load_2d_from_3d: 116 ds_id = 0 117 while True: 118 if sample_id < self.offsets[ds_id]: 119 break 120 ds_id += 1 121 offset = self.offsets[ds_id - 1] if ds_id > 0 else 0 122 z = sample_id - offset 123 else: 124 ds_id = sample_id 125 z = None 126 127 roi = self.rois[ds_id] 128 raw = self.load_data(self.raw_paths[ds_id], self.raw_key, roi, z) 129 if not self.need_gt: 130 return raw 131 gt = self.load_data(self.label_paths[ds_id], self.label_key, roi, z) 132 return raw, gt 133 134 def __iter__(self): 135 for sample_id in range(self.n_samples): 136 sample = self.load_sample(sample_id) 137 yield sample 138 139 140def _predict(model, raw, trainer, gpu_ids, save_path, sample_id): 141 142 save_key = f"sample{sample_id}" 143 if save_path is not None and os.path.exists(save_path): 144 with open_file(save_path, 'r') as f: 145 if save_key in f: 146 print("Loading predictions for sample", sample_id, "from file") 147 ds = f[save_key] 148 ds.n_threads = 8 149 return ds[:] 150 151 normalizer = get_normalizer(trainer) 152 dataset = trainer.val_loader.dataset 153 ndim = dataset.ndim 154 if isinstance(dataset, ConcatDataset): 155 patch_shape = dataset.datasets[0].patch_shape 156 else: 157 patch_shape = dataset.patch_shape 158 159 if ndim == 2 and len(patch_shape) == 3: 160 patch_shape = patch_shape[1:] 161 assert len(patch_shape) == ndim 162 163 # choose a small halo and set the correct block shape 164 halo = (32, 32) if ndim == 2 else (8, 16, 16) 165 block_shape = tuple(psh - 2 * ha for psh, ha in zip(patch_shape, halo)) 166 167 if save_path is None: 168 output = None 169 else: 170 f = open_file(save_path, 'a') 171 out_shape = (trainer.model.out_channels,) + raw.shape 172 chunks = (1,) + block_shape 173 output = f.create_dataset(save_key, shape=out_shape, chunks=chunks, 174 compression='gzip', dtype='float32') 175 176 gpu_ids = [int(gpu) if gpu != 'cpu' else gpu for gpu in gpu_ids] 177 pred = predict_with_halo( 178 raw, model, gpu_ids, block_shape, halo, 179 preprocess=normalizer, 180 output=output 181 ) 182 if output is not None: 183 f.close() 184 185 return pred 186 187 188def _visualize(raw, prediction, ground_truth): 189 with napari.gui_qt(): 190 viewer = napari.Viewer() 191 viewer.add_image(raw) 192 viewer.add_image(prediction) 193 if ground_truth is not None: 194 viewer.add_labels(ground_truth) 195 196 197def validate_checkpoint( 198 checkpoint, 199 gpu_ids, 200 save_path=None, 201 samples=None, 202 max_samples=None, 203 visualize=True, 204 metrics=None, 205 n_threads=None 206): 207 """Validate model for the given checkpoint visually and/or via metrics. 208 """ 209 if visualize and napari is None: 210 raise RuntimeError 211 212 trainer = get_trainer(checkpoint, device='cpu') 213 n_threads = trainer.train_loader.num_workers if n_threads is None else n_threads 214 model = trainer.model 215 model.eval() 216 217 need_gt = metrics is not None 218 if samples is None: 219 samples = SampleGenerator(trainer, max_samples, need_gt, n_threads) 220 else: 221 assert isinstance(samples, (list, tuple)) 222 if need_gt: 223 assert all(len(sample, 2) for sample in samples) 224 else: 225 assert all(isinstance(sample, np.ndarray) for sample in samples) 226 227 results = [] 228 for sample_id, sample in enumerate(samples): 229 raw, gt = sample if need_gt else sample, None 230 pred = _predict(model, raw, trainer, gpu_ids, save_path, sample_id) 231 if visualize: 232 _visualize(raw, pred, gt) 233 if metrics is not None: 234 res = metrics(gt, pred) 235 results.append(res) 236 return results 237 238 239def main(): 240 import argparse 241 parser = argparse.ArgumentParser() 242 parser.add_argument('-p', '--path', required=True, 243 help="Path to the checkpoint") 244 parser.add_argument('-g', '--gpus', type=str, nargs='+', required=True) 245 parser.add_argument('-n', '--max_samples', type=int, default=None) 246 parser.add_argument('-d', '--data', default=None) 247 parser.add_argument('-s', '--save_path', default=None) 248 parser.add_argument('-k', '--key', default=None) 249 parser.add_argument('-t', '--n_threads', type=int, default=None) 250 251 args = parser.parse_args() 252 # TODO implement loading data 253 assert args.data is None 254 validate_checkpoint(args.path, args.gpus, args.save_path, 255 max_samples=args.max_samples, 256 n_threads=args.n_threads)
class
SampleGenerator:
21class SampleGenerator: 22 def __init__(self, trainer, max_samples, need_gt, n_threads): 23 self.need_gt = need_gt 24 self.n_threads = n_threads 25 26 dataset = trainer.val_loader.dataset 27 self.ndim = dataset.ndim 28 29 (n_samples, load_2d_from_3d, rois, 30 raw_paths, raw_key, 31 label_paths, label_key) = self.paths_from_ds(dataset) 32 33 if max_samples is None: 34 self.n_samples = n_samples 35 else: 36 self.n_samples = min(max_samples, n_samples) 37 self.load_2d_from_3d = load_2d_from_3d 38 self.rois = rois 39 self.raw_paths, self.raw_key = raw_paths, raw_key 40 self.label_paths, self.label_key = label_paths, label_key 41 42 if self.load_2d_from_3d: 43 shapes = [ 44 open_file(rp, 'r')[self.raw_key].shape if roi is None else tuple(r.stop - r.start for r in roi) 45 for rp, roi in zip(self.raw_paths, self.rois) 46 ] 47 lens = [shape[0] for shape in shapes] 48 self.offsets = np.cumsum(lens) 49 50 def paths_from_ds(self, dataset): 51 if isinstance(dataset, ConcatDataset): 52 datasets = dataset.datasets 53 (n_samples, load_2d_from_3d, rois, 54 raw_paths, raw_key, 55 label_paths, label_key) = self.paths_from_ds(datasets[0]) 56 57 for ds in datasets[1:]: 58 ns, l2d3d, bb, rp, rk, lp, lk = self.paths_from_ds(ds) 59 assert rk == raw_key 60 assert lk == label_key 61 assert l2d3d == load_2d_from_3d 62 raw_paths.extend(rp) 63 label_paths.extend(lp) 64 rois.append(bb) 65 n_samples += ns 66 67 elif isinstance(dataset, ImageCollectionDataset): 68 raw_paths, label_paths = dataset.raw_images, dataset.label_images 69 raw_key, label_key = None, None 70 n_samples = len(raw_paths) 71 load_2d_from_3d = False 72 rois = [None] * n_samples 73 74 elif isinstance(dataset, SegmentationDataset): 75 raw_paths, label_paths = [dataset.raw_path], [dataset.label_path] 76 raw_key, label_key = dataset.raw_key, dataset.label_key 77 shape = open_file(raw_paths[0], 'r')[raw_key].shape 78 79 roi = getattr(dataset, 'roi', None) 80 if roi is not None: 81 roi = normalize_index(roi, shape) 82 shape = tuple(r.stop - r.start for r in roi) 83 rois = [roi] 84 85 if self.ndim == len(shape): 86 n_samples = len(raw_paths) 87 load_2d_from_3d = False 88 elif self.ndim == 2 and len(shape) == 3: 89 n_samples = shape[0] 90 load_2d_from_3d = True 91 else: 92 raise RuntimeError 93 94 else: 95 raise RuntimeError(f"No support for dataset of type {type(dataset)}") 96 97 return (n_samples, load_2d_from_3d, rois, 98 raw_paths, raw_key, label_paths, label_key) 99 100 def load_data(self, path, key, roi, z): 101 if key is None: 102 assert roi is None and z is None 103 return imageio.imread(path) 104 105 bb = np.s_[:, :, :] if roi is None else roi 106 if z is not None: 107 bb[0] = z if roi is None else roi[0].start + z 108 109 with open_file(path, 'r') as f: 110 ds = f[key] 111 ds.n_threads = self.n_threads 112 data = ds[bb] 113 return data 114 115 def load_sample(self, sample_id): 116 if self.load_2d_from_3d: 117 ds_id = 0 118 while True: 119 if sample_id < self.offsets[ds_id]: 120 break 121 ds_id += 1 122 offset = self.offsets[ds_id - 1] if ds_id > 0 else 0 123 z = sample_id - offset 124 else: 125 ds_id = sample_id 126 z = None 127 128 roi = self.rois[ds_id] 129 raw = self.load_data(self.raw_paths[ds_id], self.raw_key, roi, z) 130 if not self.need_gt: 131 return raw 132 gt = self.load_data(self.label_paths[ds_id], self.label_key, roi, z) 133 return raw, gt 134 135 def __iter__(self): 136 for sample_id in range(self.n_samples): 137 sample = self.load_sample(sample_id) 138 yield sample
SampleGenerator(trainer, max_samples, need_gt, n_threads)
22 def __init__(self, trainer, max_samples, need_gt, n_threads): 23 self.need_gt = need_gt 24 self.n_threads = n_threads 25 26 dataset = trainer.val_loader.dataset 27 self.ndim = dataset.ndim 28 29 (n_samples, load_2d_from_3d, rois, 30 raw_paths, raw_key, 31 label_paths, label_key) = self.paths_from_ds(dataset) 32 33 if max_samples is None: 34 self.n_samples = n_samples 35 else: 36 self.n_samples = min(max_samples, n_samples) 37 self.load_2d_from_3d = load_2d_from_3d 38 self.rois = rois 39 self.raw_paths, self.raw_key = raw_paths, raw_key 40 self.label_paths, self.label_key = label_paths, label_key 41 42 if self.load_2d_from_3d: 43 shapes = [ 44 open_file(rp, 'r')[self.raw_key].shape if roi is None else tuple(r.stop - r.start for r in roi) 45 for rp, roi in zip(self.raw_paths, self.rois) 46 ] 47 lens = [shape[0] for shape in shapes] 48 self.offsets = np.cumsum(lens)
def
paths_from_ds(self, dataset):
50 def paths_from_ds(self, dataset): 51 if isinstance(dataset, ConcatDataset): 52 datasets = dataset.datasets 53 (n_samples, load_2d_from_3d, rois, 54 raw_paths, raw_key, 55 label_paths, label_key) = self.paths_from_ds(datasets[0]) 56 57 for ds in datasets[1:]: 58 ns, l2d3d, bb, rp, rk, lp, lk = self.paths_from_ds(ds) 59 assert rk == raw_key 60 assert lk == label_key 61 assert l2d3d == load_2d_from_3d 62 raw_paths.extend(rp) 63 label_paths.extend(lp) 64 rois.append(bb) 65 n_samples += ns 66 67 elif isinstance(dataset, ImageCollectionDataset): 68 raw_paths, label_paths = dataset.raw_images, dataset.label_images 69 raw_key, label_key = None, None 70 n_samples = len(raw_paths) 71 load_2d_from_3d = False 72 rois = [None] * n_samples 73 74 elif isinstance(dataset, SegmentationDataset): 75 raw_paths, label_paths = [dataset.raw_path], [dataset.label_path] 76 raw_key, label_key = dataset.raw_key, dataset.label_key 77 shape = open_file(raw_paths[0], 'r')[raw_key].shape 78 79 roi = getattr(dataset, 'roi', None) 80 if roi is not None: 81 roi = normalize_index(roi, shape) 82 shape = tuple(r.stop - r.start for r in roi) 83 rois = [roi] 84 85 if self.ndim == len(shape): 86 n_samples = len(raw_paths) 87 load_2d_from_3d = False 88 elif self.ndim == 2 and len(shape) == 3: 89 n_samples = shape[0] 90 load_2d_from_3d = True 91 else: 92 raise RuntimeError 93 94 else: 95 raise RuntimeError(f"No support for dataset of type {type(dataset)}") 96 97 return (n_samples, load_2d_from_3d, rois, 98 raw_paths, raw_key, label_paths, label_key)
def
load_data(self, path, key, roi, z):
100 def load_data(self, path, key, roi, z): 101 if key is None: 102 assert roi is None and z is None 103 return imageio.imread(path) 104 105 bb = np.s_[:, :, :] if roi is None else roi 106 if z is not None: 107 bb[0] = z if roi is None else roi[0].start + z 108 109 with open_file(path, 'r') as f: 110 ds = f[key] 111 ds.n_threads = self.n_threads 112 data = ds[bb] 113 return data
def
load_sample(self, sample_id):
115 def load_sample(self, sample_id): 116 if self.load_2d_from_3d: 117 ds_id = 0 118 while True: 119 if sample_id < self.offsets[ds_id]: 120 break 121 ds_id += 1 122 offset = self.offsets[ds_id - 1] if ds_id > 0 else 0 123 z = sample_id - offset 124 else: 125 ds_id = sample_id 126 z = None 127 128 roi = self.rois[ds_id] 129 raw = self.load_data(self.raw_paths[ds_id], self.raw_key, roi, z) 130 if not self.need_gt: 131 return raw 132 gt = self.load_data(self.label_paths[ds_id], self.label_key, roi, z) 133 return raw, gt
def
validate_checkpoint( checkpoint, gpu_ids, save_path=None, samples=None, max_samples=None, visualize=True, metrics=None, n_threads=None):
198def validate_checkpoint( 199 checkpoint, 200 gpu_ids, 201 save_path=None, 202 samples=None, 203 max_samples=None, 204 visualize=True, 205 metrics=None, 206 n_threads=None 207): 208 """Validate model for the given checkpoint visually and/or via metrics. 209 """ 210 if visualize and napari is None: 211 raise RuntimeError 212 213 trainer = get_trainer(checkpoint, device='cpu') 214 n_threads = trainer.train_loader.num_workers if n_threads is None else n_threads 215 model = trainer.model 216 model.eval() 217 218 need_gt = metrics is not None 219 if samples is None: 220 samples = SampleGenerator(trainer, max_samples, need_gt, n_threads) 221 else: 222 assert isinstance(samples, (list, tuple)) 223 if need_gt: 224 assert all(len(sample, 2) for sample in samples) 225 else: 226 assert all(isinstance(sample, np.ndarray) for sample in samples) 227 228 results = [] 229 for sample_id, sample in enumerate(samples): 230 raw, gt = sample if need_gt else sample, None 231 pred = _predict(model, raw, trainer, gpu_ids, save_path, sample_id) 232 if visualize: 233 _visualize(raw, pred, gt) 234 if metrics is not None: 235 res = metrics(gt, pred) 236 results.append(res) 237 return results
Validate model for the given checkpoint visually and/or via metrics.
def
main():
240def main(): 241 import argparse 242 parser = argparse.ArgumentParser() 243 parser.add_argument('-p', '--path', required=True, 244 help="Path to the checkpoint") 245 parser.add_argument('-g', '--gpus', type=str, nargs='+', required=True) 246 parser.add_argument('-n', '--max_samples', type=int, default=None) 247 parser.add_argument('-d', '--data', default=None) 248 parser.add_argument('-s', '--save_path', default=None) 249 parser.add_argument('-k', '--key', default=None) 250 parser.add_argument('-t', '--n_threads', type=int, default=None) 251 252 args = parser.parse_args() 253 # TODO implement loading data 254 assert args.data is None 255 validate_checkpoint(args.path, args.gpus, args.save_path, 256 max_samples=args.max_samples, 257 n_threads=args.n_threads)