
  1import os
  2import pickle
  3from concurrent import futures
  4from glob import glob
  5from multiprocessing import cpu_count
  6from typing import Callable, Dict, Optional, Tuple, Union
  8import bioimageio.core
  9import as io
 10import numpy as np
 11import pandas as pd
 12import xarray
 13from tqdm import tqdm, trange
 15from .prepare_shallow2deep import _apply_filters, _get_filters
 16from .shallow2deep_model import IlastikPredicter
 19def visualize_pretrained_rfs(
 20    checkpoint: str,
 21    raw: np.ndarray,
 22    n_forests: int,
 23    sample_random: bool = False,
 24    filter_config: Optional[Dict] = None,
 25    n_threads: Optional[int] = None,
 26) -> None:
 27    """Visualize pretrained random forests from a shallow2depp checkpoint.
 29    Args:
 30        checkpoint: Path to the checkpoint folder.
 31        raw: The input raw data for prediction.
 32        n_forests: The number of forests to use for visualization.
 33        sample_random: Whether to subsample forests randomly or regularly.
 34        filter_config: The filter configuration.
 35        n_threads: The number of threads for parallel prediction of forests.
 36    """
 37    import napari
 39    rf_paths = glob(os.path.join(checkpoint, "*.pkl"))
 40    if len(rf_paths) == 0:
 41        rf_folder = os.path.join(checkpoint, "rfs")
 42        assert os.path.exists(rf_folder), rf_folder
 43        rf_paths = glob(os.path.join(rf_folder, "*.pkl"))
 44    assert len(rf_paths) > 0
 45    rf_paths.sort()
 46    if sample_random:
 47        rf_paths = np.random.choice(rf_paths, size=n_forests)
 48    else:
 49        rf_paths = rf_paths[::(len(rf_paths) // n_forests)][:n_forests]
 51    print("Compute features for input of shape", raw.shape)
 52    filter_config = _get_filters(raw.ndim, filter_config)
 53    features = _apply_filters(raw, filter_config)
 55    def predict_rf(rf_path):
 56        with open(rf_path, "rb") as f:
 57            rf = pickle.load(f)
 58        pred = rf.predict_proba(features)
 59        pred = pred.reshape(raw.shape + (pred.shape[1],))
 60        pred = np.moveaxis(pred, -1, 0)
 61        assert pred.shape[1:] == raw.shape
 62        return pred
 64    n_threads = cpu_count() if n_threads is None else n_threads
 65    with futures.ThreadPoolExecutor(n_threads) as tp:
 66        preds = list(tqdm(, rf_paths), desc="Predict RFs", total=len(rf_paths)))
 68    print("Start viewer")
 69    v = napari.Viewer()
 70    for path, pred in zip(rf_paths, preds):
 71        name = os.path.basename(path)
 72        v.add_image(pred, name=name)
 73    v.add_image(raw)
 74    v.grid.enabled = True
 78def evaluate_enhancers(
 79    data: np.ndarray,
 80    labels: np.ndarray,
 81    enhancers: Dict[str, str],
 82    ilastik_projects: Dict[str, str],
 83    metric: Callable,
 84    prediction_function: Optional[Callable] = None,
 85    rf_channel: Union[int, Tuple[int, ...]] = 1,
 86    is2d: bool = False,
 87    save_path: Optional[str] = None
 88) -> pd.DataFrame:
 89    """Evaluate enhancers on ilastik random forests from multiple projects.
 91    Args:
 92        data: The data for evaluation.
 93        labels: The labels for evaluation.
 94        enhancers: Mapping of enhancer names to filepath with enhancer models saved in the model format.
 95        ilastik_projects: Mapping of names to ilastik project paths.
 96        metric: The metric used for evaluation.
 97        prediction_function: Function to run prediction with the enhancer.
 98            By default the bioimageio.prediction pipeline is called directly.
 99            If given, needs to take the prediction pipeline and data (as xarray) as input.
100        rf_channel: The channel(s) of the random forest to be passed as input to the enhancer.
101        is2d: Whether to process 3d data as individual slices and average the scores. Is ignored if the data is 2d.
102        save_path: Save path for caching the random forest predictions.
104    Returns:
105        A table with the scores of the enhancers for the different forests and scores of the raw forest predictions.
106    """
108    assert data.shape == labels.shape
109    ndim = data.ndim
110    model_ndim = 2 if (data.ndim == 2 or is2d) else 3
112    def load_enhancer(enh):
113        model = bioimageio.core.load_resource_description(enh)
114        return bioimageio.core.create_prediction_pipeline(model)
116    # load the enhancers
117    models = {name: load_enhancer(enh) for name, enh in enhancers.items()}
119    # load the ilps
120    ilps = {
121        name: IlastikPredicter(path, model_ndim, ilastik_multi_thread=True, output_channel=rf_channel)
122        for name, path in ilastik_projects.items()
123    }
125    def require_rf_prediction(rf, input_, name, axes):
126        if save_path is None:
127            return rf(input_)
128        with io.open_file(save_path, "a") as f:
129            if name in f:
130                pred = f[name][:]
131            else:
132                pred = rf(input_)
133                # require len(axes) + 2 dimensions (additional batch and channel axis)
134                pred = pred[(None,) * (len(axes) + 2 - pred.ndim)]
135                assert pred.ndim == len(axes) + 2, f"{pred.ndim}, {len(axes) + 2}"
136                f.create_dataset(name, data=pred, compression="gzip")
137            return pred
139    def require_enh_prediction(enh, rf_pred, name, prediction_function, axes):
140        if save_path is None:
141            pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred)
142            pred = pred[0]
143            return pred
144        with io.open_file(save_path, "a") as f:
145            if name in f:
146                pred = f[name][:]
147            else:
148                rf_pred = xarray.DataArray(rf_pred, dims=("b", "c",) + tuple(axes))
149                pred = enh(rf_pred) if prediction_function is None else prediction_function(enh, rf_pred)
150                pred = pred[0]
151                f.create_dataset(name, data=pred, compression="gzip")
152            return pred
154    def process_chunk(x, y, axes, z=None):
155        scores = np.zeros((len(models) + 1, len(ilps)))
156        for i, (rf_name, ilp) in enumerate(ilps.items()):
157            rf_pred = require_rf_prediction(
158                ilp, x,
159                rf_name if z is None else f"{rf_name}/{z:04}",
160                axes
161            )
162            for j, (enh_name, enh) in enumerate(models.items()):
163                pred = require_enh_prediction(
164                    enh, rf_pred,
165                    f"{enh_name}/{rf_name}" if z is None else f"{enh_name}/{rf_name}/{z:04}",
166                    prediction_function,
167                    axes
168                )
169                score = metric(pred, y)
170                scores[j, i] = score
171            score = metric(rf_pred, y)
172            scores[-1, i] = score
174        scores = pd.DataFrame(scores, columns=list(ilps.keys()))
175        scores.insert(loc=0, column="enhancer", value=list(models.keys()) + ["rf-score"])
176        return scores
178    # if we have 2d data, or 3d data that is processed en block,
179    # we only have to process a single 'chunk'
180    if ndim == 2 or (ndim == 3 and not is2d):
181        scores = process_chunk(data, labels, "yx" if ndim == 2 else "zyx")
182    elif ndim == 3 and is2d:
183        scores = []
184        for z in trange(data.shape[0]):
185            scores_z = process_chunk(data[z], labels[z], "yx", z)
186            scores.append(scores_z)
187        scores = pd.concat(scores).groupby("enhancer").mean()
188    else:
189        raise ValueError("Invalid data dimensions: {ndim}")
191    return scores
194def load_predictions(save_path: str, n_threads: int = 1) -> Dict[str, np.ndarray]:
195    """Load predictions from a save_path created by evaluate_enhancers.
197    Args:
198        save_path: The path where the predictions were saved.
199        n_threads: The number of threads for loading data.
201    Returns:
202        A mapping of random forest names to the predictions.
203    """
204    predictions = {}
206    def visit(name, node):
207        if io.is_group(node):
208            return
209        node.n_threads = n_threads
210        # if we store with 'is2d' individual slices are datasets
211        try:
212            data_name = "/".join(name.split("/")[:-1])
213            z = int(name.split("/")[-1])
214            data = node[:]
215            pred = predictions.get(data_name, {})
216            pred[z] = data
217            predictions[data_name] = pred
218        # otherwise the above will throw a val error and we just load the array
219        except ValueError:
220            predictions[name] = node[:]
222    with io.open_file(save_path, "r") as f:
223        f.visititems(visit)
225    def to_vol(pred):
226        if isinstance(pred, np.ndarray):
227            return pred
228        pred = dict(sorted(pred.items()))
229        return np.concatenate([pz[None] for pz in pred.values()], axis=0)
231    return {name: to_vol(pred) for name, pred in predictions.items()}
