torch_em.shallow2deep.shallow2deep_model

  1import os
  2import pickle
  3import torch
  4from torch_em.util import get_trainer
  5from torch_em.util.modelzoo import import_bioimageio_model
  6from .prepare_shallow2deep import _get_filters, _apply_filters
  7
  8# optional imports only needed for using ilastik api for the predictio
  9try:
 10    import lazyflow
 11    from ilastik.experimental.api import from_project_file
 12    # set the number of threads used by ilastik to 0.
 13    # otherwise it does not work inside of the torch loader (and we want to limit number of threads anyways)
 14    # see https://github.com/ilastik/ilastik/issues/2517
 15    # lazyflow.request.Request.reset_thread_pool(0)
 16    # added this to the constructors via boolean flag
 17except ImportError:
 18    from_project_file = None
 19try:
 20    from xarray import DataArray
 21except ImportError:
 22    DataArray = None
 23
 24
 25class RFWithFilters:
 26    def __init__(self, rf_path, ndim, filter_config, output_channel=None):
 27        with open(rf_path, "rb") as f:
 28            self.rf = pickle.load(f)
 29        self.filters_and_sigmas = _get_filters(ndim, filter_config)
 30        self.output_channel = output_channel
 31
 32    def __call__(self, x):
 33        features = _apply_filters(x, self.filters_and_sigmas)
 34        assert features.shape[1] == self.rf.n_features_in_, f"{features.shape[1]}, {self.rf.n_features_in_}"
 35        out = self.rf.predict_proba(features)
 36        if self.output_channel is None:
 37            out_shape = (out.shape[1],) + x.shape
 38        else:
 39            out = out[:, self.output_channel]
 40            out_shape = x.shape if isinstance(self.output_channel, int) else (len(self.output_channel),) + x.shape
 41        out = out.reshape(out_shape).astype("float32")
 42        return out
 43
 44
 45# TODO need installation that does not downgrade numpy; talk to Dominik about this
 46# currently ilastik-api deps are installed via:
 47# conda install --strict-channel-priority -c ilastik-forge/label/freepy -c conda-forge ilastik-core
 48# print hint on how to install it once this is more stable
 49class IlastikPredicter:
 50    def __init__(self, ilp_path, ndim, ilastik_multi_thread, output_channel=None):
 51        assert from_project_file is not None
 52        assert DataArray is not None
 53        assert ndim in (2, 3)
 54        if not ilastik_multi_thread:
 55            lazyflow.request.Request.reset_thread_pool(0)
 56        self.ilp = from_project_file(ilp_path)
 57        self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x")
 58        self.output_channel = output_channel
 59
 60    def __call__(self, x):
 61        assert x.ndim == len(self.dims), f"{x.ndim}, {self.dims}"
 62        try:
 63            out = self.ilp.predict(DataArray(x, dims=self.dims)).values
 64        except ValueError as e:
 65            # this is a bit of a dirty hack for projects that are trained to classify in 2d, but with 3d data
 66            # and thus need a singleton z axis. It would be better to ask this of the ilastik classifier, see
 67            # https://github.com/ilastik/ilastik/issues/2530
 68            if x.ndim == 2:
 69                x = x[None]
 70                dims = ("z",) + self.dims
 71                out = self.ilp.predict(DataArray(x, dims=dims)).values
 72                assert out.shape[0] == 1
 73                # get rid of the singleton z-axis
 74                out = out[0]
 75            else:
 76                raise e
 77        if self.output_channel is not None:
 78            out = out[..., self.output_channel]
 79        return out
 80
 81
 82class Shallow2DeepModel:
 83
 84    @staticmethod
 85    def load_model(checkpoint, device):
 86        try:
 87            model = get_trainer(checkpoint, device=device).model
 88            model.eval()
 89            return model
 90        except Exception as e:
 91            print("Could not load torch_em checkpoint from", checkpoint, "due to exception:", e)
 92            print("Trying to load as bioimageio model instead")
 93        model = import_bioimageio_model(checkpoint, device=device)[0]
 94        model.eval()
 95        return model
 96
 97    @staticmethod
 98    def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False):
 99        if len(rf_config) == 3:  # random forest path and feature config
100            rf_path, ndim, filter_config = rf_config
101            assert os.path.exists(rf_path)
102            return RFWithFilters(rf_path, ndim, filter_config, rf_channel)
103        elif len(rf_config) == 2:  # ilastik project and dimensionality
104            ilp_path, ndim = rf_config
105            return IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, rf_channel)
106        else:
107            raise ValueError(f"Invalid rf config: {rf_config}")
108
109    def __init__(self, checkpoint, rf_config, device, rf_channel=1, ilastik_multi_thread=False):
110        self.model = self.load_model(checkpoint, device)
111        self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread)
112        self.device = device
113
114        self.checkpoint = checkpoint
115        self.rf_config = rf_config
116        self.device = device
117
118    def __call__(self, x):
119        # TODO support batch axis and multiple input channels
120        out = self.rf_predicter(x[0, 0].cpu().detach().numpy())
121        out = torch.from_numpy(out[None, None]).to(self.device)
122        out = self.model(out)
123        return out
124
125    # need to overwrite pickle to support the rf / ilastik predicter
126    def __getstate__(self):
127        state = self.__dict__.copy()
128        del state["rf_predicter"]
129        return state
130
131    def __setstate__(self, state):
132        state["rf_predicter"] = self.load_rf(state["rf_config"])
133        self.__dict__.update(state)
class RFWithFilters:
26class RFWithFilters:
27    def __init__(self, rf_path, ndim, filter_config, output_channel=None):
28        with open(rf_path, "rb") as f:
29            self.rf = pickle.load(f)
30        self.filters_and_sigmas = _get_filters(ndim, filter_config)
31        self.output_channel = output_channel
32
33    def __call__(self, x):
34        features = _apply_filters(x, self.filters_and_sigmas)
35        assert features.shape[1] == self.rf.n_features_in_, f"{features.shape[1]}, {self.rf.n_features_in_}"
36        out = self.rf.predict_proba(features)
37        if self.output_channel is None:
38            out_shape = (out.shape[1],) + x.shape
39        else:
40            out = out[:, self.output_channel]
41            out_shape = x.shape if isinstance(self.output_channel, int) else (len(self.output_channel),) + x.shape
42        out = out.reshape(out_shape).astype("float32")
43        return out
RFWithFilters(rf_path, ndim, filter_config, output_channel=None)
27    def __init__(self, rf_path, ndim, filter_config, output_channel=None):
28        with open(rf_path, "rb") as f:
29            self.rf = pickle.load(f)
30        self.filters_and_sigmas = _get_filters(ndim, filter_config)
31        self.output_channel = output_channel
filters_and_sigmas
output_channel
class IlastikPredicter:
50class IlastikPredicter:
51    def __init__(self, ilp_path, ndim, ilastik_multi_thread, output_channel=None):
52        assert from_project_file is not None
53        assert DataArray is not None
54        assert ndim in (2, 3)
55        if not ilastik_multi_thread:
56            lazyflow.request.Request.reset_thread_pool(0)
57        self.ilp = from_project_file(ilp_path)
58        self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x")
59        self.output_channel = output_channel
60
61    def __call__(self, x):
62        assert x.ndim == len(self.dims), f"{x.ndim}, {self.dims}"
63        try:
64            out = self.ilp.predict(DataArray(x, dims=self.dims)).values
65        except ValueError as e:
66            # this is a bit of a dirty hack for projects that are trained to classify in 2d, but with 3d data
67            # and thus need a singleton z axis. It would be better to ask this of the ilastik classifier, see
68            # https://github.com/ilastik/ilastik/issues/2530
69            if x.ndim == 2:
70                x = x[None]
71                dims = ("z",) + self.dims
72                out = self.ilp.predict(DataArray(x, dims=dims)).values
73                assert out.shape[0] == 1
74                # get rid of the singleton z-axis
75                out = out[0]
76            else:
77                raise e
78        if self.output_channel is not None:
79            out = out[..., self.output_channel]
80        return out
IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, output_channel=None)
51    def __init__(self, ilp_path, ndim, ilastik_multi_thread, output_channel=None):
52        assert from_project_file is not None
53        assert DataArray is not None
54        assert ndim in (2, 3)
55        if not ilastik_multi_thread:
56            lazyflow.request.Request.reset_thread_pool(0)
57        self.ilp = from_project_file(ilp_path)
58        self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x")
59        self.output_channel = output_channel
ilp
dims
output_channel
class Shallow2DeepModel:
 83class Shallow2DeepModel:
 84
 85    @staticmethod
 86    def load_model(checkpoint, device):
 87        try:
 88            model = get_trainer(checkpoint, device=device).model
 89            model.eval()
 90            return model
 91        except Exception as e:
 92            print("Could not load torch_em checkpoint from", checkpoint, "due to exception:", e)
 93            print("Trying to load as bioimageio model instead")
 94        model = import_bioimageio_model(checkpoint, device=device)[0]
 95        model.eval()
 96        return model
 97
 98    @staticmethod
 99    def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False):
100        if len(rf_config) == 3:  # random forest path and feature config
101            rf_path, ndim, filter_config = rf_config
102            assert os.path.exists(rf_path)
103            return RFWithFilters(rf_path, ndim, filter_config, rf_channel)
104        elif len(rf_config) == 2:  # ilastik project and dimensionality
105            ilp_path, ndim = rf_config
106            return IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, rf_channel)
107        else:
108            raise ValueError(f"Invalid rf config: {rf_config}")
109
110    def __init__(self, checkpoint, rf_config, device, rf_channel=1, ilastik_multi_thread=False):
111        self.model = self.load_model(checkpoint, device)
112        self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread)
113        self.device = device
114
115        self.checkpoint = checkpoint
116        self.rf_config = rf_config
117        self.device = device
118
119    def __call__(self, x):
120        # TODO support batch axis and multiple input channels
121        out = self.rf_predicter(x[0, 0].cpu().detach().numpy())
122        out = torch.from_numpy(out[None, None]).to(self.device)
123        out = self.model(out)
124        return out
125
126    # need to overwrite pickle to support the rf / ilastik predicter
127    def __getstate__(self):
128        state = self.__dict__.copy()
129        del state["rf_predicter"]
130        return state
131
132    def __setstate__(self, state):
133        state["rf_predicter"] = self.load_rf(state["rf_config"])
134        self.__dict__.update(state)
Shallow2DeepModel( checkpoint, rf_config, device, rf_channel=1, ilastik_multi_thread=False)
110    def __init__(self, checkpoint, rf_config, device, rf_channel=1, ilastik_multi_thread=False):
111        self.model = self.load_model(checkpoint, device)
112        self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread)
113        self.device = device
114
115        self.checkpoint = checkpoint
116        self.rf_config = rf_config
117        self.device = device
@staticmethod
def load_model(checkpoint, device):
85    @staticmethod
86    def load_model(checkpoint, device):
87        try:
88            model = get_trainer(checkpoint, device=device).model
89            model.eval()
90            return model
91        except Exception as e:
92            print("Could not load torch_em checkpoint from", checkpoint, "due to exception:", e)
93            print("Trying to load as bioimageio model instead")
94        model = import_bioimageio_model(checkpoint, device=device)[0]
95        model.eval()
96        return model
@staticmethod
def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False):
 98    @staticmethod
 99    def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False):
100        if len(rf_config) == 3:  # random forest path and feature config
101            rf_path, ndim, filter_config = rf_config
102            assert os.path.exists(rf_path)
103            return RFWithFilters(rf_path, ndim, filter_config, rf_channel)
104        elif len(rf_config) == 2:  # ilastik project and dimensionality
105            ilp_path, ndim = rf_config
106            return IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, rf_channel)
107        else:
108            raise ValueError(f"Invalid rf config: {rf_config}")
model
rf_predicter
device
checkpoint
rf_config