torch_em.util.validation
1import os 2from typing import Callable, List, Optional, Sequence, Tuple, Union 3 4import imageio.v3 as imageio 5import numpy as np 6 7from elf.io import open_file 8from elf.util import normalize_index 9 10from .prediction import predict_with_halo 11from .util import get_trainer, get_normalizer 12from ..data import ConcatDataset, ImageCollectionDataset, SegmentationDataset 13 14try: 15 import napari 16except ImportError: 17 napari = None 18 19 20class SampleGenerator: 21 """@private 22 """ 23 def __init__(self, trainer, max_samples, need_gt, n_threads): 24 self.need_gt = need_gt 25 self.n_threads = n_threads 26 27 dataset = trainer.val_loader.dataset 28 self.ndim = dataset.ndim 29 30 (n_samples, load_2d_from_3d, rois, 31 raw_paths, raw_key, 32 label_paths, label_key) = self.paths_from_ds(dataset) 33 34 if max_samples is None: 35 self.n_samples = n_samples 36 else: 37 self.n_samples = min(max_samples, n_samples) 38 self.load_2d_from_3d = load_2d_from_3d 39 self.rois = rois 40 self.raw_paths, self.raw_key = raw_paths, raw_key 41 self.label_paths, self.label_key = label_paths, label_key 42 43 if self.load_2d_from_3d: 44 shapes = [ 45 open_file(rp, "r")[self.raw_key].shape if roi is None else tuple(r.stop - r.start for r in roi) 46 for rp, roi in zip(self.raw_paths, self.rois) 47 ] 48 lens = [shape[0] for shape in shapes] 49 self.offsets = np.cumsum(lens) 50 51 def paths_from_ds(self, dataset): 52 if isinstance(dataset, ConcatDataset): 53 datasets = dataset.datasets 54 ( 55 n_samples, load_2d_from_3d, rois, raw_paths, 56 raw_key, label_paths, label_key 57 ) = self.paths_from_ds(datasets[0]) 58 59 for ds in datasets[1:]: 60 ns, l2d3d, bb, rp, rk, lp, lk = self.paths_from_ds(ds) 61 assert rk == raw_key 62 assert lk == label_key 63 assert l2d3d == load_2d_from_3d 64 raw_paths.extend(rp) 65 label_paths.extend(lp) 66 rois.append(bb) 67 n_samples += ns 68 69 elif isinstance(dataset, ImageCollectionDataset): 70 raw_paths, label_paths = dataset.raw_images, dataset.label_images 71 raw_key, label_key = None, None 72 n_samples = len(raw_paths) 73 load_2d_from_3d = False 74 rois = [None] * n_samples 75 76 elif isinstance(dataset, SegmentationDataset): 77 raw_paths, label_paths = [dataset.raw_path], [dataset.label_path] 78 raw_key, label_key = dataset.raw_key, dataset.label_key 79 shape = open_file(raw_paths[0], 'r')[raw_key].shape 80 81 roi = getattr(dataset, 'roi', None) 82 if roi is not None: 83 roi = normalize_index(roi, shape) 84 shape = tuple(r.stop - r.start for r in roi) 85 rois = [roi] 86 87 if self.ndim == len(shape): 88 n_samples = len(raw_paths) 89 load_2d_from_3d = False 90 elif self.ndim == 2 and len(shape) == 3: 91 n_samples = shape[0] 92 load_2d_from_3d = True 93 else: 94 raise RuntimeError 95 96 else: 97 raise RuntimeError(f"No support for dataset of type {type(dataset)}") 98 99 return (n_samples, load_2d_from_3d, rois, 100 raw_paths, raw_key, label_paths, label_key) 101 102 def load_data(self, path, key, roi, z): 103 if key is None: 104 assert roi is None and z is None 105 return imageio.imread(path) 106 107 bb = np.s_[:, :, :] if roi is None else roi 108 if z is not None: 109 bb[0] = z if roi is None else roi[0].start + z 110 111 with open_file(path, 'r') as f: 112 ds = f[key] 113 ds.n_threads = self.n_threads 114 data = ds[bb] 115 return data 116 117 def load_sample(self, sample_id): 118 if self.load_2d_from_3d: 119 ds_id = 0 120 while True: 121 if sample_id < self.offsets[ds_id]: 122 break 123 ds_id += 1 124 offset = self.offsets[ds_id - 1] if ds_id > 0 else 0 125 z = sample_id - offset 126 else: 127 ds_id = sample_id 128 z = None 129 130 roi = self.rois[ds_id] 131 raw = self.load_data(self.raw_paths[ds_id], self.raw_key, roi, z) 132 if not self.need_gt: 133 return raw 134 gt = self.load_data(self.label_paths[ds_id], self.label_key, roi, z) 135 return raw, gt 136 137 def __iter__(self): 138 for sample_id in range(self.n_samples): 139 sample = self.load_sample(sample_id) 140 yield sample 141 142 143def _predict(model, raw, trainer, gpu_ids, save_path, sample_id): 144 save_key = f"sample{sample_id}" 145 if save_path is not None and os.path.exists(save_path): 146 with open_file(save_path, "r") as f: 147 if save_key in f: 148 print("Loading predictions for sample", sample_id, "from file") 149 ds = f[save_key] 150 ds.n_threads = 8 151 return ds[:] 152 153 normalizer = get_normalizer(trainer) 154 dataset = trainer.val_loader.dataset 155 ndim = dataset.ndim 156 if isinstance(dataset, ConcatDataset): 157 patch_shape = dataset.datasets[0].patch_shape 158 else: 159 patch_shape = dataset.patch_shape 160 161 if ndim == 2 and len(patch_shape) == 3: 162 patch_shape = patch_shape[1:] 163 assert len(patch_shape) == ndim 164 165 # choose a small halo and set the correct block shape 166 halo = (32, 32) if ndim == 2 else (8, 16, 16) 167 block_shape = tuple(psh - 2 * ha for psh, ha in zip(patch_shape, halo)) 168 169 if save_path is None: 170 output = None 171 else: 172 f = open_file(save_path, "a") 173 out_shape = (trainer.model.out_channels,) + raw.shape 174 chunks = (1,) + block_shape 175 output = f.create_dataset(save_key, shape=out_shape, chunks=chunks, compression="gzip", dtype="float32") 176 177 gpu_ids = [int(gpu) if gpu != "cpu" else gpu for gpu in gpu_ids] 178 pred = predict_with_halo(raw, model, gpu_ids, block_shape, halo, preprocess=normalizer, output=output) 179 if output is not None: 180 f.close() 181 182 return pred 183 184 185def _visualize(raw, prediction, ground_truth): 186 with napari.gui_qt(): 187 viewer = napari.Viewer() 188 viewer.add_image(raw) 189 viewer.add_image(prediction) 190 if ground_truth is not None: 191 viewer.add_labels(ground_truth) 192 193 194def validate_checkpoint( 195 checkpoint: str, 196 gpu_ids: List[int], 197 save_path: Optional[str] = None, 198 samples: Optional[Sequence[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]] = None, 199 max_samples: Optional[int] = None, 200 visualize: bool = True, 201 metrics: Optional[Callable] = None, 202 n_threads: Optional[int] = None, 203) -> List[float]: 204 """Validate model for the given checkpoint visually and/or via metrics. 205 206 Args: 207 checkpoint: The path to the checkpoint to evaluate. 208 gpu_ids: The gpu ids to use for prediction. 209 save_path: Optional path for saving the predictions. 210 samples: The samples to use for evaluation. If None, the validation loader of the trainer is used. 211 max_samples: The maximum number of samples to evaluate. 212 visualize: Whether to visualize the predictions with napari. 213 metrics: The metric to use for evaluating the samples. 214 n_threads: The number of threads to use for parallelization. 215 216 Returns: 217 A list of metric scores. 218 """ 219 if visualize and napari is None: 220 raise RuntimeError 221 222 trainer = get_trainer(checkpoint, device="cpu") 223 n_threads = trainer.train_loader.num_workers if n_threads is None else n_threads 224 model = trainer.model 225 model.eval() 226 227 need_gt = metrics is not None 228 if samples is None: 229 samples = SampleGenerator(trainer, max_samples, need_gt, n_threads) 230 else: 231 assert isinstance(samples, (list, tuple)) 232 if need_gt: 233 assert all(len(sample, 2) for sample in samples) 234 else: 235 assert all(isinstance(sample, np.ndarray) for sample in samples) 236 237 results = [] 238 for sample_id, sample in enumerate(samples): 239 raw, gt = sample if need_gt else sample, None 240 pred = _predict(model, raw, trainer, gpu_ids, save_path, sample_id) 241 if visualize: 242 _visualize(raw, pred, gt) 243 if metrics is not None: 244 res = metrics(gt, pred) 245 results.append(res) 246 return results 247 248 249def main(): 250 """@private 251 """ 252 import argparse 253 parser = argparse.ArgumentParser() 254 parser.add_argument("-p", "--path", required=True, help="Path to the checkpoint") 255 parser.add_argument("-g", "--gpus", type=str, nargs="+", required=True) 256 parser.add_argument("-n", "--max_samples", type=int, default=None) 257 parser.add_argument("-d", "--data", default=None) 258 parser.add_argument("-s", "--save_path", default=None) 259 parser.add_argument("-k", "--key", default=None) 260 parser.add_argument("-t", "--n_threads", type=int, default=None) 261 262 args = parser.parse_args() 263 # TODO implement loading data 264 assert args.data is None 265 validate_checkpoint(args.path, args.gpus, args.save_path, max_samples=args.max_samples, n_threads=args.n_threads)
def
validate_checkpoint( checkpoint: str, gpu_ids: List[int], save_path: Optional[str] = None, samples: Optional[Sequence[Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]]] = None, max_samples: Optional[int] = None, visualize: bool = True, metrics: Optional[Callable] = None, n_threads: Optional[int] = None) -> List[float]:
195def validate_checkpoint( 196 checkpoint: str, 197 gpu_ids: List[int], 198 save_path: Optional[str] = None, 199 samples: Optional[Sequence[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]] = None, 200 max_samples: Optional[int] = None, 201 visualize: bool = True, 202 metrics: Optional[Callable] = None, 203 n_threads: Optional[int] = None, 204) -> List[float]: 205 """Validate model for the given checkpoint visually and/or via metrics. 206 207 Args: 208 checkpoint: The path to the checkpoint to evaluate. 209 gpu_ids: The gpu ids to use for prediction. 210 save_path: Optional path for saving the predictions. 211 samples: The samples to use for evaluation. If None, the validation loader of the trainer is used. 212 max_samples: The maximum number of samples to evaluate. 213 visualize: Whether to visualize the predictions with napari. 214 metrics: The metric to use for evaluating the samples. 215 n_threads: The number of threads to use for parallelization. 216 217 Returns: 218 A list of metric scores. 219 """ 220 if visualize and napari is None: 221 raise RuntimeError 222 223 trainer = get_trainer(checkpoint, device="cpu") 224 n_threads = trainer.train_loader.num_workers if n_threads is None else n_threads 225 model = trainer.model 226 model.eval() 227 228 need_gt = metrics is not None 229 if samples is None: 230 samples = SampleGenerator(trainer, max_samples, need_gt, n_threads) 231 else: 232 assert isinstance(samples, (list, tuple)) 233 if need_gt: 234 assert all(len(sample, 2) for sample in samples) 235 else: 236 assert all(isinstance(sample, np.ndarray) for sample in samples) 237 238 results = [] 239 for sample_id, sample in enumerate(samples): 240 raw, gt = sample if need_gt else sample, None 241 pred = _predict(model, raw, trainer, gpu_ids, save_path, sample_id) 242 if visualize: 243 _visualize(raw, pred, gt) 244 if metrics is not None: 245 res = metrics(gt, pred) 246 results.append(res) 247 return results
Validate model for the given checkpoint visually and/or via metrics.
Arguments:
- checkpoint: The path to the checkpoint to evaluate.
- gpu_ids: The gpu ids to use for prediction.
- save_path: Optional path for saving the predictions.
- samples: The samples to use for evaluation. If None, the validation loader of the trainer is used.
- max_samples: The maximum number of samples to evaluate.
- visualize: Whether to visualize the predictions with napari.
- metrics: The metric to use for evaluating the samples.
- n_threads: The number of threads to use for parallelization.
Returns:
A list of metric scores.