torch_em.shallow2deep.shallow2deep_eval

  1import os
  2import pickle
  3from concurrent import futures
  4from glob import glob
  5from multiprocessing import cpu_count
  6from typing import Callable, Dict, Optional, Tuple, Union
  7
  8import bioimageio.core
  9import elf.io as io
 10import numpy as np
 11import pandas as pd
 12import xarray
 13from tqdm import tqdm, trange
 14
 15from .prepare_shallow2deep import _apply_filters, _get_filters
 16from .shallow2deep_model import IlastikPredicter
 17
 18
 19def visualize_pretrained_rfs(
 20    checkpoint: str,
 21    raw: np.ndarray,
 22    n_forests: int,
 23    sample_random: bool = False,
 24    filter_config: Optional[Dict] = None,
 25    n_threads: Optional[int] = None,
 26) -> None:
 27    """Visualize pretrained random forests from a shallow2depp checkpoint.
 28
 29    Args:
 30        checkpoint: Path to the checkpoint folder.
 31        raw: The input raw data for prediction.
 32        n_forests: The number of forests to use for visualization.
 33        sample_random: Whether to subsample forests randomly or regularly.
 34        filter_config: The filter configuration.
 35        n_threads: The number of threads for parallel prediction of forests.
 36    """
 37    import napari
 38
 39    rf_paths = glob(os.path.join(checkpoint, "*.pkl"))
 40    if len(rf_paths) == 0:
 41        rf_folder = os.path.join(checkpoint, "rfs")
 42        assert os.path.exists(rf_folder), rf_folder
 43        rf_paths = glob(os.path.join(rf_folder, "*.pkl"))
 44    assert len(rf_paths) > 0
 45    rf_paths.sort()
 46    if sample_random:
 47        rf_paths = np.random.choice(rf_paths, size=n_forests)
 48    else:
 49        rf_paths = rf_paths[::(len(rf_paths) // n_forests)][:n_forests]
 50
 51    print("Compute features for input of shape", raw.shape)
 52    filter_config = _get_filters(raw.ndim, filter_config)
 53    features = _apply_filters(raw, filter_config)
 54
 55    def predict_rf(rf_path):
 56        with open(rf_path, "rb") as f:
 57            rf = pickle.load(f)
 58        pred = rf.predict_proba(features)
 59        pred = pred.reshape(raw.shape + (pred.shape[1],))
 60        pred = np.moveaxis(pred, -1, 0)
 61        assert pred.shape[1:] == raw.shape
 62        return pred
 63
 64    n_threads = cpu_count() if n_threads is None else n_threads
 65    with futures.ThreadPoolExecutor(n_threads) as tp:
 66        preds = list(tqdm(tp.map(predict_rf, rf_paths), desc="Predict RFs", total=len(rf_paths)))
 67
 68    print("Start viewer")
 69    v = napari.Viewer()
 70    for path, pred in zip(rf_paths, preds):
 71        name = os.path.basename(path)
 72        v.add_image(pred, name=name)
 73    v.add_image(raw)
 74    v.grid.enabled = True
 75    napari.run()
 76
 77
 78def evaluate_enhancers(
 79    data: np.ndarray,
 80    labels: np.ndarray,
 81    enhancers: Dict[str, str],
 82    ilastik_projects: Dict[str, str],
 83    metric: Callable,
 84    prediction_function: Optional[Callable] = None,
 85    rf_channel: Union[int, Tuple[int, ...]] = 1,
 86    is2d: bool = False,
 87    save_path: Optional[str] = None
 88) -> pd.DataFrame:
 89    """Evaluate enhancers on ilastik random forests from multiple projects.
 90
 91    Args:
 92        data: The data for evaluation.
 93        labels: The labels for evaluation.
 94        enhancers: Mapping of enhancer names to filepath with enhancer models saved in the biomage.io model format.
 95        ilastik_projects: Mapping of names to ilastik project paths.
 96        metric: The metric used for evaluation.
 97        prediction_function: Function to run prediction with the enhancer.
 98            By default the bioimageio.prediction pipeline is called directly.
 99            If given, needs to take the prediction pipeline and data (as xarray) as input.
100        rf_channel: The channel(s) of the random forest to be passed as input to the enhancer.
101        is2d: Whether to process 3d data as individual slices and average the scores. Is ignored if the data is 2d.
102        save_path: Save path for caching the random forest predictions.
103
104    Returns:
105        A table with the scores of the enhancers for the different forests and scores of the raw forest predictions.
106    """
107
108    assert data.shape == labels.shape
109    ndim = data.ndim
110    model_ndim = 2 if (data.ndim == 2 or is2d) else 3
111
112    def load_enhancer(enh):
113        model = bioimageio.core.load_resource_description(enh)
114        return bioimageio.core.create_prediction_pipeline(model)
115
116    # load the enhancers
117    models = {name: load_enhancer(enh) for name, enh in enhancers.items()}
118
119    # load the ilps
120    ilps = {
121        name: IlastikPredicter(path, model_ndim, ilastik_multi_thread=True, output_channel=rf_channel)
122        for name, path in ilastik_projects.items()
123    }
124
125    def require_rf_prediction(rf, input_, name, axes):
126        if save_path is None:
127            return rf(input_)
128        with io.open_file(save_path, "a") as f:
129            if name in f:
130                pred = f[name][:]
131            else:
132                pred = rf(input_)
133                # require len(axes) + 2 dimensions (additional batch and channel axis)
134                pred = pred[(None,) * (len(axes) + 2 - pred.ndim)]
135                assert pred.ndim == len(axes) + 2, f"{pred.ndim}, {len(axes) + 2}"
136                f.create_dataset(name, data=pred, compression="gzip")
137            return pred
138
139    def require_enh_prediction(enh, rf_pred, name, prediction_function, axes):
140        if save_path is None:
141            pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred)
142            pred = pred[0]
143            return pred
144        with io.open_file(save_path, "a") as f:
145            if name in f:
146                pred = f[name][:]
147            else:
148                rf_pred = xarray.DataArray(rf_pred, dims=("b", "c",) + tuple(axes))
149                pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred)
150                pred = pred[0]
151                f.create_dataset(name, data=pred, compression="gzip")
152            return pred
153
154    def process_chunk(x, y, axes, z=None):
155        scores = np.zeros((len(models) + 1, len(ilps)))
156        for i, (rf_name, ilp) in enumerate(ilps.items()):
157            rf_pred = require_rf_prediction(
158                ilp, x,
159                rf_name if z is None else f"{rf_name}/{z:04}",
160                axes
161            )
162            for j, (enh_name, enh) in enumerate(models.items()):
163                pred = require_enh_prediction(
164                    enh, rf_pred,
165                    f"{enh_name}/{rf_name}" if z is None else f"{enh_name}/{rf_name}/{z:04}",
166                    prediction_function,
167                    axes
168                )
169                score = metric(pred, y)
170                scores[j, i] = score
171            score = metric(rf_pred, y)
172            scores[-1, i] = score
173
174        scores = pd.DataFrame(scores, columns=list(ilps.keys()))
175        scores.insert(loc=0, column="enhancer", value=list(models.keys()) + ["rf-score"])
176        return scores
177
178    # if we have 2d data, or 3d data that is processed en block,
179    # we only have to process a single 'chunk'
180    if ndim == 2 or (ndim == 3 and not is2d):
181        scores = process_chunk(data, labels, "yx" if ndim == 2 else "zyx")
182    elif ndim == 3 and is2d:
183        scores = []
184        for z in trange(data.shape[0]):
185            scores_z = process_chunk(data[z], labels[z], "yx", z)
186            scores.append(scores_z)
187        scores = pd.concat(scores).groupby("enhancer").mean()
188    else:
189        raise ValueError("Invalid data dimensions: {ndim}")
190
191    return scores
192
193
194def load_predictions(save_path: str, n_threads: int = 1) -> Dict[str, np.ndarray]:
195    """Load predictions from a save_path created by evaluate_enhancers.
196
197    Args:
198        save_path: The path where the predictions were saved.
199        n_threads: The number of threads for loading data.
200
201    Returns:
202        A mapping of random forest names to the predictions.
203    """
204    predictions = {}
205
206    def visit(name, node):
207        if io.is_group(node):
208            return
209        node.n_threads = n_threads
210        # if we store with 'is2d' individual slices are datasets
211        try:
212            data_name = "/".join(name.split("/")[:-1])
213            z = int(name.split("/")[-1])
214            data = node[:]
215            pred = predictions.get(data_name, {})
216            pred[z] = data
217            predictions[data_name] = pred
218        # otherwise the above will throw a val error and we just load the array
219        except ValueError:
220            predictions[name] = node[:]
221
222    with io.open_file(save_path, "r") as f:
223        f.visititems(visit)
224
225    def to_vol(pred):
226        if isinstance(pred, np.ndarray):
227            return pred
228        pred = dict(sorted(pred.items()))
229        return np.concatenate([pz[None] for pz in pred.values()], axis=0)
230
231    return {name: to_vol(pred) for name, pred in predictions.items()}
def visualize_pretrained_rfs( checkpoint: str, raw: numpy.ndarray, n_forests: int, sample_random: bool = False, filter_config: Optional[Dict] = None, n_threads: Optional[int] = None) -> None:
20def visualize_pretrained_rfs(
21    checkpoint: str,
22    raw: np.ndarray,
23    n_forests: int,
24    sample_random: bool = False,
25    filter_config: Optional[Dict] = None,
26    n_threads: Optional[int] = None,
27) -> None:
28    """Visualize pretrained random forests from a shallow2depp checkpoint.
29
30    Args:
31        checkpoint: Path to the checkpoint folder.
32        raw: The input raw data for prediction.
33        n_forests: The number of forests to use for visualization.
34        sample_random: Whether to subsample forests randomly or regularly.
35        filter_config: The filter configuration.
36        n_threads: The number of threads for parallel prediction of forests.
37    """
38    import napari
39
40    rf_paths = glob(os.path.join(checkpoint, "*.pkl"))
41    if len(rf_paths) == 0:
42        rf_folder = os.path.join(checkpoint, "rfs")
43        assert os.path.exists(rf_folder), rf_folder
44        rf_paths = glob(os.path.join(rf_folder, "*.pkl"))
45    assert len(rf_paths) > 0
46    rf_paths.sort()
47    if sample_random:
48        rf_paths = np.random.choice(rf_paths, size=n_forests)
49    else:
50        rf_paths = rf_paths[::(len(rf_paths) // n_forests)][:n_forests]
51
52    print("Compute features for input of shape", raw.shape)
53    filter_config = _get_filters(raw.ndim, filter_config)
54    features = _apply_filters(raw, filter_config)
55
56    def predict_rf(rf_path):
57        with open(rf_path, "rb") as f:
58            rf = pickle.load(f)
59        pred = rf.predict_proba(features)
60        pred = pred.reshape(raw.shape + (pred.shape[1],))
61        pred = np.moveaxis(pred, -1, 0)
62        assert pred.shape[1:] == raw.shape
63        return pred
64
65    n_threads = cpu_count() if n_threads is None else n_threads
66    with futures.ThreadPoolExecutor(n_threads) as tp:
67        preds = list(tqdm(tp.map(predict_rf, rf_paths), desc="Predict RFs", total=len(rf_paths)))
68
69    print("Start viewer")
70    v = napari.Viewer()
71    for path, pred in zip(rf_paths, preds):
72        name = os.path.basename(path)
73        v.add_image(pred, name=name)
74    v.add_image(raw)
75    v.grid.enabled = True
76    napari.run()

Visualize pretrained random forests from a shallow2depp checkpoint.

Arguments:
  • checkpoint: Path to the checkpoint folder.
  • raw: The input raw data for prediction.
  • n_forests: The number of forests to use for visualization.
  • sample_random: Whether to subsample forests randomly or regularly.
  • filter_config: The filter configuration.
  • n_threads: The number of threads for parallel prediction of forests.
def evaluate_enhancers( data: numpy.ndarray, labels: numpy.ndarray, enhancers: Dict[str, str], ilastik_projects: Dict[str, str], metric: Callable, prediction_function: Optional[Callable] = None, rf_channel: Union[int, Tuple[int, ...]] = 1, is2d: bool = False, save_path: Optional[str] = None) -> pandas.core.frame.DataFrame:
 79def evaluate_enhancers(
 80    data: np.ndarray,
 81    labels: np.ndarray,
 82    enhancers: Dict[str, str],
 83    ilastik_projects: Dict[str, str],
 84    metric: Callable,
 85    prediction_function: Optional[Callable] = None,
 86    rf_channel: Union[int, Tuple[int, ...]] = 1,
 87    is2d: bool = False,
 88    save_path: Optional[str] = None
 89) -> pd.DataFrame:
 90    """Evaluate enhancers on ilastik random forests from multiple projects.
 91
 92    Args:
 93        data: The data for evaluation.
 94        labels: The labels for evaluation.
 95        enhancers: Mapping of enhancer names to filepath with enhancer models saved in the biomage.io model format.
 96        ilastik_projects: Mapping of names to ilastik project paths.
 97        metric: The metric used for evaluation.
 98        prediction_function: Function to run prediction with the enhancer.
 99            By default the bioimageio.prediction pipeline is called directly.
100            If given, needs to take the prediction pipeline and data (as xarray) as input.
101        rf_channel: The channel(s) of the random forest to be passed as input to the enhancer.
102        is2d: Whether to process 3d data as individual slices and average the scores. Is ignored if the data is 2d.
103        save_path: Save path for caching the random forest predictions.
104
105    Returns:
106        A table with the scores of the enhancers for the different forests and scores of the raw forest predictions.
107    """
108
109    assert data.shape == labels.shape
110    ndim = data.ndim
111    model_ndim = 2 if (data.ndim == 2 or is2d) else 3
112
113    def load_enhancer(enh):
114        model = bioimageio.core.load_resource_description(enh)
115        return bioimageio.core.create_prediction_pipeline(model)
116
117    # load the enhancers
118    models = {name: load_enhancer(enh) for name, enh in enhancers.items()}
119
120    # load the ilps
121    ilps = {
122        name: IlastikPredicter(path, model_ndim, ilastik_multi_thread=True, output_channel=rf_channel)
123        for name, path in ilastik_projects.items()
124    }
125
126    def require_rf_prediction(rf, input_, name, axes):
127        if save_path is None:
128            return rf(input_)
129        with io.open_file(save_path, "a") as f:
130            if name in f:
131                pred = f[name][:]
132            else:
133                pred = rf(input_)
134                # require len(axes) + 2 dimensions (additional batch and channel axis)
135                pred = pred[(None,) * (len(axes) + 2 - pred.ndim)]
136                assert pred.ndim == len(axes) + 2, f"{pred.ndim}, {len(axes) + 2}"
137                f.create_dataset(name, data=pred, compression="gzip")
138            return pred
139
140    def require_enh_prediction(enh, rf_pred, name, prediction_function, axes):
141        if save_path is None:
142            pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred)
143            pred = pred[0]
144            return pred
145        with io.open_file(save_path, "a") as f:
146            if name in f:
147                pred = f[name][:]
148            else:
149                rf_pred = xarray.DataArray(rf_pred, dims=("b", "c",) + tuple(axes))
150                pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred)
151                pred = pred[0]
152                f.create_dataset(name, data=pred, compression="gzip")
153            return pred
154
155    def process_chunk(x, y, axes, z=None):
156        scores = np.zeros((len(models) + 1, len(ilps)))
157        for i, (rf_name, ilp) in enumerate(ilps.items()):
158            rf_pred = require_rf_prediction(
159                ilp, x,
160                rf_name if z is None else f"{rf_name}/{z:04}",
161                axes
162            )
163            for j, (enh_name, enh) in enumerate(models.items()):
164                pred = require_enh_prediction(
165                    enh, rf_pred,
166                    f"{enh_name}/{rf_name}" if z is None else f"{enh_name}/{rf_name}/{z:04}",
167                    prediction_function,
168                    axes
169                )
170                score = metric(pred, y)
171                scores[j, i] = score
172            score = metric(rf_pred, y)
173            scores[-1, i] = score
174
175        scores = pd.DataFrame(scores, columns=list(ilps.keys()))
176        scores.insert(loc=0, column="enhancer", value=list(models.keys()) + ["rf-score"])
177        return scores
178
179    # if we have 2d data, or 3d data that is processed en block,
180    # we only have to process a single 'chunk'
181    if ndim == 2 or (ndim == 3 and not is2d):
182        scores = process_chunk(data, labels, "yx" if ndim == 2 else "zyx")
183    elif ndim == 3 and is2d:
184        scores = []
185        for z in trange(data.shape[0]):
186            scores_z = process_chunk(data[z], labels[z], "yx", z)
187            scores.append(scores_z)
188        scores = pd.concat(scores).groupby("enhancer").mean()
189    else:
190        raise ValueError("Invalid data dimensions: {ndim}")
191
192    return scores

Evaluate enhancers on ilastik random forests from multiple projects.

Arguments:
  • data: The data for evaluation.
  • labels: The labels for evaluation.
  • enhancers: Mapping of enhancer names to filepath with enhancer models saved in the biomage.io model format.
  • ilastik_projects: Mapping of names to ilastik project paths.
  • metric: The metric used for evaluation.
  • prediction_function: Function to run prediction with the enhancer. By default the bioimageio.prediction pipeline is called directly. If given, needs to take the prediction pipeline and data (as xarray) as input.
  • rf_channel: The channel(s) of the random forest to be passed as input to the enhancer.
  • is2d: Whether to process 3d data as individual slices and average the scores. Is ignored if the data is 2d.
  • save_path: Save path for caching the random forest predictions.
Returns:

A table with the scores of the enhancers for the different forests and scores of the raw forest predictions.

def load_predictions(save_path: str, n_threads: int = 1) -> Dict[str, numpy.ndarray]:
195def load_predictions(save_path: str, n_threads: int = 1) -> Dict[str, np.ndarray]:
196    """Load predictions from a save_path created by evaluate_enhancers.
197
198    Args:
199        save_path: The path where the predictions were saved.
200        n_threads: The number of threads for loading data.
201
202    Returns:
203        A mapping of random forest names to the predictions.
204    """
205    predictions = {}
206
207    def visit(name, node):
208        if io.is_group(node):
209            return
210        node.n_threads = n_threads
211        # if we store with 'is2d' individual slices are datasets
212        try:
213            data_name = "/".join(name.split("/")[:-1])
214            z = int(name.split("/")[-1])
215            data = node[:]
216            pred = predictions.get(data_name, {})
217            pred[z] = data
218            predictions[data_name] = pred
219        # otherwise the above will throw a val error and we just load the array
220        except ValueError:
221            predictions[name] = node[:]
222
223    with io.open_file(save_path, "r") as f:
224        f.visititems(visit)
225
226    def to_vol(pred):
227        if isinstance(pred, np.ndarray):
228            return pred
229        pred = dict(sorted(pred.items()))
230        return np.concatenate([pz[None] for pz in pred.values()], axis=0)
231
232    return {name: to_vol(pred) for name, pred in predictions.items()}

Load predictions from a save_path created by evaluate_enhancers.

Arguments:
  • save_path: The path where the predictions were saved.
  • n_threads: The number of threads for loading data.
Returns:

A mapping of random forest names to the predictions.