torch_em.shallow2deep.shallow2deep_eval
1import os 2import pickle 3from concurrent import futures 4from glob import glob 5from multiprocessing import cpu_count 6 7import numpy as np 8import pandas as pd 9import elf.io as io 10from tqdm import tqdm, trange 11 12from .prepare_shallow2deep import _apply_filters, _get_filters 13from .shallow2deep_model import IlastikPredicter 14 15 16def visualize_pretrained_rfs(checkpoint, raw, n_forests, 17 sample_random=False, filter_config=None, n_threads=None): 18 """Visualize pretrained random forests from a shallow2depp checkpoint. 19 20 Arguments: 21 checkpoint [str] - path to the checkpoint folder 22 raw [np.ndarray] - the raw data for prediction 23 n_forests [int] - the number of forests to use 24 sample_random [bool] - whether to subsample forests randomly or regularly (default: False) 25 filter_config [list] - the filter configuration (default: None) 26 n_threads [int] - number of threads for parallel prediction of forests (default: None) 27 """ 28 import napari 29 30 rf_paths = glob(os.path.join(checkpoint, "*.pkl")) 31 if len(rf_paths) == 0: 32 rf_folder = os.path.join(checkpoint, "rfs") 33 assert os.path.exists(rf_folder), rf_folder 34 rf_paths = glob(os.path.join(rf_folder, "*.pkl")) 35 assert len(rf_paths) > 0 36 rf_paths.sort() 37 if sample_random: 38 rf_paths = np.random.choice(rf_paths, size=n_forests) 39 else: 40 rf_paths = rf_paths[::(len(rf_paths) // n_forests)][:n_forests] 41 42 print("Compute features for input of shape", raw.shape) 43 filter_config = _get_filters(raw.ndim, filter_config) 44 features = _apply_filters(raw, filter_config) 45 46 def predict_rf(rf_path): 47 with open(rf_path, "rb") as f: 48 rf = pickle.load(f) 49 pred = rf.predict_proba(features) 50 pred = pred.reshape(raw.shape + (pred.shape[1],)) 51 pred = np.moveaxis(pred, -1, 0) 52 assert pred.shape[1:] == raw.shape 53 return pred 54 55 n_threads = cpu_count() if n_threads is None else n_threads 56 with futures.ThreadPoolExecutor(n_threads) as tp: 57 preds = list(tqdm(tp.map(predict_rf, rf_paths), desc="Predict RFs", total=len(rf_paths))) 58 59 print("Start viewer") 60 v = napari.Viewer() 61 for path, pred in zip(rf_paths, preds): 62 name = os.path.basename(path) 63 v.add_image(pred, name=name) 64 v.add_image(raw) 65 v.grid.enabled = True 66 napari.run() 67 68 69def evaluate_enhancers(data, labels, enhancers, ilastik_projects, metric, 70 prediction_function=None, rf_channel=1, is2d=False, save_path=None): 71 """Evaluate enhancers on ilastik random forests from multiple projects. 72 73 Arguments: 74 data [np.ndarray] - the data for evaluation 75 labels [np.ndarray] - the labels for evaluation 76 enhancers [dict[str, str] - map of enhancer names to filepath with enhancer 77 models saved in the biomage.io model format 78 ilastik_projects [dict[str, str]] - map of names to ilastik project paths 79 metric [callable] - the metric used for evaluation 80 prediction_function [callable] - function to run prediction with the enhancer. 81 By default the bioimageio.prediction pipeline is called directly. 82 If given, needs to take the prediction pipeline and data (as xarray) 83 as input (default: None) 84 rf_channel [int, list[int]] - the channel(s) of the random forest to be passed 85 as input to the enhancer (default: 1) 86 is2d [bool] - whether to process 3d data as individual slices and average the scores. 87 Is ignored if the data is 2d (default: False) 88 save_path [str] - 89 Returns: 90 [pd.DataFrame] - a table with the scores of the enhancers for the different forests 91 and scores of the raw forest predictions 92 """ 93 import bioimageio.core 94 import xarray 95 96 assert data.shape == labels.shape 97 ndim = data.ndim 98 model_ndim = 2 if (data.ndim == 2 or is2d) else 3 99 100 def load_enhancer(enh): 101 model = bioimageio.core.load_resource_description(enh) 102 return bioimageio.core.create_prediction_pipeline(model) 103 104 # load the enhancers 105 models = {name: load_enhancer(enh) for name, enh in enhancers.items()} 106 107 # load the ilps 108 ilps = { 109 name: IlastikPredicter(path, model_ndim, ilastik_multi_thread=True, output_channel=rf_channel) 110 for name, path in ilastik_projects.items() 111 } 112 113 def require_rf_prediction(rf, input_, name, axes): 114 if save_path is None: 115 return rf(input_) 116 with io.open_file(save_path, "a") as f: 117 if name in f: 118 pred = f[name][:] 119 else: 120 pred = rf(input_) 121 # require len(axes) + 2 dimensions (additional batch and channel axis) 122 pred = pred[(None,) * (len(axes) + 2 - pred.ndim)] 123 assert pred.ndim == len(axes) + 2, f"{pred.ndim}, {len(axes) + 2}" 124 f.create_dataset(name, data=pred, compression="gzip") 125 return pred 126 127 def require_enh_prediction(enh, rf_pred, name, prediction_function, axes): 128 if save_path is None: 129 pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred) 130 pred = pred[0] 131 return pred 132 with io.open_file(save_path, "a") as f: 133 if name in f: 134 pred = f[name][:] 135 else: 136 rf_pred = xarray.DataArray(rf_pred, dims=("b", "c",) + tuple(axes)) 137 pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred) 138 pred = pred[0] 139 f.create_dataset(name, data=pred, compression="gzip") 140 return pred 141 142 def process_chunk(x, y, axes, z=None): 143 scores = np.zeros((len(models) + 1, len(ilps))) 144 for i, (rf_name, ilp) in enumerate(ilps.items()): 145 rf_pred = require_rf_prediction( 146 ilp, x, 147 rf_name if z is None else f"{rf_name}/{z:04}", 148 axes 149 ) 150 for j, (enh_name, enh) in enumerate(models.items()): 151 pred = require_enh_prediction( 152 enh, rf_pred, 153 f"{enh_name}/{rf_name}" if z is None else f"{enh_name}/{rf_name}/{z:04}", 154 prediction_function, 155 axes 156 ) 157 score = metric(pred, y) 158 scores[j, i] = score 159 score = metric(rf_pred, y) 160 scores[-1, i] = score 161 162 scores = pd.DataFrame(scores, columns=list(ilps.keys())) 163 scores.insert(loc=0, column="enhancer", value=list(models.keys()) + ["rf-score"]) 164 return scores 165 166 # if we have 2d data, or 3d data that is processed en block, 167 # we only have to process a single 'chunk' 168 if ndim == 2 or (ndim == 3 and not is2d): 169 scores = process_chunk(data, labels, "yx" if ndim == 2 else "zyx") 170 elif ndim == 3 and is2d: 171 scores = [] 172 for z in trange(data.shape[0]): 173 scores_z = process_chunk(data[z], labels[z], "yx", z) 174 scores.append(scores_z) 175 scores = pd.concat(scores).groupby("enhancer").mean() 176 else: 177 raise ValueError("Invalid data dimensions: {ndim}") 178 179 return scores 180 181 182def load_predictions(save_path, n_threads=1): 183 """Helper functions to load predictions from a save_path created by evaluate_enhancers 184 """ 185 predictions = {} 186 187 def visit(name, node): 188 if io.is_group(node): 189 return 190 node.n_threads = n_threads 191 # if we store with 'is2d' individual slices are datasets 192 try: 193 data_name = "/".join(name.split("/")[:-1]) 194 z = int(name.split("/")[-1]) 195 data = node[:] 196 pred = predictions.get(data_name, {}) 197 pred[z] = data 198 predictions[data_name] = pred 199 # otherwise the above will throw a val error and we just load the array 200 except ValueError: 201 predictions[name] = node[:] 202 203 with io.open_file(save_path, "r") as f: 204 f.visititems(visit) 205 206 def to_vol(pred): 207 if isinstance(pred, np.ndarray): 208 return pred 209 pred = dict(sorted(pred.items())) 210 return np.concatenate([pz[None] for pz in pred.values()], axis=0) 211 212 predictions = {name: to_vol(pred) for name, pred in predictions.items()} 213 214 return predictions
def
visualize_pretrained_rfs( checkpoint, raw, n_forests, sample_random=False, filter_config=None, n_threads=None):
17def visualize_pretrained_rfs(checkpoint, raw, n_forests, 18 sample_random=False, filter_config=None, n_threads=None): 19 """Visualize pretrained random forests from a shallow2depp checkpoint. 20 21 Arguments: 22 checkpoint [str] - path to the checkpoint folder 23 raw [np.ndarray] - the raw data for prediction 24 n_forests [int] - the number of forests to use 25 sample_random [bool] - whether to subsample forests randomly or regularly (default: False) 26 filter_config [list] - the filter configuration (default: None) 27 n_threads [int] - number of threads for parallel prediction of forests (default: None) 28 """ 29 import napari 30 31 rf_paths = glob(os.path.join(checkpoint, "*.pkl")) 32 if len(rf_paths) == 0: 33 rf_folder = os.path.join(checkpoint, "rfs") 34 assert os.path.exists(rf_folder), rf_folder 35 rf_paths = glob(os.path.join(rf_folder, "*.pkl")) 36 assert len(rf_paths) > 0 37 rf_paths.sort() 38 if sample_random: 39 rf_paths = np.random.choice(rf_paths, size=n_forests) 40 else: 41 rf_paths = rf_paths[::(len(rf_paths) // n_forests)][:n_forests] 42 43 print("Compute features for input of shape", raw.shape) 44 filter_config = _get_filters(raw.ndim, filter_config) 45 features = _apply_filters(raw, filter_config) 46 47 def predict_rf(rf_path): 48 with open(rf_path, "rb") as f: 49 rf = pickle.load(f) 50 pred = rf.predict_proba(features) 51 pred = pred.reshape(raw.shape + (pred.shape[1],)) 52 pred = np.moveaxis(pred, -1, 0) 53 assert pred.shape[1:] == raw.shape 54 return pred 55 56 n_threads = cpu_count() if n_threads is None else n_threads 57 with futures.ThreadPoolExecutor(n_threads) as tp: 58 preds = list(tqdm(tp.map(predict_rf, rf_paths), desc="Predict RFs", total=len(rf_paths))) 59 60 print("Start viewer") 61 v = napari.Viewer() 62 for path, pred in zip(rf_paths, preds): 63 name = os.path.basename(path) 64 v.add_image(pred, name=name) 65 v.add_image(raw) 66 v.grid.enabled = True 67 napari.run()
Visualize pretrained random forests from a shallow2depp checkpoint.
Arguments:
- checkpoint [str] - path to the checkpoint folder
- raw [np.ndarray] - the raw data for prediction
- n_forests [int] - the number of forests to use
- sample_random [bool] - whether to subsample forests randomly or regularly (default: False)
- filter_config [list] - the filter configuration (default: None)
- n_threads [int] - number of threads for parallel prediction of forests (default: None)
def
evaluate_enhancers( data, labels, enhancers, ilastik_projects, metric, prediction_function=None, rf_channel=1, is2d=False, save_path=None):
70def evaluate_enhancers(data, labels, enhancers, ilastik_projects, metric, 71 prediction_function=None, rf_channel=1, is2d=False, save_path=None): 72 """Evaluate enhancers on ilastik random forests from multiple projects. 73 74 Arguments: 75 data [np.ndarray] - the data for evaluation 76 labels [np.ndarray] - the labels for evaluation 77 enhancers [dict[str, str] - map of enhancer names to filepath with enhancer 78 models saved in the biomage.io model format 79 ilastik_projects [dict[str, str]] - map of names to ilastik project paths 80 metric [callable] - the metric used for evaluation 81 prediction_function [callable] - function to run prediction with the enhancer. 82 By default the bioimageio.prediction pipeline is called directly. 83 If given, needs to take the prediction pipeline and data (as xarray) 84 as input (default: None) 85 rf_channel [int, list[int]] - the channel(s) of the random forest to be passed 86 as input to the enhancer (default: 1) 87 is2d [bool] - whether to process 3d data as individual slices and average the scores. 88 Is ignored if the data is 2d (default: False) 89 save_path [str] - 90 Returns: 91 [pd.DataFrame] - a table with the scores of the enhancers for the different forests 92 and scores of the raw forest predictions 93 """ 94 import bioimageio.core 95 import xarray 96 97 assert data.shape == labels.shape 98 ndim = data.ndim 99 model_ndim = 2 if (data.ndim == 2 or is2d) else 3 100 101 def load_enhancer(enh): 102 model = bioimageio.core.load_resource_description(enh) 103 return bioimageio.core.create_prediction_pipeline(model) 104 105 # load the enhancers 106 models = {name: load_enhancer(enh) for name, enh in enhancers.items()} 107 108 # load the ilps 109 ilps = { 110 name: IlastikPredicter(path, model_ndim, ilastik_multi_thread=True, output_channel=rf_channel) 111 for name, path in ilastik_projects.items() 112 } 113 114 def require_rf_prediction(rf, input_, name, axes): 115 if save_path is None: 116 return rf(input_) 117 with io.open_file(save_path, "a") as f: 118 if name in f: 119 pred = f[name][:] 120 else: 121 pred = rf(input_) 122 # require len(axes) + 2 dimensions (additional batch and channel axis) 123 pred = pred[(None,) * (len(axes) + 2 - pred.ndim)] 124 assert pred.ndim == len(axes) + 2, f"{pred.ndim}, {len(axes) + 2}" 125 f.create_dataset(name, data=pred, compression="gzip") 126 return pred 127 128 def require_enh_prediction(enh, rf_pred, name, prediction_function, axes): 129 if save_path is None: 130 pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred) 131 pred = pred[0] 132 return pred 133 with io.open_file(save_path, "a") as f: 134 if name in f: 135 pred = f[name][:] 136 else: 137 rf_pred = xarray.DataArray(rf_pred, dims=("b", "c",) + tuple(axes)) 138 pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred) 139 pred = pred[0] 140 f.create_dataset(name, data=pred, compression="gzip") 141 return pred 142 143 def process_chunk(x, y, axes, z=None): 144 scores = np.zeros((len(models) + 1, len(ilps))) 145 for i, (rf_name, ilp) in enumerate(ilps.items()): 146 rf_pred = require_rf_prediction( 147 ilp, x, 148 rf_name if z is None else f"{rf_name}/{z:04}", 149 axes 150 ) 151 for j, (enh_name, enh) in enumerate(models.items()): 152 pred = require_enh_prediction( 153 enh, rf_pred, 154 f"{enh_name}/{rf_name}" if z is None else f"{enh_name}/{rf_name}/{z:04}", 155 prediction_function, 156 axes 157 ) 158 score = metric(pred, y) 159 scores[j, i] = score 160 score = metric(rf_pred, y) 161 scores[-1, i] = score 162 163 scores = pd.DataFrame(scores, columns=list(ilps.keys())) 164 scores.insert(loc=0, column="enhancer", value=list(models.keys()) + ["rf-score"]) 165 return scores 166 167 # if we have 2d data, or 3d data that is processed en block, 168 # we only have to process a single 'chunk' 169 if ndim == 2 or (ndim == 3 and not is2d): 170 scores = process_chunk(data, labels, "yx" if ndim == 2 else "zyx") 171 elif ndim == 3 and is2d: 172 scores = [] 173 for z in trange(data.shape[0]): 174 scores_z = process_chunk(data[z], labels[z], "yx", z) 175 scores.append(scores_z) 176 scores = pd.concat(scores).groupby("enhancer").mean() 177 else: 178 raise ValueError("Invalid data dimensions: {ndim}") 179 180 return scores
Evaluate enhancers on ilastik random forests from multiple projects.
Arguments:
- data [np.ndarray] - the data for evaluation
- labels [np.ndarray] - the labels for evaluation
- enhancers [dict[str, str] - map of enhancer names to filepath with enhancer models saved in the biomage.io model format
- ilastik_projects [dict[str, str]] - map of names to ilastik project paths
- metric [callable] - the metric used for evaluation
- prediction_function [callable] - function to run prediction with the enhancer. By default the bioimageio.prediction pipeline is called directly. If given, needs to take the prediction pipeline and data (as xarray) as input (default: None)
- rf_channel [int, list[int]] - the channel(s) of the random forest to be passed as input to the enhancer (default: 1)
- is2d [bool] - whether to process 3d data as individual slices and average the scores. Is ignored if the data is 2d (default: False)
- save_path [str] -
Returns:
[pd.DataFrame] - a table with the scores of the enhancers for the different forests and scores of the raw forest predictions
def
load_predictions(save_path, n_threads=1):
183def load_predictions(save_path, n_threads=1): 184 """Helper functions to load predictions from a save_path created by evaluate_enhancers 185 """ 186 predictions = {} 187 188 def visit(name, node): 189 if io.is_group(node): 190 return 191 node.n_threads = n_threads 192 # if we store with 'is2d' individual slices are datasets 193 try: 194 data_name = "/".join(name.split("/")[:-1]) 195 z = int(name.split("/")[-1]) 196 data = node[:] 197 pred = predictions.get(data_name, {}) 198 pred[z] = data 199 predictions[data_name] = pred 200 # otherwise the above will throw a val error and we just load the array 201 except ValueError: 202 predictions[name] = node[:] 203 204 with io.open_file(save_path, "r") as f: 205 f.visititems(visit) 206 207 def to_vol(pred): 208 if isinstance(pred, np.ndarray): 209 return pred 210 pred = dict(sorted(pred.items())) 211 return np.concatenate([pz[None] for pz in pred.values()], axis=0) 212 213 predictions = {name: to_vol(pred) for name, pred in predictions.items()} 214 215 return predictions
Helper functions to load predictions from a save_path created by evaluate_enhancers