torch_em.shallow2deep.shallow2deep_eval

  1import os
  2import pickle
  3from concurrent import futures
  4from glob import glob
  5from multiprocessing import cpu_count
  6
  7import numpy as np
  8import pandas as pd
  9import elf.io as io
 10from tqdm import tqdm, trange
 11
 12from .prepare_shallow2deep import _apply_filters, _get_filters
 13from .shallow2deep_model import IlastikPredicter
 14
 15
 16def visualize_pretrained_rfs(checkpoint, raw, n_forests,
 17                             sample_random=False, filter_config=None, n_threads=None):
 18    """Visualize pretrained random forests from a shallow2depp checkpoint.
 19
 20    Arguments:
 21        checkpoint [str] - path to the checkpoint folder
 22        raw [np.ndarray] - the raw data for prediction
 23        n_forests [int] - the number of forests to use
 24        sample_random [bool] - whether to subsample forests randomly or regularly (default: False)
 25        filter_config [list] - the filter configuration (default: None)
 26        n_threads [int] - number of threads for parallel prediction of forests (default: None)
 27    """
 28    import napari
 29
 30    rf_paths = glob(os.path.join(checkpoint, "*.pkl"))
 31    if len(rf_paths) == 0:
 32        rf_folder = os.path.join(checkpoint, "rfs")
 33        assert os.path.exists(rf_folder), rf_folder
 34        rf_paths = glob(os.path.join(rf_folder, "*.pkl"))
 35    assert len(rf_paths) > 0
 36    rf_paths.sort()
 37    if sample_random:
 38        rf_paths = np.random.choice(rf_paths, size=n_forests)
 39    else:
 40        rf_paths = rf_paths[::(len(rf_paths) // n_forests)][:n_forests]
 41
 42    print("Compute features for input of shape", raw.shape)
 43    filter_config = _get_filters(raw.ndim, filter_config)
 44    features = _apply_filters(raw, filter_config)
 45
 46    def predict_rf(rf_path):
 47        with open(rf_path, "rb") as f:
 48            rf = pickle.load(f)
 49        pred = rf.predict_proba(features)
 50        pred = pred.reshape(raw.shape + (pred.shape[1],))
 51        pred = np.moveaxis(pred, -1, 0)
 52        assert pred.shape[1:] == raw.shape
 53        return pred
 54
 55    n_threads = cpu_count() if n_threads is None else n_threads
 56    with futures.ThreadPoolExecutor(n_threads) as tp:
 57        preds = list(tqdm(tp.map(predict_rf, rf_paths), desc="Predict RFs", total=len(rf_paths)))
 58
 59    print("Start viewer")
 60    v = napari.Viewer()
 61    for path, pred in zip(rf_paths, preds):
 62        name = os.path.basename(path)
 63        v.add_image(pred, name=name)
 64    v.add_image(raw)
 65    v.grid.enabled = True
 66    napari.run()
 67
 68
 69def evaluate_enhancers(data, labels, enhancers, ilastik_projects, metric,
 70                       prediction_function=None, rf_channel=1, is2d=False, save_path=None):
 71    """Evaluate enhancers on ilastik random forests from multiple projects.
 72
 73    Arguments:
 74        data [np.ndarray] - the data for evaluation
 75        labels [np.ndarray] - the labels for evaluation
 76        enhancers [dict[str, str] - map of enhancer names to filepath with enhancer
 77            models saved in the biomage.io model format
 78        ilastik_projects [dict[str, str]] - map of names to ilastik project paths
 79        metric [callable] - the metric used for evaluation
 80        prediction_function [callable] - function to run prediction with the enhancer.
 81            By default the bioimageio.prediction pipeline is called directly.
 82            If given, needs to take the prediction pipeline and data (as xarray)
 83            as input (default: None)
 84        rf_channel [int, list[int]] - the channel(s) of the random forest to be passed
 85            as input to the enhancer (default: 1)
 86        is2d [bool] - whether to process 3d data as individual slices and average the scores.
 87            Is ignored if the data is 2d (default: False)
 88        save_path [str] -
 89    Returns:
 90        [pd.DataFrame] - a table with the scores of the enhancers for the different forests
 91            and scores of the raw forest predictions
 92    """
 93    import bioimageio.core
 94    import xarray
 95
 96    assert data.shape == labels.shape
 97    ndim = data.ndim
 98    model_ndim = 2 if (data.ndim == 2 or is2d) else 3
 99
100    def load_enhancer(enh):
101        model = bioimageio.core.load_resource_description(enh)
102        return bioimageio.core.create_prediction_pipeline(model)
103
104    # load the enhancers
105    models = {name: load_enhancer(enh) for name, enh in enhancers.items()}
106
107    # load the ilps
108    ilps = {
109        name: IlastikPredicter(path, model_ndim, ilastik_multi_thread=True, output_channel=rf_channel)
110        for name, path in ilastik_projects.items()
111    }
112
113    def require_rf_prediction(rf, input_, name, axes):
114        if save_path is None:
115            return rf(input_)
116        with io.open_file(save_path, "a") as f:
117            if name in f:
118                pred = f[name][:]
119            else:
120                pred = rf(input_)
121                # require len(axes) + 2 dimensions (additional batch and channel axis)
122                pred = pred[(None,) * (len(axes) + 2 - pred.ndim)]
123                assert pred.ndim == len(axes) + 2, f"{pred.ndim}, {len(axes) + 2}"
124                f.create_dataset(name, data=pred, compression="gzip")
125            return pred
126
127    def require_enh_prediction(enh, rf_pred, name, prediction_function, axes):
128        if save_path is None:
129            pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred)
130            pred = pred[0]
131            return pred
132        with io.open_file(save_path, "a") as f:
133            if name in f:
134                pred = f[name][:]
135            else:
136                rf_pred = xarray.DataArray(rf_pred, dims=("b", "c",) + tuple(axes))
137                pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred)
138                pred = pred[0]
139                f.create_dataset(name, data=pred, compression="gzip")
140            return pred
141
142    def process_chunk(x, y, axes, z=None):
143        scores = np.zeros((len(models) + 1, len(ilps)))
144        for i, (rf_name, ilp) in enumerate(ilps.items()):
145            rf_pred = require_rf_prediction(
146                ilp, x,
147                rf_name if z is None else f"{rf_name}/{z:04}",
148                axes
149            )
150            for j, (enh_name, enh) in enumerate(models.items()):
151                pred = require_enh_prediction(
152                    enh, rf_pred,
153                    f"{enh_name}/{rf_name}" if z is None else f"{enh_name}/{rf_name}/{z:04}",
154                    prediction_function,
155                    axes
156                )
157                score = metric(pred, y)
158                scores[j, i] = score
159            score = metric(rf_pred, y)
160            scores[-1, i] = score
161
162        scores = pd.DataFrame(scores, columns=list(ilps.keys()))
163        scores.insert(loc=0, column="enhancer", value=list(models.keys()) + ["rf-score"])
164        return scores
165
166    # if we have 2d data, or 3d data that is processed en block,
167    # we only have to process a single 'chunk'
168    if ndim == 2 or (ndim == 3 and not is2d):
169        scores = process_chunk(data, labels, "yx" if ndim == 2 else "zyx")
170    elif ndim == 3 and is2d:
171        scores = []
172        for z in trange(data.shape[0]):
173            scores_z = process_chunk(data[z], labels[z], "yx", z)
174            scores.append(scores_z)
175        scores = pd.concat(scores).groupby("enhancer").mean()
176    else:
177        raise ValueError("Invalid data dimensions: {ndim}")
178
179    return scores
180
181
182def load_predictions(save_path, n_threads=1):
183    """Helper functions to load predictions from a save_path created by evaluate_enhancers
184    """
185    predictions = {}
186
187    def visit(name, node):
188        if io.is_group(node):
189            return
190        node.n_threads = n_threads
191        # if we store with 'is2d' individual slices are datasets
192        try:
193            data_name = "/".join(name.split("/")[:-1])
194            z = int(name.split("/")[-1])
195            data = node[:]
196            pred = predictions.get(data_name, {})
197            pred[z] = data
198            predictions[data_name] = pred
199        # otherwise the above will throw a val error and we just load the array
200        except ValueError:
201            predictions[name] = node[:]
202
203    with io.open_file(save_path, "r") as f:
204        f.visititems(visit)
205
206    def to_vol(pred):
207        if isinstance(pred, np.ndarray):
208            return pred
209        pred = dict(sorted(pred.items()))
210        return np.concatenate([pz[None] for pz in pred.values()], axis=0)
211
212    predictions = {name: to_vol(pred) for name, pred in predictions.items()}
213
214    return predictions
def visualize_pretrained_rfs( checkpoint, raw, n_forests, sample_random=False, filter_config=None, n_threads=None):
17def visualize_pretrained_rfs(checkpoint, raw, n_forests,
18                             sample_random=False, filter_config=None, n_threads=None):
19    """Visualize pretrained random forests from a shallow2depp checkpoint.
20
21    Arguments:
22        checkpoint [str] - path to the checkpoint folder
23        raw [np.ndarray] - the raw data for prediction
24        n_forests [int] - the number of forests to use
25        sample_random [bool] - whether to subsample forests randomly or regularly (default: False)
26        filter_config [list] - the filter configuration (default: None)
27        n_threads [int] - number of threads for parallel prediction of forests (default: None)
28    """
29    import napari
30
31    rf_paths = glob(os.path.join(checkpoint, "*.pkl"))
32    if len(rf_paths) == 0:
33        rf_folder = os.path.join(checkpoint, "rfs")
34        assert os.path.exists(rf_folder), rf_folder
35        rf_paths = glob(os.path.join(rf_folder, "*.pkl"))
36    assert len(rf_paths) > 0
37    rf_paths.sort()
38    if sample_random:
39        rf_paths = np.random.choice(rf_paths, size=n_forests)
40    else:
41        rf_paths = rf_paths[::(len(rf_paths) // n_forests)][:n_forests]
42
43    print("Compute features for input of shape", raw.shape)
44    filter_config = _get_filters(raw.ndim, filter_config)
45    features = _apply_filters(raw, filter_config)
46
47    def predict_rf(rf_path):
48        with open(rf_path, "rb") as f:
49            rf = pickle.load(f)
50        pred = rf.predict_proba(features)
51        pred = pred.reshape(raw.shape + (pred.shape[1],))
52        pred = np.moveaxis(pred, -1, 0)
53        assert pred.shape[1:] == raw.shape
54        return pred
55
56    n_threads = cpu_count() if n_threads is None else n_threads
57    with futures.ThreadPoolExecutor(n_threads) as tp:
58        preds = list(tqdm(tp.map(predict_rf, rf_paths), desc="Predict RFs", total=len(rf_paths)))
59
60    print("Start viewer")
61    v = napari.Viewer()
62    for path, pred in zip(rf_paths, preds):
63        name = os.path.basename(path)
64        v.add_image(pred, name=name)
65    v.add_image(raw)
66    v.grid.enabled = True
67    napari.run()

Visualize pretrained random forests from a shallow2depp checkpoint.

Arguments:
  • checkpoint [str] - path to the checkpoint folder
  • raw [np.ndarray] - the raw data for prediction
  • n_forests [int] - the number of forests to use
  • sample_random [bool] - whether to subsample forests randomly or regularly (default: False)
  • filter_config [list] - the filter configuration (default: None)
  • n_threads [int] - number of threads for parallel prediction of forests (default: None)
def evaluate_enhancers( data, labels, enhancers, ilastik_projects, metric, prediction_function=None, rf_channel=1, is2d=False, save_path=None):
 70def evaluate_enhancers(data, labels, enhancers, ilastik_projects, metric,
 71                       prediction_function=None, rf_channel=1, is2d=False, save_path=None):
 72    """Evaluate enhancers on ilastik random forests from multiple projects.
 73
 74    Arguments:
 75        data [np.ndarray] - the data for evaluation
 76        labels [np.ndarray] - the labels for evaluation
 77        enhancers [dict[str, str] - map of enhancer names to filepath with enhancer
 78            models saved in the biomage.io model format
 79        ilastik_projects [dict[str, str]] - map of names to ilastik project paths
 80        metric [callable] - the metric used for evaluation
 81        prediction_function [callable] - function to run prediction with the enhancer.
 82            By default the bioimageio.prediction pipeline is called directly.
 83            If given, needs to take the prediction pipeline and data (as xarray)
 84            as input (default: None)
 85        rf_channel [int, list[int]] - the channel(s) of the random forest to be passed
 86            as input to the enhancer (default: 1)
 87        is2d [bool] - whether to process 3d data as individual slices and average the scores.
 88            Is ignored if the data is 2d (default: False)
 89        save_path [str] -
 90    Returns:
 91        [pd.DataFrame] - a table with the scores of the enhancers for the different forests
 92            and scores of the raw forest predictions
 93    """
 94    import bioimageio.core
 95    import xarray
 96
 97    assert data.shape == labels.shape
 98    ndim = data.ndim
 99    model_ndim = 2 if (data.ndim == 2 or is2d) else 3
100
101    def load_enhancer(enh):
102        model = bioimageio.core.load_resource_description(enh)
103        return bioimageio.core.create_prediction_pipeline(model)
104
105    # load the enhancers
106    models = {name: load_enhancer(enh) for name, enh in enhancers.items()}
107
108    # load the ilps
109    ilps = {
110        name: IlastikPredicter(path, model_ndim, ilastik_multi_thread=True, output_channel=rf_channel)
111        for name, path in ilastik_projects.items()
112    }
113
114    def require_rf_prediction(rf, input_, name, axes):
115        if save_path is None:
116            return rf(input_)
117        with io.open_file(save_path, "a") as f:
118            if name in f:
119                pred = f[name][:]
120            else:
121                pred = rf(input_)
122                # require len(axes) + 2 dimensions (additional batch and channel axis)
123                pred = pred[(None,) * (len(axes) + 2 - pred.ndim)]
124                assert pred.ndim == len(axes) + 2, f"{pred.ndim}, {len(axes) + 2}"
125                f.create_dataset(name, data=pred, compression="gzip")
126            return pred
127
128    def require_enh_prediction(enh, rf_pred, name, prediction_function, axes):
129        if save_path is None:
130            pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred)
131            pred = pred[0]
132            return pred
133        with io.open_file(save_path, "a") as f:
134            if name in f:
135                pred = f[name][:]
136            else:
137                rf_pred = xarray.DataArray(rf_pred, dims=("b", "c",) + tuple(axes))
138                pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred)
139                pred = pred[0]
140                f.create_dataset(name, data=pred, compression="gzip")
141            return pred
142
143    def process_chunk(x, y, axes, z=None):
144        scores = np.zeros((len(models) + 1, len(ilps)))
145        for i, (rf_name, ilp) in enumerate(ilps.items()):
146            rf_pred = require_rf_prediction(
147                ilp, x,
148                rf_name if z is None else f"{rf_name}/{z:04}",
149                axes
150            )
151            for j, (enh_name, enh) in enumerate(models.items()):
152                pred = require_enh_prediction(
153                    enh, rf_pred,
154                    f"{enh_name}/{rf_name}" if z is None else f"{enh_name}/{rf_name}/{z:04}",
155                    prediction_function,
156                    axes
157                )
158                score = metric(pred, y)
159                scores[j, i] = score
160            score = metric(rf_pred, y)
161            scores[-1, i] = score
162
163        scores = pd.DataFrame(scores, columns=list(ilps.keys()))
164        scores.insert(loc=0, column="enhancer", value=list(models.keys()) + ["rf-score"])
165        return scores
166
167    # if we have 2d data, or 3d data that is processed en block,
168    # we only have to process a single 'chunk'
169    if ndim == 2 or (ndim == 3 and not is2d):
170        scores = process_chunk(data, labels, "yx" if ndim == 2 else "zyx")
171    elif ndim == 3 and is2d:
172        scores = []
173        for z in trange(data.shape[0]):
174            scores_z = process_chunk(data[z], labels[z], "yx", z)
175            scores.append(scores_z)
176        scores = pd.concat(scores).groupby("enhancer").mean()
177    else:
178        raise ValueError("Invalid data dimensions: {ndim}")
179
180    return scores

Evaluate enhancers on ilastik random forests from multiple projects.

Arguments:
  • data [np.ndarray] - the data for evaluation
  • labels [np.ndarray] - the labels for evaluation
  • enhancers [dict[str, str] - map of enhancer names to filepath with enhancer models saved in the biomage.io model format
  • ilastik_projects [dict[str, str]] - map of names to ilastik project paths
  • metric [callable] - the metric used for evaluation
  • prediction_function [callable] - function to run prediction with the enhancer. By default the bioimageio.prediction pipeline is called directly. If given, needs to take the prediction pipeline and data (as xarray) as input (default: None)
  • rf_channel [int, list[int]] - the channel(s) of the random forest to be passed as input to the enhancer (default: 1)
  • is2d [bool] - whether to process 3d data as individual slices and average the scores. Is ignored if the data is 2d (default: False)
  • save_path [str] -
Returns:

[pd.DataFrame] - a table with the scores of the enhancers for the different forests and scores of the raw forest predictions

def load_predictions(save_path, n_threads=1):
183def load_predictions(save_path, n_threads=1):
184    """Helper functions to load predictions from a save_path created by evaluate_enhancers
185    """
186    predictions = {}
187
188    def visit(name, node):
189        if io.is_group(node):
190            return
191        node.n_threads = n_threads
192        # if we store with 'is2d' individual slices are datasets
193        try:
194            data_name = "/".join(name.split("/")[:-1])
195            z = int(name.split("/")[-1])
196            data = node[:]
197            pred = predictions.get(data_name, {})
198            pred[z] = data
199            predictions[data_name] = pred
200        # otherwise the above will throw a val error and we just load the array
201        except ValueError:
202            predictions[name] = node[:]
203
204    with io.open_file(save_path, "r") as f:
205        f.visititems(visit)
206
207    def to_vol(pred):
208        if isinstance(pred, np.ndarray):
209            return pred
210        pred = dict(sorted(pred.items()))
211        return np.concatenate([pz[None] for pz in pred.values()], axis=0)
212
213    predictions = {name: to_vol(pred) for name, pred in predictions.items()}
214
215    return predictions

Helper functions to load predictions from a save_path created by evaluate_enhancers