torch_em.shallow2deep.shallow2deep_model
1import os 2import pickle 3from typing import Dict, Optional, Tuple, Union 4 5import torch 6import numpy as np 7 8from torch_em.util import get_trainer 9from torch_em.util.modelzoo import import_bioimageio_model 10from .prepare_shallow2deep import _get_filters, _apply_filters 11 12# optional imports only needed for using ilastik api for the predictio 13try: 14 import lazyflow 15 from ilastik.experimental.api import from_project_file 16 # set the number of threads used by ilastik to 0. 17 # otherwise it does not work inside of the torch loader (and we want to limit number of threads anyways) 18 # see https://github.com/ilastik/ilastik/issues/2517 19 # lazyflow.request.Request.reset_thread_pool(0) 20 # added this to the constructors via boolean flag 21except ImportError: 22 from_project_file = None 23 24try: 25 from xarray import DataArray 26except ImportError: 27 DataArray = None 28 29 30class RFWithFilters: 31 """Wrapper to apply feature computation and random forest prediction. 32 33 Args: 34 rf_path: The path to the trained random forest. 35 ndim: The dimensionality of the input data. 36 filter_config: The configuration of the filters. 37 output_channel: The output channel of the random forest prediction to keep. 38 """ 39 def __init__( 40 self, rf_path: str, ndim: int, filter_config: Dict, output_channel: Optional[Union[Tuple[int, ...], int]] = None 41 ): 42 with open(rf_path, "rb") as f: 43 self.rf = pickle.load(f) 44 self.filters_and_sigmas = _get_filters(ndim, filter_config) 45 self.output_channel = output_channel 46 47 def __call__(self, x: np.ndarray) -> np.ndarray: 48 """Apply feature computation and random forest prediction to input data. 49 50 Args: 51 x: The input data. 52 53 Returns: 54 The random forest prediction. 55 """ 56 features = _apply_filters(x, self.filters_and_sigmas) 57 assert features.shape[1] == self.rf.n_features_in_, f"{features.shape[1]}, {self.rf.n_features_in_}" 58 out = self.rf.predict_proba(features) 59 if self.output_channel is None: 60 out_shape = (out.shape[1],) + x.shape 61 else: 62 out = out[:, self.output_channel] 63 out_shape = x.shape if isinstance(self.output_channel, int) else (len(self.output_channel),) + x.shape 64 out = out.reshape(out_shape).astype("float32") 65 return out 66 67 68class IlastikPredicter: 69 """Wrapper to apply a trained ilastik pixel classification project. 70 71 Args: 72 ilp_path: The file path to the ilastik project. 73 ndim: The dimensionality of the input data. 74 ilastik_multi_thread: Whether to use multi-threaded prediction with ilastik. 75 output_channel: The output channel of the ilastik prediction to keep. 76 """ 77 def __init__( 78 self, 79 ilp_path: str, 80 ndim: int, 81 ilastik_multi_thread: bool, 82 output_channel: Optional[Union[int, Tuple[int, ...]]] = None, 83 ): 84 assert from_project_file is not None 85 assert DataArray is not None 86 assert ndim in (2, 3) 87 if not ilastik_multi_thread: 88 lazyflow.request.Request.reset_thread_pool(0) 89 self.ilp = from_project_file(ilp_path) 90 self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x") 91 self.output_channel = output_channel 92 93 def __call__(self, x: np.ndarray) -> np.ndarray: 94 """Apply ilastik project to input data. 95 96 Args: 97 x: The input data. 98 99 Returns: 100 The ilastik prediction. 101 """ 102 assert x.ndim == len(self.dims), f"{x.ndim}, {self.dims}" 103 try: 104 out = self.ilp.predict(DataArray(x, dims=self.dims)).values 105 except ValueError as e: 106 # this is a bit of a dirty hack for projects that are trained to classify in 2d, but with 3d data 107 # and thus need a singleton z axis. It would be better to ask this of the ilastik classifier, see 108 # https://github.com/ilastik/ilastik/issues/2530 109 if x.ndim == 2: 110 x = x[None] 111 dims = ("z",) + self.dims 112 out = self.ilp.predict(DataArray(x, dims=dims)).values 113 assert out.shape[0] == 1 114 # get rid of the singleton z-axis 115 out = out[0] 116 else: 117 raise e 118 if self.output_channel is not None: 119 out = out[..., self.output_channel] 120 return out 121 122 123class Shallow2DeepModel: 124 """Wrapper to apply a shallow2deep enhancer model to raw data. 125 126 First runs prediction with the random forest and then applies the enhancer model 127 to the random forest predictions. 128 129 Args: 130 checkpoint: The checkpoint of the enhancer model. 131 rf_config: The feature configuration of the random forest. 132 device: The device for the enhancer 133 rf_channel: The channel of the random forest prediction to use as input to the enhancer. 134 ilastik_multi_thread: Whether to use ilastik mulit-threaded prediction. 135 """ 136 137 @staticmethod 138 def load_model(checkpoint, device): 139 """@private 140 """ 141 try: 142 model = get_trainer(checkpoint, device=device).model 143 model.eval() 144 return model 145 except Exception as e: 146 print("Could not load torch_em checkpoint from", checkpoint, "due to exception:", e) 147 print("Trying to load as bioimageio model instead") 148 model = import_bioimageio_model(checkpoint, device=device)[0] 149 model.eval() 150 return model 151 152 @staticmethod 153 def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False): 154 """@private 155 """ 156 if len(rf_config) == 3: # random forest path and feature config 157 rf_path, ndim, filter_config = rf_config 158 assert os.path.exists(rf_path) 159 return RFWithFilters(rf_path, ndim, filter_config, rf_channel) 160 elif len(rf_config) == 2: # ilastik project and dimensionality 161 ilp_path, ndim = rf_config 162 return IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, rf_channel) 163 else: 164 raise ValueError(f"Invalid rf config: {rf_config}") 165 166 def __init__( 167 self, 168 checkpoint: str, 169 rf_config: Dict, 170 device: str, 171 rf_channel: Optional[int] = 1, 172 ilastik_multi_thread: bool = False, 173 ): 174 self.model = self.load_model(checkpoint, device) 175 self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread) 176 self.device = device 177 178 self.checkpoint = checkpoint 179 self.rf_config = rf_config 180 self.device = device 181 182 def __call__(self, x: np.ndarray) -> np.ndarray: 183 """Apply the Shallow2Deep Model to the input data. 184 185 Args: 186 x: The input data. 187 188 Returns: 189 The shallow2deep predictions. 190 """ 191 # TODO support batch axis and multiple input channels 192 out = self.rf_predicter(x[0, 0].cpu().detach().numpy()) 193 out = torch.from_numpy(out[None, None]).to(self.device) 194 out = self.model(out) 195 return out 196 197 # need to overwrite pickle to support the rf / ilastik predicter 198 def __getstate__(self): 199 state = self.__dict__.copy() 200 del state["rf_predicter"] 201 return state 202 203 def __setstate__(self, state): 204 state["rf_predicter"] = self.load_rf(state["rf_config"]) 205 self.__dict__.update(state)
class
RFWithFilters:
31class RFWithFilters: 32 """Wrapper to apply feature computation and random forest prediction. 33 34 Args: 35 rf_path: The path to the trained random forest. 36 ndim: The dimensionality of the input data. 37 filter_config: The configuration of the filters. 38 output_channel: The output channel of the random forest prediction to keep. 39 """ 40 def __init__( 41 self, rf_path: str, ndim: int, filter_config: Dict, output_channel: Optional[Union[Tuple[int, ...], int]] = None 42 ): 43 with open(rf_path, "rb") as f: 44 self.rf = pickle.load(f) 45 self.filters_and_sigmas = _get_filters(ndim, filter_config) 46 self.output_channel = output_channel 47 48 def __call__(self, x: np.ndarray) -> np.ndarray: 49 """Apply feature computation and random forest prediction to input data. 50 51 Args: 52 x: The input data. 53 54 Returns: 55 The random forest prediction. 56 """ 57 features = _apply_filters(x, self.filters_and_sigmas) 58 assert features.shape[1] == self.rf.n_features_in_, f"{features.shape[1]}, {self.rf.n_features_in_}" 59 out = self.rf.predict_proba(features) 60 if self.output_channel is None: 61 out_shape = (out.shape[1],) + x.shape 62 else: 63 out = out[:, self.output_channel] 64 out_shape = x.shape if isinstance(self.output_channel, int) else (len(self.output_channel),) + x.shape 65 out = out.reshape(out_shape).astype("float32") 66 return out
Wrapper to apply feature computation and random forest prediction.
Arguments:
- rf_path: The path to the trained random forest.
- ndim: The dimensionality of the input data.
- filter_config: The configuration of the filters.
- output_channel: The output channel of the random forest prediction to keep.
RFWithFilters( rf_path: str, ndim: int, filter_config: Dict, output_channel: Union[Tuple[int, ...], int, NoneType] = None)
40 def __init__( 41 self, rf_path: str, ndim: int, filter_config: Dict, output_channel: Optional[Union[Tuple[int, ...], int]] = None 42 ): 43 with open(rf_path, "rb") as f: 44 self.rf = pickle.load(f) 45 self.filters_and_sigmas = _get_filters(ndim, filter_config) 46 self.output_channel = output_channel
class
IlastikPredicter:
69class IlastikPredicter: 70 """Wrapper to apply a trained ilastik pixel classification project. 71 72 Args: 73 ilp_path: The file path to the ilastik project. 74 ndim: The dimensionality of the input data. 75 ilastik_multi_thread: Whether to use multi-threaded prediction with ilastik. 76 output_channel: The output channel of the ilastik prediction to keep. 77 """ 78 def __init__( 79 self, 80 ilp_path: str, 81 ndim: int, 82 ilastik_multi_thread: bool, 83 output_channel: Optional[Union[int, Tuple[int, ...]]] = None, 84 ): 85 assert from_project_file is not None 86 assert DataArray is not None 87 assert ndim in (2, 3) 88 if not ilastik_multi_thread: 89 lazyflow.request.Request.reset_thread_pool(0) 90 self.ilp = from_project_file(ilp_path) 91 self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x") 92 self.output_channel = output_channel 93 94 def __call__(self, x: np.ndarray) -> np.ndarray: 95 """Apply ilastik project to input data. 96 97 Args: 98 x: The input data. 99 100 Returns: 101 The ilastik prediction. 102 """ 103 assert x.ndim == len(self.dims), f"{x.ndim}, {self.dims}" 104 try: 105 out = self.ilp.predict(DataArray(x, dims=self.dims)).values 106 except ValueError as e: 107 # this is a bit of a dirty hack for projects that are trained to classify in 2d, but with 3d data 108 # and thus need a singleton z axis. It would be better to ask this of the ilastik classifier, see 109 # https://github.com/ilastik/ilastik/issues/2530 110 if x.ndim == 2: 111 x = x[None] 112 dims = ("z",) + self.dims 113 out = self.ilp.predict(DataArray(x, dims=dims)).values 114 assert out.shape[0] == 1 115 # get rid of the singleton z-axis 116 out = out[0] 117 else: 118 raise e 119 if self.output_channel is not None: 120 out = out[..., self.output_channel] 121 return out
Wrapper to apply a trained ilastik pixel classification project.
Arguments:
- ilp_path: The file path to the ilastik project.
- ndim: The dimensionality of the input data.
- ilastik_multi_thread: Whether to use multi-threaded prediction with ilastik.
- output_channel: The output channel of the ilastik prediction to keep.
IlastikPredicter( ilp_path: str, ndim: int, ilastik_multi_thread: bool, output_channel: Union[Tuple[int, ...], int, NoneType] = None)
78 def __init__( 79 self, 80 ilp_path: str, 81 ndim: int, 82 ilastik_multi_thread: bool, 83 output_channel: Optional[Union[int, Tuple[int, ...]]] = None, 84 ): 85 assert from_project_file is not None 86 assert DataArray is not None 87 assert ndim in (2, 3) 88 if not ilastik_multi_thread: 89 lazyflow.request.Request.reset_thread_pool(0) 90 self.ilp = from_project_file(ilp_path) 91 self.dims = ("y", "x") if ndim == 2 else ("z", "y", "x") 92 self.output_channel = output_channel
class
Shallow2DeepModel:
124class Shallow2DeepModel: 125 """Wrapper to apply a shallow2deep enhancer model to raw data. 126 127 First runs prediction with the random forest and then applies the enhancer model 128 to the random forest predictions. 129 130 Args: 131 checkpoint: The checkpoint of the enhancer model. 132 rf_config: The feature configuration of the random forest. 133 device: The device for the enhancer 134 rf_channel: The channel of the random forest prediction to use as input to the enhancer. 135 ilastik_multi_thread: Whether to use ilastik mulit-threaded prediction. 136 """ 137 138 @staticmethod 139 def load_model(checkpoint, device): 140 """@private 141 """ 142 try: 143 model = get_trainer(checkpoint, device=device).model 144 model.eval() 145 return model 146 except Exception as e: 147 print("Could not load torch_em checkpoint from", checkpoint, "due to exception:", e) 148 print("Trying to load as bioimageio model instead") 149 model = import_bioimageio_model(checkpoint, device=device)[0] 150 model.eval() 151 return model 152 153 @staticmethod 154 def load_rf(rf_config, rf_channel=1, ilastik_multi_thread=False): 155 """@private 156 """ 157 if len(rf_config) == 3: # random forest path and feature config 158 rf_path, ndim, filter_config = rf_config 159 assert os.path.exists(rf_path) 160 return RFWithFilters(rf_path, ndim, filter_config, rf_channel) 161 elif len(rf_config) == 2: # ilastik project and dimensionality 162 ilp_path, ndim = rf_config 163 return IlastikPredicter(ilp_path, ndim, ilastik_multi_thread, rf_channel) 164 else: 165 raise ValueError(f"Invalid rf config: {rf_config}") 166 167 def __init__( 168 self, 169 checkpoint: str, 170 rf_config: Dict, 171 device: str, 172 rf_channel: Optional[int] = 1, 173 ilastik_multi_thread: bool = False, 174 ): 175 self.model = self.load_model(checkpoint, device) 176 self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread) 177 self.device = device 178 179 self.checkpoint = checkpoint 180 self.rf_config = rf_config 181 self.device = device 182 183 def __call__(self, x: np.ndarray) -> np.ndarray: 184 """Apply the Shallow2Deep Model to the input data. 185 186 Args: 187 x: The input data. 188 189 Returns: 190 The shallow2deep predictions. 191 """ 192 # TODO support batch axis and multiple input channels 193 out = self.rf_predicter(x[0, 0].cpu().detach().numpy()) 194 out = torch.from_numpy(out[None, None]).to(self.device) 195 out = self.model(out) 196 return out 197 198 # need to overwrite pickle to support the rf / ilastik predicter 199 def __getstate__(self): 200 state = self.__dict__.copy() 201 del state["rf_predicter"] 202 return state 203 204 def __setstate__(self, state): 205 state["rf_predicter"] = self.load_rf(state["rf_config"]) 206 self.__dict__.update(state)
Wrapper to apply a shallow2deep enhancer model to raw data.
First runs prediction with the random forest and then applies the enhancer model to the random forest predictions.
Arguments:
- checkpoint: The checkpoint of the enhancer model.
- rf_config: The feature configuration of the random forest.
- device: The device for the enhancer
- rf_channel: The channel of the random forest prediction to use as input to the enhancer.
- ilastik_multi_thread: Whether to use ilastik mulit-threaded prediction.
Shallow2DeepModel( checkpoint: str, rf_config: Dict, device: str, rf_channel: Optional[int] = 1, ilastik_multi_thread: bool = False)
167 def __init__( 168 self, 169 checkpoint: str, 170 rf_config: Dict, 171 device: str, 172 rf_channel: Optional[int] = 1, 173 ilastik_multi_thread: bool = False, 174 ): 175 self.model = self.load_model(checkpoint, device) 176 self.rf_predicter = self.load_rf(rf_config, rf_channel, ilastik_multi_thread) 177 self.device = device 178 179 self.checkpoint = checkpoint 180 self.rf_config = rf_config 181 self.device = device