torch_em.shallow2deep.shallow2deep_eval
1import os 2import pickle 3from concurrent import futures 4from glob import glob 5from multiprocessing import cpu_count 6from typing import Callable, Dict, Optional, Tuple, Union 7 8import bioimageio.core 9import elf.io as io 10import numpy as np 11import pandas as pd 12import xarray 13from tqdm import tqdm, trange 14 15from .prepare_shallow2deep import _apply_filters, _get_filters 16from .shallow2deep_model import IlastikPredicter 17 18 19def visualize_pretrained_rfs( 20 checkpoint: str, 21 raw: np.ndarray, 22 n_forests: int, 23 sample_random: bool = False, 24 filter_config: Optional[Dict] = None, 25 n_threads: Optional[int] = None, 26) -> None: 27 """Visualize pretrained random forests from a shallow2depp checkpoint. 28 29 Args: 30 checkpoint: Path to the checkpoint folder. 31 raw: The input raw data for prediction. 32 n_forests: The number of forests to use for visualization. 33 sample_random: Whether to subsample forests randomly or regularly. 34 filter_config: The filter configuration. 35 n_threads: The number of threads for parallel prediction of forests. 36 """ 37 import napari 38 39 rf_paths = glob(os.path.join(checkpoint, "*.pkl")) 40 if len(rf_paths) == 0: 41 rf_folder = os.path.join(checkpoint, "rfs") 42 assert os.path.exists(rf_folder), rf_folder 43 rf_paths = glob(os.path.join(rf_folder, "*.pkl")) 44 assert len(rf_paths) > 0 45 rf_paths.sort() 46 if sample_random: 47 rf_paths = np.random.choice(rf_paths, size=n_forests) 48 else: 49 rf_paths = rf_paths[::(len(rf_paths) // n_forests)][:n_forests] 50 51 print("Compute features for input of shape", raw.shape) 52 filter_config = _get_filters(raw.ndim, filter_config) 53 features = _apply_filters(raw, filter_config) 54 55 def predict_rf(rf_path): 56 with open(rf_path, "rb") as f: 57 rf = pickle.load(f) 58 pred = rf.predict_proba(features) 59 pred = pred.reshape(raw.shape + (pred.shape[1],)) 60 pred = np.moveaxis(pred, -1, 0) 61 assert pred.shape[1:] == raw.shape 62 return pred 63 64 n_threads = cpu_count() if n_threads is None else n_threads 65 with futures.ThreadPoolExecutor(n_threads) as tp: 66 preds = list(tqdm(tp.map(predict_rf, rf_paths), desc="Predict RFs", total=len(rf_paths))) 67 68 print("Start viewer") 69 v = napari.Viewer() 70 for path, pred in zip(rf_paths, preds): 71 name = os.path.basename(path) 72 v.add_image(pred, name=name) 73 v.add_image(raw) 74 v.grid.enabled = True 75 napari.run() 76 77 78def evaluate_enhancers( 79 data: np.ndarray, 80 labels: np.ndarray, 81 enhancers: Dict[str, str], 82 ilastik_projects: Dict[str, str], 83 metric: Callable, 84 prediction_function: Optional[Callable] = None, 85 rf_channel: Union[int, Tuple[int, ...]] = 1, 86 is2d: bool = False, 87 save_path: Optional[str] = None 88) -> pd.DataFrame: 89 """Evaluate enhancers on ilastik random forests from multiple projects. 90 91 Args: 92 data: The data for evaluation. 93 labels: The labels for evaluation. 94 enhancers: Mapping of enhancer names to filepath with enhancer models saved in the biomage.io model format. 95 ilastik_projects: Mapping of names to ilastik project paths. 96 metric: The metric used for evaluation. 97 prediction_function: Function to run prediction with the enhancer. 98 By default the bioimageio.prediction pipeline is called directly. 99 If given, needs to take the prediction pipeline and data (as xarray) as input. 100 rf_channel: The channel(s) of the random forest to be passed as input to the enhancer. 101 is2d: Whether to process 3d data as individual slices and average the scores. Is ignored if the data is 2d. 102 save_path: Save path for caching the random forest predictions. 103 104 Returns: 105 A table with the scores of the enhancers for the different forests and scores of the raw forest predictions. 106 """ 107 108 assert data.shape == labels.shape 109 ndim = data.ndim 110 model_ndim = 2 if (data.ndim == 2 or is2d) else 3 111 112 def load_enhancer(enh): 113 model = bioimageio.core.load_resource_description(enh) 114 return bioimageio.core.create_prediction_pipeline(model) 115 116 # load the enhancers 117 models = {name: load_enhancer(enh) for name, enh in enhancers.items()} 118 119 # load the ilps 120 ilps = { 121 name: IlastikPredicter(path, model_ndim, ilastik_multi_thread=True, output_channel=rf_channel) 122 for name, path in ilastik_projects.items() 123 } 124 125 def require_rf_prediction(rf, input_, name, axes): 126 if save_path is None: 127 return rf(input_) 128 with io.open_file(save_path, "a") as f: 129 if name in f: 130 pred = f[name][:] 131 else: 132 pred = rf(input_) 133 # require len(axes) + 2 dimensions (additional batch and channel axis) 134 pred = pred[(None,) * (len(axes) + 2 - pred.ndim)] 135 assert pred.ndim == len(axes) + 2, f"{pred.ndim}, {len(axes) + 2}" 136 f.create_dataset(name, data=pred, compression="gzip") 137 return pred 138 139 def require_enh_prediction(enh, rf_pred, name, prediction_function, axes): 140 if save_path is None: 141 pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred) 142 pred = pred[0] 143 return pred 144 with io.open_file(save_path, "a") as f: 145 if name in f: 146 pred = f[name][:] 147 else: 148 rf_pred = xarray.DataArray(rf_pred, dims=("b", "c",) + tuple(axes)) 149 pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred) 150 pred = pred[0] 151 f.create_dataset(name, data=pred, compression="gzip") 152 return pred 153 154 def process_chunk(x, y, axes, z=None): 155 scores = np.zeros((len(models) + 1, len(ilps))) 156 for i, (rf_name, ilp) in enumerate(ilps.items()): 157 rf_pred = require_rf_prediction( 158 ilp, x, 159 rf_name if z is None else f"{rf_name}/{z:04}", 160 axes 161 ) 162 for j, (enh_name, enh) in enumerate(models.items()): 163 pred = require_enh_prediction( 164 enh, rf_pred, 165 f"{enh_name}/{rf_name}" if z is None else f"{enh_name}/{rf_name}/{z:04}", 166 prediction_function, 167 axes 168 ) 169 score = metric(pred, y) 170 scores[j, i] = score 171 score = metric(rf_pred, y) 172 scores[-1, i] = score 173 174 scores = pd.DataFrame(scores, columns=list(ilps.keys())) 175 scores.insert(loc=0, column="enhancer", value=list(models.keys()) + ["rf-score"]) 176 return scores 177 178 # if we have 2d data, or 3d data that is processed en block, 179 # we only have to process a single 'chunk' 180 if ndim == 2 or (ndim == 3 and not is2d): 181 scores = process_chunk(data, labels, "yx" if ndim == 2 else "zyx") 182 elif ndim == 3 and is2d: 183 scores = [] 184 for z in trange(data.shape[0]): 185 scores_z = process_chunk(data[z], labels[z], "yx", z) 186 scores.append(scores_z) 187 scores = pd.concat(scores).groupby("enhancer").mean() 188 else: 189 raise ValueError("Invalid data dimensions: {ndim}") 190 191 return scores 192 193 194def load_predictions(save_path: str, n_threads: int = 1) -> Dict[str, np.ndarray]: 195 """Load predictions from a save_path created by evaluate_enhancers. 196 197 Args: 198 save_path: The path where the predictions were saved. 199 n_threads: The number of threads for loading data. 200 201 Returns: 202 A mapping of random forest names to the predictions. 203 """ 204 predictions = {} 205 206 def visit(name, node): 207 if io.is_group(node): 208 return 209 node.n_threads = n_threads 210 # if we store with 'is2d' individual slices are datasets 211 try: 212 data_name = "/".join(name.split("/")[:-1]) 213 z = int(name.split("/")[-1]) 214 data = node[:] 215 pred = predictions.get(data_name, {}) 216 pred[z] = data 217 predictions[data_name] = pred 218 # otherwise the above will throw a val error and we just load the array 219 except ValueError: 220 predictions[name] = node[:] 221 222 with io.open_file(save_path, "r") as f: 223 f.visititems(visit) 224 225 def to_vol(pred): 226 if isinstance(pred, np.ndarray): 227 return pred 228 pred = dict(sorted(pred.items())) 229 return np.concatenate([pz[None] for pz in pred.values()], axis=0) 230 231 return {name: to_vol(pred) for name, pred in predictions.items()}
def
visualize_pretrained_rfs( checkpoint: str, raw: numpy.ndarray, n_forests: int, sample_random: bool = False, filter_config: Optional[Dict] = None, n_threads: Optional[int] = None) -> None:
20def visualize_pretrained_rfs( 21 checkpoint: str, 22 raw: np.ndarray, 23 n_forests: int, 24 sample_random: bool = False, 25 filter_config: Optional[Dict] = None, 26 n_threads: Optional[int] = None, 27) -> None: 28 """Visualize pretrained random forests from a shallow2depp checkpoint. 29 30 Args: 31 checkpoint: Path to the checkpoint folder. 32 raw: The input raw data for prediction. 33 n_forests: The number of forests to use for visualization. 34 sample_random: Whether to subsample forests randomly or regularly. 35 filter_config: The filter configuration. 36 n_threads: The number of threads for parallel prediction of forests. 37 """ 38 import napari 39 40 rf_paths = glob(os.path.join(checkpoint, "*.pkl")) 41 if len(rf_paths) == 0: 42 rf_folder = os.path.join(checkpoint, "rfs") 43 assert os.path.exists(rf_folder), rf_folder 44 rf_paths = glob(os.path.join(rf_folder, "*.pkl")) 45 assert len(rf_paths) > 0 46 rf_paths.sort() 47 if sample_random: 48 rf_paths = np.random.choice(rf_paths, size=n_forests) 49 else: 50 rf_paths = rf_paths[::(len(rf_paths) // n_forests)][:n_forests] 51 52 print("Compute features for input of shape", raw.shape) 53 filter_config = _get_filters(raw.ndim, filter_config) 54 features = _apply_filters(raw, filter_config) 55 56 def predict_rf(rf_path): 57 with open(rf_path, "rb") as f: 58 rf = pickle.load(f) 59 pred = rf.predict_proba(features) 60 pred = pred.reshape(raw.shape + (pred.shape[1],)) 61 pred = np.moveaxis(pred, -1, 0) 62 assert pred.shape[1:] == raw.shape 63 return pred 64 65 n_threads = cpu_count() if n_threads is None else n_threads 66 with futures.ThreadPoolExecutor(n_threads) as tp: 67 preds = list(tqdm(tp.map(predict_rf, rf_paths), desc="Predict RFs", total=len(rf_paths))) 68 69 print("Start viewer") 70 v = napari.Viewer() 71 for path, pred in zip(rf_paths, preds): 72 name = os.path.basename(path) 73 v.add_image(pred, name=name) 74 v.add_image(raw) 75 v.grid.enabled = True 76 napari.run()
Visualize pretrained random forests from a shallow2depp checkpoint.
Arguments:
- checkpoint: Path to the checkpoint folder.
- raw: The input raw data for prediction.
- n_forests: The number of forests to use for visualization.
- sample_random: Whether to subsample forests randomly or regularly.
- filter_config: The filter configuration.
- n_threads: The number of threads for parallel prediction of forests.
def
evaluate_enhancers( data: numpy.ndarray, labels: numpy.ndarray, enhancers: Dict[str, str], ilastik_projects: Dict[str, str], metric: Callable, prediction_function: Optional[Callable] = None, rf_channel: Union[int, Tuple[int, ...]] = 1, is2d: bool = False, save_path: Optional[str] = None) -> pandas.core.frame.DataFrame:
79def evaluate_enhancers( 80 data: np.ndarray, 81 labels: np.ndarray, 82 enhancers: Dict[str, str], 83 ilastik_projects: Dict[str, str], 84 metric: Callable, 85 prediction_function: Optional[Callable] = None, 86 rf_channel: Union[int, Tuple[int, ...]] = 1, 87 is2d: bool = False, 88 save_path: Optional[str] = None 89) -> pd.DataFrame: 90 """Evaluate enhancers on ilastik random forests from multiple projects. 91 92 Args: 93 data: The data for evaluation. 94 labels: The labels for evaluation. 95 enhancers: Mapping of enhancer names to filepath with enhancer models saved in the biomage.io model format. 96 ilastik_projects: Mapping of names to ilastik project paths. 97 metric: The metric used for evaluation. 98 prediction_function: Function to run prediction with the enhancer. 99 By default the bioimageio.prediction pipeline is called directly. 100 If given, needs to take the prediction pipeline and data (as xarray) as input. 101 rf_channel: The channel(s) of the random forest to be passed as input to the enhancer. 102 is2d: Whether to process 3d data as individual slices and average the scores. Is ignored if the data is 2d. 103 save_path: Save path for caching the random forest predictions. 104 105 Returns: 106 A table with the scores of the enhancers for the different forests and scores of the raw forest predictions. 107 """ 108 109 assert data.shape == labels.shape 110 ndim = data.ndim 111 model_ndim = 2 if (data.ndim == 2 or is2d) else 3 112 113 def load_enhancer(enh): 114 model = bioimageio.core.load_resource_description(enh) 115 return bioimageio.core.create_prediction_pipeline(model) 116 117 # load the enhancers 118 models = {name: load_enhancer(enh) for name, enh in enhancers.items()} 119 120 # load the ilps 121 ilps = { 122 name: IlastikPredicter(path, model_ndim, ilastik_multi_thread=True, output_channel=rf_channel) 123 for name, path in ilastik_projects.items() 124 } 125 126 def require_rf_prediction(rf, input_, name, axes): 127 if save_path is None: 128 return rf(input_) 129 with io.open_file(save_path, "a") as f: 130 if name in f: 131 pred = f[name][:] 132 else: 133 pred = rf(input_) 134 # require len(axes) + 2 dimensions (additional batch and channel axis) 135 pred = pred[(None,) * (len(axes) + 2 - pred.ndim)] 136 assert pred.ndim == len(axes) + 2, f"{pred.ndim}, {len(axes) + 2}" 137 f.create_dataset(name, data=pred, compression="gzip") 138 return pred 139 140 def require_enh_prediction(enh, rf_pred, name, prediction_function, axes): 141 if save_path is None: 142 pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred) 143 pred = pred[0] 144 return pred 145 with io.open_file(save_path, "a") as f: 146 if name in f: 147 pred = f[name][:] 148 else: 149 rf_pred = xarray.DataArray(rf_pred, dims=("b", "c",) + tuple(axes)) 150 pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred) 151 pred = pred[0] 152 f.create_dataset(name, data=pred, compression="gzip") 153 return pred 154 155 def process_chunk(x, y, axes, z=None): 156 scores = np.zeros((len(models) + 1, len(ilps))) 157 for i, (rf_name, ilp) in enumerate(ilps.items()): 158 rf_pred = require_rf_prediction( 159 ilp, x, 160 rf_name if z is None else f"{rf_name}/{z:04}", 161 axes 162 ) 163 for j, (enh_name, enh) in enumerate(models.items()): 164 pred = require_enh_prediction( 165 enh, rf_pred, 166 f"{enh_name}/{rf_name}" if z is None else f"{enh_name}/{rf_name}/{z:04}", 167 prediction_function, 168 axes 169 ) 170 score = metric(pred, y) 171 scores[j, i] = score 172 score = metric(rf_pred, y) 173 scores[-1, i] = score 174 175 scores = pd.DataFrame(scores, columns=list(ilps.keys())) 176 scores.insert(loc=0, column="enhancer", value=list(models.keys()) + ["rf-score"]) 177 return scores 178 179 # if we have 2d data, or 3d data that is processed en block, 180 # we only have to process a single 'chunk' 181 if ndim == 2 or (ndim == 3 and not is2d): 182 scores = process_chunk(data, labels, "yx" if ndim == 2 else "zyx") 183 elif ndim == 3 and is2d: 184 scores = [] 185 for z in trange(data.shape[0]): 186 scores_z = process_chunk(data[z], labels[z], "yx", z) 187 scores.append(scores_z) 188 scores = pd.concat(scores).groupby("enhancer").mean() 189 else: 190 raise ValueError("Invalid data dimensions: {ndim}") 191 192 return scores
Evaluate enhancers on ilastik random forests from multiple projects.
Arguments:
- data: The data for evaluation.
- labels: The labels for evaluation.
- enhancers: Mapping of enhancer names to filepath with enhancer models saved in the biomage.io model format.
- ilastik_projects: Mapping of names to ilastik project paths.
- metric: The metric used for evaluation.
- prediction_function: Function to run prediction with the enhancer. By default the bioimageio.prediction pipeline is called directly. If given, needs to take the prediction pipeline and data (as xarray) as input.
- rf_channel: The channel(s) of the random forest to be passed as input to the enhancer.
- is2d: Whether to process 3d data as individual slices and average the scores. Is ignored if the data is 2d.
- save_path: Save path for caching the random forest predictions.
Returns:
A table with the scores of the enhancers for the different forests and scores of the raw forest predictions.
def
load_predictions(save_path: str, n_threads: int = 1) -> Dict[str, numpy.ndarray]:
195def load_predictions(save_path: str, n_threads: int = 1) -> Dict[str, np.ndarray]: 196 """Load predictions from a save_path created by evaluate_enhancers. 197 198 Args: 199 save_path: The path where the predictions were saved. 200 n_threads: The number of threads for loading data. 201 202 Returns: 203 A mapping of random forest names to the predictions. 204 """ 205 predictions = {} 206 207 def visit(name, node): 208 if io.is_group(node): 209 return 210 node.n_threads = n_threads 211 # if we store with 'is2d' individual slices are datasets 212 try: 213 data_name = "/".join(name.split("/")[:-1]) 214 z = int(name.split("/")[-1]) 215 data = node[:] 216 pred = predictions.get(data_name, {}) 217 pred[z] = data 218 predictions[data_name] = pred 219 # otherwise the above will throw a val error and we just load the array 220 except ValueError: 221 predictions[name] = node[:] 222 223 with io.open_file(save_path, "r") as f: 224 f.visititems(visit) 225 226 def to_vol(pred): 227 if isinstance(pred, np.ndarray): 228 return pred 229 pred = dict(sorted(pred.items())) 230 return np.concatenate([pz[None] for pz in pred.values()], axis=0) 231 232 return {name: to_vol(pred) for name, pred in predictions.items()}
Load predictions from a save_path created by evaluate_enhancers.
Arguments:
- save_path: The path where the predictions were saved.
- n_threads: The number of threads for loading data.
Returns:
A mapping of random forest names to the predictions.