torch_em.shallow2deep.shallow2deep_model
1import os 2import pickle 3import torch 4from torch_em.util import get_trainer 5from torch_em.util.modelzoo import import_bioimageio_model 6from .prepare_shallow2deep import _get_filters, _apply_filters 7 8# optional imports only needed for using ilastik api for the predictio 9try: 10 import lazyflow 11 from ilastik.experimental.api import from_project_file 12 # set the number of threads used by ilastik to 0. 13 # otherwise it does not work inside of the torch loader (and we want to limit number of threads anyways) 14 # see https://github.com/ilastik/ilastik/issues/2517 15 # lazyflow.request.Request.reset_thread_pool(0) 16 # added this to the constructors via boolean flag 17except ImportError: 18 from_project_file = None 19try: 20 from xarray import DataArray 21except ImportError: 22 DataArray = None 23 24 25class RFWithFilters: 26 def __init__(self, rf_path, ndim, filter_config, output_channel=None): 27 with open(rf_path, "rb") as f: 28 self.rf = pickle.load(f) 29 self.filters_and_sigmas = _get_filters(ndim, filter_config) 30 self.output_channel = output_channel 31 32 def __call__(self, x): 33 features = _apply_filters(x, self.filters_and_sigmas) 34 assert features.shape[1] == self.rf.n_features_in_, f"{features.shape[1]}, {self.rf.n_features_in_}" 35 out = self.rf.predict_proba(features) 36 if self.output_channel is None: 37 out_shape = (out.shape[1],) + x.shape 38 else: 39 out = out[:, self.output_channel] 40 out_shape = x.shape if isinstance(self.output_channel, int) else (len(self.output_channel),) + x.shape 41 out = out.reshape(out_shape).astype("float32") 42 return out 43 44 45# TODO need installation that does not downgrade numpy; talk to Dominik about this 46# currently ilastik-api deps are installed via: 47# conda install --strict-channel-priority -c ilastik-forge/label/freepy -c conda-forge ilastik-core 48# print hint on how to install it once this is more stable 49class IlastikPredicter: 50 def __init__(self, ilp_path, ndim, ilastik_multi_thread, output_channel=None): 51 assert from_project_file is not None 52 assert DataArray is not None 53 assert ndim in (2, 3) 54 if not ilastik_multi_thread: 55 lazyflow.request.Request.reset_thread_pool(0) 56 self.ilp = from_project_file(ilp_path) 57 self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x") 58 self.output_channel = output_channel 59 60 def __call__(self, x): 61 assert x.ndim == len(self.dims), f"{x.ndim}, {self.dims}" 62 try: 63 out = self.ilp.predict(DataArray(x, dims=self.dims)).values 64 except ValueError as e: 65 # this is a bit of a dirty hack for projects that are trained to classify in 2d, but with 3d data 66 # and thus need a singleton z axis. It would be better to ask this of the ilastik classifier, see 67 # https://github.com/ilastik/ilastik/issues/2530 68 if x.ndim == 2: 69 x = x[None] 70 dims = ("z",) + self.dims 71 out = self.ilp.predict(DataArray(x, dims=dims)).values 72 assert out.shape[0] == 1 73 # get rid of the singleton z-axis 74 out = out[0] 75 else: 76 raise e 77 if self.output_channel is not None: 78 out = out[..., self.output_channel] 79 return out 80 81 82class Shallow2DeepModel: 83 84 @staticmethod 85 def load_model(checkpoint, device): 86 try: 87 model = get_trainer(checkpoint, device=device).model 88 model.eval() 89 return model 90 except Exception as e: 91 print("Could not load torch_em checkpoint from", checkpoint, "due to exception:", e) 92 print("Trying to load as bioimageio model instead") 93 model = import_bioimageio_model(checkpoint, device=device)[0] 94 model.eval() 95 return model 96 97 @staticmethod 98 def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False): 99 if len(rf_config) == 3: # random forest path and feature config 100 rf_path, ndim, filter_config = rf_config 101 assert os.path.exists(rf_path) 102 return RFWithFilters(rf_path, ndim, filter_config, rf_channel) 103 elif len(rf_config) == 2: # ilastik project and dimensionality 104 ilp_path, ndim = rf_config 105 return IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, rf_channel) 106 else: 107 raise ValueError(f"Invalid rf config: {rf_config}") 108 109 def __init__(self, checkpoint, rf_config, device, rf_channel=1, ilastik_multi_thread=False): 110 self.model = self.load_model(checkpoint, device) 111 self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread) 112 self.device = device 113 114 self.checkpoint = checkpoint 115 self.rf_config = rf_config 116 self.device = device 117 118 def __call__(self, x): 119 # TODO support batch axis and multiple input channels 120 out = self.rf_predicter(x[0, 0].cpu().detach().numpy()) 121 out = torch.from_numpy(out[None, None]).to(self.device) 122 out = self.model(out) 123 return out 124 125 # need to overwrite pickle to support the rf / ilastik predicter 126 def __getstate__(self): 127 state = self.__dict__.copy() 128 del state["rf_predicter"] 129 return state 130 131 def __setstate__(self, state): 132 state["rf_predicter"] = self.load_rf(state["rf_config"]) 133 self.__dict__.update(state)
class
RFWithFilters:
26class RFWithFilters: 27 def __init__(self, rf_path, ndim, filter_config, output_channel=None): 28 with open(rf_path, "rb") as f: 29 self.rf = pickle.load(f) 30 self.filters_and_sigmas = _get_filters(ndim, filter_config) 31 self.output_channel = output_channel 32 33 def __call__(self, x): 34 features = _apply_filters(x, self.filters_and_sigmas) 35 assert features.shape[1] == self.rf.n_features_in_, f"{features.shape[1]}, {self.rf.n_features_in_}" 36 out = self.rf.predict_proba(features) 37 if self.output_channel is None: 38 out_shape = (out.shape[1],) + x.shape 39 else: 40 out = out[:, self.output_channel] 41 out_shape = x.shape if isinstance(self.output_channel, int) else (len(self.output_channel),) + x.shape 42 out = out.reshape(out_shape).astype("float32") 43 return out
class
IlastikPredicter:
50class IlastikPredicter: 51 def __init__(self, ilp_path, ndim, ilastik_multi_thread, output_channel=None): 52 assert from_project_file is not None 53 assert DataArray is not None 54 assert ndim in (2, 3) 55 if not ilastik_multi_thread: 56 lazyflow.request.Request.reset_thread_pool(0) 57 self.ilp = from_project_file(ilp_path) 58 self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x") 59 self.output_channel = output_channel 60 61 def __call__(self, x): 62 assert x.ndim == len(self.dims), f"{x.ndim}, {self.dims}" 63 try: 64 out = self.ilp.predict(DataArray(x, dims=self.dims)).values 65 except ValueError as e: 66 # this is a bit of a dirty hack for projects that are trained to classify in 2d, but with 3d data 67 # and thus need a singleton z axis. It would be better to ask this of the ilastik classifier, see 68 # https://github.com/ilastik/ilastik/issues/2530 69 if x.ndim == 2: 70 x = x[None] 71 dims = ("z",) + self.dims 72 out = self.ilp.predict(DataArray(x, dims=dims)).values 73 assert out.shape[0] == 1 74 # get rid of the singleton z-axis 75 out = out[0] 76 else: 77 raise e 78 if self.output_channel is not None: 79 out = out[..., self.output_channel] 80 return out
IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, output_channel=None)
51 def __init__(self, ilp_path, ndim, ilastik_multi_thread, output_channel=None): 52 assert from_project_file is not None 53 assert DataArray is not None 54 assert ndim in (2, 3) 55 if not ilastik_multi_thread: 56 lazyflow.request.Request.reset_thread_pool(0) 57 self.ilp = from_project_file(ilp_path) 58 self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x") 59 self.output_channel = output_channel
class
Shallow2DeepModel:
83class Shallow2DeepModel: 84 85 @staticmethod 86 def load_model(checkpoint, device): 87 try: 88 model = get_trainer(checkpoint, device=device).model 89 model.eval() 90 return model 91 except Exception as e: 92 print("Could not load torch_em checkpoint from", checkpoint, "due to exception:", e) 93 print("Trying to load as bioimageio model instead") 94 model = import_bioimageio_model(checkpoint, device=device)[0] 95 model.eval() 96 return model 97 98 @staticmethod 99 def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False): 100 if len(rf_config) == 3: # random forest path and feature config 101 rf_path, ndim, filter_config = rf_config 102 assert os.path.exists(rf_path) 103 return RFWithFilters(rf_path, ndim, filter_config, rf_channel) 104 elif len(rf_config) == 2: # ilastik project and dimensionality 105 ilp_path, ndim = rf_config 106 return IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, rf_channel) 107 else: 108 raise ValueError(f"Invalid rf config: {rf_config}") 109 110 def __init__(self, checkpoint, rf_config, device, rf_channel=1, ilastik_multi_thread=False): 111 self.model = self.load_model(checkpoint, device) 112 self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread) 113 self.device = device 114 115 self.checkpoint = checkpoint 116 self.rf_config = rf_config 117 self.device = device 118 119 def __call__(self, x): 120 # TODO support batch axis and multiple input channels 121 out = self.rf_predicter(x[0, 0].cpu().detach().numpy()) 122 out = torch.from_numpy(out[None, None]).to(self.device) 123 out = self.model(out) 124 return out 125 126 # need to overwrite pickle to support the rf / ilastik predicter 127 def __getstate__(self): 128 state = self.__dict__.copy() 129 del state["rf_predicter"] 130 return state 131 132 def __setstate__(self, state): 133 state["rf_predicter"] = self.load_rf(state["rf_config"]) 134 self.__dict__.update(state)
Shallow2DeepModel( checkpoint, rf_config, device, rf_channel=1, ilastik_multi_thread=False)
110 def __init__(self, checkpoint, rf_config, device, rf_channel=1, ilastik_multi_thread=False): 111 self.model = self.load_model(checkpoint, device) 112 self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread) 113 self.device = device 114 115 self.checkpoint = checkpoint 116 self.rf_config = rf_config 117 self.device = device
@staticmethod
def
load_model(checkpoint, device):
85 @staticmethod 86 def load_model(checkpoint, device): 87 try: 88 model = get_trainer(checkpoint, device=device).model 89 model.eval() 90 return model 91 except Exception as e: 92 print("Could not load torch_em checkpoint from", checkpoint, "due to exception:", e) 93 print("Trying to load as bioimageio model instead") 94 model = import_bioimageio_model(checkpoint, device=device)[0] 95 model.eval() 96 return model
@staticmethod
def
load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False):
98 @staticmethod 99 def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False): 100 if len(rf_config) == 3: # random forest path and feature config 101 rf_path, ndim, filter_config = rf_config 102 assert os.path.exists(rf_path) 103 return RFWithFilters(rf_path, ndim, filter_config, rf_channel) 104 elif len(rf_config) == 2: # ilastik project and dimensionality 105 ilp_path, ndim = rf_config 106 return IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, rf_channel) 107 else: 108 raise ValueError(f"Invalid rf config: {rf_config}")