torch_em.shallow2deep.shallow2deep_model

  1import os
  2import pickle
  3from typing import Dict, Optional, Tuple, Union
  4
  5import torch
  6import numpy as np
  7
  8from torch_em.util import get_trainer
  9from torch_em.util.modelzoo import import_bioimageio_model
 10from .prepare_shallow2deep import _get_filters, _apply_filters
 11
 12# optional imports only needed for using ilastik api for the predictio
 13try:
 14    import lazyflow
 15    from ilastik.experimental.api import from_project_file
 16    # set the number of threads used by ilastik to 0.
 17    # otherwise it does not work inside of the torch loader (and we want to limit number of threads anyways)
 18    # see https://github.com/ilastik/ilastik/issues/2517
 19    # lazyflow.request.Request.reset_thread_pool(0)
 20    # added this to the constructors via boolean flag
 21except ImportError:
 22    from_project_file = None
 23
 24try:
 25    from xarray import DataArray
 26except ImportError:
 27    DataArray = None
 28
 29
 30class RFWithFilters:
 31    """Wrapper to apply feature computation and random forest prediction.
 32
 33    Args:
 34        rf_path: The path to the trained random forest.
 35        ndim: The dimensionality of the input data.
 36        filter_config: The configuration of the filters.
 37        output_channel: The output channel of the random forest prediction to keep.
 38    """
 39    def __init__(
 40        self, rf_path: str, ndim: int, filter_config: Dict, output_channel: Optional[Union[Tuple[int, ...], int]] = None
 41    ):
 42        with open(rf_path, "rb") as f:
 43            self.rf = pickle.load(f)
 44        self.filters_and_sigmas = _get_filters(ndim, filter_config)
 45        self.output_channel = output_channel
 46
 47    def __call__(self, x: np.ndarray) -> np.ndarray:
 48        """Apply feature computation and random forest prediction to input data.
 49
 50        Args:
 51            x: The input data.
 52
 53        Returns:
 54            The random forest prediction.
 55        """
 56        features = _apply_filters(x, self.filters_and_sigmas)
 57        assert features.shape[1] == self.rf.n_features_in_, f"{features.shape[1]}, {self.rf.n_features_in_}"
 58        out = self.rf.predict_proba(features)
 59        if self.output_channel is None:
 60            out_shape = (out.shape[1],) + x.shape
 61        else:
 62            out = out[:, self.output_channel]
 63            out_shape = x.shape if isinstance(self.output_channel, int) else (len(self.output_channel),) + x.shape
 64        out = out.reshape(out_shape).astype("float32")
 65        return out
 66
 67
 68class IlastikPredicter:
 69    """Wrapper to apply a trained ilastik pixel classification project.
 70
 71    Args:
 72        ilp_path: The file path to the ilastik project.
 73        ndim: The dimensionality of the input data.
 74        ilastik_multi_thread: Whether to use multi-threaded prediction with ilastik.
 75        output_channel: The output channel of the ilastik prediction to keep.
 76    """
 77    def __init__(
 78        self,
 79        ilp_path: str,
 80        ndim: int,
 81        ilastik_multi_thread: bool,
 82        output_channel: Optional[Union[int, Tuple[int, ...]]] = None,
 83    ):
 84        assert from_project_file is not None
 85        assert DataArray is not None
 86        assert ndim in (2, 3)
 87        if not ilastik_multi_thread:
 88            lazyflow.request.Request.reset_thread_pool(0)
 89        self.ilp = from_project_file(ilp_path)
 90        self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x")
 91        self.output_channel = output_channel
 92
 93    def __call__(self, x: np.ndarray) -> np.ndarray:
 94        """Apply ilastik project to input data.
 95
 96        Args:
 97            x: The input data.
 98
 99        Returns:
100            The ilastik prediction.
101        """
102        assert x.ndim == len(self.dims), f"{x.ndim}, {self.dims}"
103        try:
104            out = self.ilp.predict(DataArray(x, dims=self.dims)).values
105        except ValueError as e:
106            # this is a bit of a dirty hack for projects that are trained to classify in 2d, but with 3d data
107            # and thus need a singleton z axis. It would be better to ask this of the ilastik classifier, see
108            # https://github.com/ilastik/ilastik/issues/2530
109            if x.ndim == 2:
110                x = x[None]
111                dims = ("z",) + self.dims
112                out = self.ilp.predict(DataArray(x, dims=dims)).values
113                assert out.shape[0] == 1
114                # get rid of the singleton z-axis
115                out = out[0]
116            else:
117                raise e
118        if self.output_channel is not None:
119            out = out[..., self.output_channel]
120        return out
121
122
123class Shallow2DeepModel:
124    """Wrapper to apply a shallow2deep enhancer model to raw data.
125
126    First runs prediction with the random forest and then applies the enhancer model
127    to the random forest predictions.
128
129    Args:
130        checkpoint: The checkpoint of the enhancer model.
131        rf_config: The feature configuration of the random forest.
132        device: The device for the enhancer
133        rf_channel: The channel of the random forest prediction to use as input to the enhancer.
134        ilastik_multi_thread: Whether to use ilastik mulit-threaded prediction.
135    """
136
137    @staticmethod
138    def load_model(checkpoint, device):
139        """@private
140        """
141        try:
142            model = get_trainer(checkpoint, device=device).model
143            model.eval()
144            return model
145        except Exception as e:
146            print("Could not load torch_em checkpoint from", checkpoint, "due to exception:", e)
147            print("Trying to load as bioimageio model instead")
148        model = import_bioimageio_model(checkpoint, device=device)[0]
149        model.eval()
150        return model
151
152    @staticmethod
153    def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False):
154        """@private
155        """
156        if len(rf_config) == 3:  # random forest path and feature config
157            rf_path, ndim, filter_config = rf_config
158            assert os.path.exists(rf_path)
159            return RFWithFilters(rf_path, ndim, filter_config, rf_channel)
160        elif len(rf_config) == 2:  # ilastik project and dimensionality
161            ilp_path, ndim = rf_config
162            return IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, rf_channel)
163        else:
164            raise ValueError(f"Invalid rf config: {rf_config}")
165
166    def __init__(
167        self,
168        checkpoint: str,
169        rf_config: Dict,
170        device: str,
171        rf_channel: Optional[int] = 1,
172        ilastik_multi_thread: bool = False,
173    ):
174        self.model = self.load_model(checkpoint, device)
175        self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread)
176        self.device = device
177
178        self.checkpoint = checkpoint
179        self.rf_config = rf_config
180        self.device = device
181
182    def __call__(self, x: np.ndarray) -> np.ndarray:
183        """Apply the Shallow2Deep Model to the input data.
184
185        Args:
186            x: The input data.
187
188        Returns:
189            The shallow2deep predictions.
190        """
191        # TODO support batch axis and multiple input channels
192        out = self.rf_predicter(x[0, 0].cpu().detach().numpy())
193        out = torch.from_numpy(out[None, None]).to(self.device)
194        out = self.model(out)
195        return out
196
197    # need to overwrite pickle to support the rf / ilastik predicter
198    def __getstate__(self):
199        state = self.__dict__.copy()
200        del state["rf_predicter"]
201        return state
202
203    def __setstate__(self, state):
204        state["rf_predicter"] = self.load_rf(state["rf_config"])
205        self.__dict__.update(state)
class RFWithFilters:
31class RFWithFilters:
32    """Wrapper to apply feature computation and random forest prediction.
33
34    Args:
35        rf_path: The path to the trained random forest.
36        ndim: The dimensionality of the input data.
37        filter_config: The configuration of the filters.
38        output_channel: The output channel of the random forest prediction to keep.
39    """
40    def __init__(
41        self, rf_path: str, ndim: int, filter_config: Dict, output_channel: Optional[Union[Tuple[int, ...], int]] = None
42    ):
43        with open(rf_path, "rb") as f:
44            self.rf = pickle.load(f)
45        self.filters_and_sigmas = _get_filters(ndim, filter_config)
46        self.output_channel = output_channel
47
48    def __call__(self, x: np.ndarray) -> np.ndarray:
49        """Apply feature computation and random forest prediction to input data.
50
51        Args:
52            x: The input data.
53
54        Returns:
55            The random forest prediction.
56        """
57        features = _apply_filters(x, self.filters_and_sigmas)
58        assert features.shape[1] == self.rf.n_features_in_, f"{features.shape[1]}, {self.rf.n_features_in_}"
59        out = self.rf.predict_proba(features)
60        if self.output_channel is None:
61            out_shape = (out.shape[1],) + x.shape
62        else:
63            out = out[:, self.output_channel]
64            out_shape = x.shape if isinstance(self.output_channel, int) else (len(self.output_channel),) + x.shape
65        out = out.reshape(out_shape).astype("float32")
66        return out

Wrapper to apply feature computation and random forest prediction.

Arguments:
  • rf_path: The path to the trained random forest.
  • ndim: The dimensionality of the input data.
  • filter_config: The configuration of the filters.
  • output_channel: The output channel of the random forest prediction to keep.
RFWithFilters( rf_path: str, ndim: int, filter_config: Dict, output_channel: Union[Tuple[int, ...], int, NoneType] = None)
40    def __init__(
41        self, rf_path: str, ndim: int, filter_config: Dict, output_channel: Optional[Union[Tuple[int, ...], int]] = None
42    ):
43        with open(rf_path, "rb") as f:
44            self.rf = pickle.load(f)
45        self.filters_and_sigmas = _get_filters(ndim, filter_config)
46        self.output_channel = output_channel
filters_and_sigmas
output_channel
class IlastikPredicter:
 69class IlastikPredicter:
 70    """Wrapper to apply a trained ilastik pixel classification project.
 71
 72    Args:
 73        ilp_path: The file path to the ilastik project.
 74        ndim: The dimensionality of the input data.
 75        ilastik_multi_thread: Whether to use multi-threaded prediction with ilastik.
 76        output_channel: The output channel of the ilastik prediction to keep.
 77    """
 78    def __init__(
 79        self,
 80        ilp_path: str,
 81        ndim: int,
 82        ilastik_multi_thread: bool,
 83        output_channel: Optional[Union[int, Tuple[int, ...]]] = None,
 84    ):
 85        assert from_project_file is not None
 86        assert DataArray is not None
 87        assert ndim in (2, 3)
 88        if not ilastik_multi_thread:
 89            lazyflow.request.Request.reset_thread_pool(0)
 90        self.ilp = from_project_file(ilp_path)
 91        self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x")
 92        self.output_channel = output_channel
 93
 94    def __call__(self, x: np.ndarray) -> np.ndarray:
 95        """Apply ilastik project to input data.
 96
 97        Args:
 98            x: The input data.
 99
100        Returns:
101            The ilastik prediction.
102        """
103        assert x.ndim == len(self.dims), f"{x.ndim}, {self.dims}"
104        try:
105            out = self.ilp.predict(DataArray(x, dims=self.dims)).values
106        except ValueError as e:
107            # this is a bit of a dirty hack for projects that are trained to classify in 2d, but with 3d data
108            # and thus need a singleton z axis. It would be better to ask this of the ilastik classifier, see
109            # https://github.com/ilastik/ilastik/issues/2530
110            if x.ndim == 2:
111                x = x[None]
112                dims = ("z",) + self.dims
113                out = self.ilp.predict(DataArray(x, dims=dims)).values
114                assert out.shape[0] == 1
115                # get rid of the singleton z-axis
116                out = out[0]
117            else:
118                raise e
119        if self.output_channel is not None:
120            out = out[..., self.output_channel]
121        return out

Wrapper to apply a trained ilastik pixel classification project.

Arguments:
  • ilp_path: The file path to the ilastik project.
  • ndim: The dimensionality of the input data.
  • ilastik_multi_thread: Whether to use multi-threaded prediction with ilastik.
  • output_channel: The output channel of the ilastik prediction to keep.
IlastikPredicter( ilp_path: str, ndim: int, ilastik_multi_thread: bool, output_channel: Union[Tuple[int, ...], int, NoneType] = None)
78    def __init__(
79        self,
80        ilp_path: str,
81        ndim: int,
82        ilastik_multi_thread: bool,
83        output_channel: Optional[Union[int, Tuple[int, ...]]] = None,
84    ):
85        assert from_project_file is not None
86        assert DataArray is not None
87        assert ndim in (2, 3)
88        if not ilastik_multi_thread:
89            lazyflow.request.Request.reset_thread_pool(0)
90        self.ilp = from_project_file(ilp_path)
91        self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x")
92        self.output_channel = output_channel
ilp
dims
output_channel
class Shallow2DeepModel:
124class Shallow2DeepModel:
125    """Wrapper to apply a shallow2deep enhancer model to raw data.
126
127    First runs prediction with the random forest and then applies the enhancer model
128    to the random forest predictions.
129
130    Args:
131        checkpoint: The checkpoint of the enhancer model.
132        rf_config: The feature configuration of the random forest.
133        device: The device for the enhancer
134        rf_channel: The channel of the random forest prediction to use as input to the enhancer.
135        ilastik_multi_thread: Whether to use ilastik mulit-threaded prediction.
136    """
137
138    @staticmethod
139    def load_model(checkpoint, device):
140        """@private
141        """
142        try:
143            model = get_trainer(checkpoint, device=device).model
144            model.eval()
145            return model
146        except Exception as e:
147            print("Could not load torch_em checkpoint from", checkpoint, "due to exception:", e)
148            print("Trying to load as bioimageio model instead")
149        model = import_bioimageio_model(checkpoint, device=device)[0]
150        model.eval()
151        return model
152
153    @staticmethod
154    def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False):
155        """@private
156        """
157        if len(rf_config) == 3:  # random forest path and feature config
158            rf_path, ndim, filter_config = rf_config
159            assert os.path.exists(rf_path)
160            return RFWithFilters(rf_path, ndim, filter_config, rf_channel)
161        elif len(rf_config) == 2:  # ilastik project and dimensionality
162            ilp_path, ndim = rf_config
163            return IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, rf_channel)
164        else:
165            raise ValueError(f"Invalid rf config: {rf_config}")
166
167    def __init__(
168        self,
169        checkpoint: str,
170        rf_config: Dict,
171        device: str,
172        rf_channel: Optional[int] = 1,
173        ilastik_multi_thread: bool = False,
174    ):
175        self.model = self.load_model(checkpoint, device)
176        self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread)
177        self.device = device
178
179        self.checkpoint = checkpoint
180        self.rf_config = rf_config
181        self.device = device
182
183    def __call__(self, x: np.ndarray) -> np.ndarray:
184        """Apply the Shallow2Deep Model to the input data.
185
186        Args:
187            x: The input data.
188
189        Returns:
190            The shallow2deep predictions.
191        """
192        # TODO support batch axis and multiple input channels
193        out = self.rf_predicter(x[0, 0].cpu().detach().numpy())
194        out = torch.from_numpy(out[None, None]).to(self.device)
195        out = self.model(out)
196        return out
197
198    # need to overwrite pickle to support the rf / ilastik predicter
199    def __getstate__(self):
200        state = self.__dict__.copy()
201        del state["rf_predicter"]
202        return state
203
204    def __setstate__(self, state):
205        state["rf_predicter"] = self.load_rf(state["rf_config"])
206        self.__dict__.update(state)

Wrapper to apply a shallow2deep enhancer model to raw data.

First runs prediction with the random forest and then applies the enhancer model to the random forest predictions.

Arguments:
  • checkpoint: The checkpoint of the enhancer model.
  • rf_config: The feature configuration of the random forest.
  • device: The device for the enhancer
  • rf_channel: The channel of the random forest prediction to use as input to the enhancer.
  • ilastik_multi_thread: Whether to use ilastik mulit-threaded prediction.
Shallow2DeepModel( checkpoint: str, rf_config: Dict, device: str, rf_channel: Optional[int] = 1, ilastik_multi_thread: bool = False)
167    def __init__(
168        self,
169        checkpoint: str,
170        rf_config: Dict,
171        device: str,
172        rf_channel: Optional[int] = 1,
173        ilastik_multi_thread: bool = False,
174    ):
175        self.model = self.load_model(checkpoint, device)
176        self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread)
177        self.device = device
178
179        self.checkpoint = checkpoint
180        self.rf_config = rf_config
181        self.device = device
model
rf_predicter
device
checkpoint
rf_config