torch_em.shallow2deep.shallow2deep_dataset
1import os 2import pickle 3import warnings 4from glob import glob 5 6import numpy as np 7import torch 8from torch_em.segmentation import (check_paths, is_segmentation_dataset, 9 get_data_loader, get_raw_transform, 10 samples_to_datasets, _get_default_transform) 11from torch_em.data import ConcatDataset, ImageCollectionDataset, SegmentationDataset 12from .prepare_shallow2deep import _get_filters, _apply_filters 13from ..util import ensure_tensor_with_channels, ensure_spatial_array 14 15 16class _Shallow2DeepBase: 17 _rf_paths = None 18 _filter_config = None 19 20 @property 21 def rf_paths(self): 22 return self._rf_paths 23 24 @rf_paths.setter 25 def rf_paths(self, value): 26 self._rf_paths = value 27 28 @property 29 def filter_config(self): 30 return self._filter_config 31 32 @filter_config.setter 33 def filter_config(self, value): 34 self._filter_config = value 35 36 @property 37 def rf_channels(self): 38 return self._rf_channels 39 40 @rf_channels.setter 41 def rf_channels(self, value): 42 if isinstance(value, int): 43 self.rf_channels = (value,) 44 else: 45 assert isinstance(value, tuple) 46 self._rf_channels = value 47 48 def _predict(self, raw, rf, filters_and_sigmas): 49 features = _apply_filters(raw, filters_and_sigmas) 50 assert rf.n_features_in_ == features.shape[1], f"{rf.n_features_in_}, {features.shape[1]}" 51 52 try: 53 pred_ = rf.predict_proba(features) 54 assert pred_.shape[1] > max(self.rf_channels), f"{pred_.shape}, {self.rf_channels}" 55 pred_ = pred_[:, self.rf_channels] 56 except IndexError: 57 warnings.warn(f"Random forest prediction failed for input features of shape: {features.shape}") 58 pred_shape = (len(features), len(self.rf_channels)) 59 pred_ = np.zeros(pred_shape, dtype="float32") 60 61 spatial_shape = raw.shape 62 out_shape = (len(self.rf_channels),) + spatial_shape 63 prediction = np.zeros(out_shape, dtype="float32") 64 for chan in range(pred_.shape[1]): 65 prediction[chan] = pred_[:, chan].reshape(spatial_shape) 66 67 return prediction 68 69 def _predict_rf(self, raw): 70 n_rfs = len(self._rf_paths) 71 rf_path = self._rf_paths[np.random.randint(0, n_rfs)] 72 with open(rf_path, "rb") as f: 73 rf = pickle.load(f) 74 filters_and_sigmas = _get_filters(self.ndim, self._filter_config) 75 return self._predict(raw, rf, filters_and_sigmas) 76 77 def _predict_rf_anisotropic(self, raw): 78 n_rfs = len(self._rf_paths) 79 rf_path = self._rf_paths[np.random.randint(0, n_rfs)] 80 with open(rf_path, "rb") as f: 81 rf = pickle.load(f) 82 filters_and_sigmas = _get_filters(2, self._filter_config) 83 84 n_channels = len(self.rf_channels) 85 prediction = np.zeros((n_channels,) + raw.shape, dtype="float32") 86 for z in range(raw.shape[0]): 87 pred = self._predict(raw[z], rf, filters_and_sigmas) 88 prediction[:, z] = pred 89 90 return prediction 91 92 93class Shallow2DeepDataset(SegmentationDataset, _Shallow2DeepBase): 94 def __getitem__(self, index): 95 assert self._rf_paths is not None 96 raw, labels = self._get_sample(index) 97 initial_label_dtype = labels.dtype 98 99 if self.raw_transform is not None: 100 raw = self.raw_transform(raw) 101 if self.label_transform is not None: 102 labels = self.label_transform(labels) 103 if self.transform is not None: 104 raw, labels = self.transform(raw, labels) 105 if self.trafo_halo is not None: 106 raw = self.crop(raw) 107 labels = self.crop(labels) 108 # support enlarging bounding box here as well (for affinity transform) ? 109 if self.label_transform2 is not None: 110 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 111 labels = self.label_transform2(labels) 112 113 if isinstance(raw, (list, tuple)): # this can be a list or tuple due to transforms 114 assert len(raw) == 1 115 raw = raw[0] 116 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 117 if raw.shape[0] > 1: 118 raise NotImplementedError( 119 f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels" 120 ) 121 122 # NOTE we assume single channel raw data here; this needs to be changed for multi-channel 123 if getattr(self, "is_anisotropic", False): 124 prediction = self._predict_rf_anisotropic(raw[0].numpy()) 125 else: 126 prediction = self._predict_rf(raw[0].numpy()) 127 prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype) 128 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 129 return prediction, labels 130 131 132class Shallow2DeepImageCollectionDataset(ImageCollectionDataset, _Shallow2DeepBase): 133 def __getitem__(self, index): 134 raw, labels = self._get_sample(index) 135 initial_label_dtype = labels.dtype 136 137 if self.raw_transform is not None: 138 raw = self.raw_transform(raw) 139 140 if self.label_transform is not None: 141 labels = self.label_transform(labels) 142 143 if self.transform is not None: 144 raw, labels = self.transform(raw, labels) 145 146 # support enlarging bounding box here as well (for affinity transform) ? 147 if self.label_transform2 is not None: 148 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 149 labels = self.label_transform2(labels) 150 151 if isinstance(raw, (list, tuple)): # this can be a list or tuple due to transforms 152 assert len(raw) == 1 153 raw = raw[0] 154 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 155 if raw.shape[0] > 1: 156 raise NotImplementedError( 157 f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels" 158 ) 159 160 # NOTE we assume single channel raw data here; this needs to be changed for multi-channel 161 prediction = self._predict_rf(raw[0].numpy()) 162 prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype) 163 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 164 return prediction, labels 165 166 167def _load_shallow2deep_segmentation_dataset( 168 raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, ndim, **kwargs 169): 170 rois = kwargs.pop("rois", None) 171 filter_config = kwargs.pop("filter_config", None) 172 if ndim == "anisotropic": 173 ndim = 3 174 is_anisotropic = True 175 else: 176 is_anisotropic = False 177 178 if isinstance(raw_paths, str): 179 if rois is not None: 180 assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois) 181 ds = Shallow2DeepDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, ndim=ndim, **kwargs) 182 ds.rf_paths = rf_paths 183 ds.filter_config = filter_config 184 ds.rf_channels = rf_channels 185 ds.is_anisotropic = is_anisotropic 186 else: 187 assert len(raw_paths) > 0 188 if rois is not None: 189 assert len(rois) == len(label_paths), f"{len(rois)}, {len(label_paths)}" 190 assert all(isinstance(roi, tuple) for roi in rois) 191 n_samples = kwargs.pop("n_samples", None) 192 193 samples_per_ds = ( 194 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 195 ) 196 ds = [] 197 for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): 198 roi = None if rois is None else rois[i] 199 dset = Shallow2DeepDataset( 200 raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], ndim=ndim, **kwargs 201 ) 202 dset.rf_paths = rf_paths 203 dset.filter_config = filter_config 204 dset.rf_channels = rf_channels 205 dset.is_anisotropic = is_anisotropic 206 ds.append(dset) 207 ds = ConcatDataset(*ds) 208 return ds 209 210 211def _load_shallow2deep_image_collection_dataset( 212 raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape, **kwargs 213): 214 if isinstance(raw_paths, str): 215 assert isinstance(label_paths, str) 216 raw_file_paths = glob(os.path.join(raw_paths, raw_key)) 217 raw_file_paths.sort() 218 label_file_paths = glob(os.path.join(label_paths, label_key)) 219 label_file_paths.sort() 220 ds = Shallow2DeepImageCollectionDataset(raw_file_paths, label_file_paths, patch_shape, **kwargs) 221 elif isinstance(raw_paths, list) and raw_key is None: 222 assert isinstance(label_paths, list) 223 assert label_key is None 224 assert all(os.path.exists(pp) for pp in raw_paths) 225 assert all(os.path.exists(pp) for pp in label_paths) 226 ds = Shallow2DeepImageCollectionDataset(raw_paths, label_paths, patch_shape, **kwargs) 227 else: 228 raise NotImplementedError 229 230 filter_config = kwargs.pop("filter_config", None) 231 ds.rf_paths = rf_paths 232 ds.filter_config = filter_config 233 ds.rf_channels = rf_channels 234 return ds 235 236 237def get_shallow2deep_dataset( 238 raw_paths, 239 raw_key, 240 label_paths, 241 label_key, 242 rf_paths, 243 patch_shape, 244 raw_transform=None, 245 label_transform=None, 246 transform=None, 247 dtype=torch.float32, 248 rois=None, 249 n_samples=None, 250 sampler=None, 251 ndim=None, 252 is_seg_dataset=None, 253 with_channels=False, 254 filter_config=None, 255 rf_channels=(1,), 256): 257 check_paths(raw_paths, label_paths) 258 if is_seg_dataset is None: 259 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, 260 label_paths, label_key) 261 262 # we always use a raw transform in the convenience function 263 if raw_transform is None: 264 raw_transform = get_raw_transform() 265 266 # we always use augmentations in the convenience function 267 if transform is None: 268 transform = _get_default_transform( 269 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, 270 3 if ndim == "anisotropic" else ndim 271 ) 272 273 if is_seg_dataset: 274 ds = _load_shallow2deep_segmentation_dataset( 275 raw_paths, 276 raw_key, 277 label_paths, 278 label_key, 279 rf_paths, 280 patch_shape=patch_shape, 281 raw_transform=raw_transform, 282 label_transform=label_transform, 283 transform=transform, 284 rois=rois, 285 n_samples=n_samples, 286 sampler=sampler, 287 ndim=ndim, 288 dtype=dtype, 289 with_channels=with_channels, 290 filter_config=filter_config, 291 rf_channels=rf_channels, 292 ) 293 else: 294 if rois is not None: 295 raise NotImplementedError 296 ds = _load_shallow2deep_image_collection_dataset( 297 raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape, 298 raw_transform=raw_transform, label_transform=label_transform, 299 transform=transform, dtype=dtype, n_samples=n_samples, 300 ) 301 return ds 302 303 304def get_shallow2deep_loader( 305 raw_paths, 306 raw_key, 307 label_paths, 308 label_key, 309 rf_paths, 310 batch_size, 311 patch_shape, 312 filter_config=None, 313 raw_transform=None, 314 label_transform=None, 315 transform=None, 316 rois=None, 317 n_samples=None, 318 sampler=None, 319 ndim=None, 320 is_seg_dataset=None, 321 with_channels=False, 322 rf_channels=(1,), 323 **loader_kwargs, 324): 325 ds = get_shallow2deep_dataset( 326 raw_paths=raw_paths, 327 raw_key=raw_key, 328 label_paths=label_paths, 329 label_key=label_key, 330 rf_paths=rf_paths, 331 patch_shape=patch_shape, 332 raw_transform=raw_transform, 333 label_transform=label_transform, 334 transform=transform, 335 rois=rois, 336 n_samples=n_samples, 337 sampler=sampler, 338 ndim=ndim, 339 is_seg_dataset=is_seg_dataset, 340 with_channels=with_channels, 341 filter_config=filter_config, 342 rf_channels=rf_channels, 343 ) 344 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
94class Shallow2DeepDataset(SegmentationDataset, _Shallow2DeepBase): 95 def __getitem__(self, index): 96 assert self._rf_paths is not None 97 raw, labels = self._get_sample(index) 98 initial_label_dtype = labels.dtype 99 100 if self.raw_transform is not None: 101 raw = self.raw_transform(raw) 102 if self.label_transform is not None: 103 labels = self.label_transform(labels) 104 if self.transform is not None: 105 raw, labels = self.transform(raw, labels) 106 if self.trafo_halo is not None: 107 raw = self.crop(raw) 108 labels = self.crop(labels) 109 # support enlarging bounding box here as well (for affinity transform) ? 110 if self.label_transform2 is not None: 111 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 112 labels = self.label_transform2(labels) 113 114 if isinstance(raw, (list, tuple)): # this can be a list or tuple due to transforms 115 assert len(raw) == 1 116 raw = raw[0] 117 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 118 if raw.shape[0] > 1: 119 raise NotImplementedError( 120 f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels" 121 ) 122 123 # NOTE we assume single channel raw data here; this needs to be changed for multi-channel 124 if getattr(self, "is_anisotropic", False): 125 prediction = self._predict_rf_anisotropic(raw[0].numpy()) 126 else: 127 prediction = self._predict_rf(raw[0].numpy()) 128 prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype) 129 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 130 return prediction, labels
Inherited Members
133class Shallow2DeepImageCollectionDataset(ImageCollectionDataset, _Shallow2DeepBase): 134 def __getitem__(self, index): 135 raw, labels = self._get_sample(index) 136 initial_label_dtype = labels.dtype 137 138 if self.raw_transform is not None: 139 raw = self.raw_transform(raw) 140 141 if self.label_transform is not None: 142 labels = self.label_transform(labels) 143 144 if self.transform is not None: 145 raw, labels = self.transform(raw, labels) 146 147 # support enlarging bounding box here as well (for affinity transform) ? 148 if self.label_transform2 is not None: 149 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 150 labels = self.label_transform2(labels) 151 152 if isinstance(raw, (list, tuple)): # this can be a list or tuple due to transforms 153 assert len(raw) == 1 154 raw = raw[0] 155 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 156 if raw.shape[0] > 1: 157 raise NotImplementedError( 158 f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels" 159 ) 160 161 # NOTE we assume single channel raw data here; this needs to be changed for multi-channel 162 prediction = self._predict_rf(raw[0].numpy()) 163 prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype) 164 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 165 return prediction, labels
An abstract class representing a Dataset
.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__()
, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__()
, which is expected to return the size of the dataset by many
~torch.utils.data.Sampler
implementations and the default options
of ~torch.utils.data.DataLoader
. Subclasses could also
optionally implement __getitems__()
, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
~torch.utils.data.DataLoader
by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
Inherited Members
238def get_shallow2deep_dataset( 239 raw_paths, 240 raw_key, 241 label_paths, 242 label_key, 243 rf_paths, 244 patch_shape, 245 raw_transform=None, 246 label_transform=None, 247 transform=None, 248 dtype=torch.float32, 249 rois=None, 250 n_samples=None, 251 sampler=None, 252 ndim=None, 253 is_seg_dataset=None, 254 with_channels=False, 255 filter_config=None, 256 rf_channels=(1,), 257): 258 check_paths(raw_paths, label_paths) 259 if is_seg_dataset is None: 260 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, 261 label_paths, label_key) 262 263 # we always use a raw transform in the convenience function 264 if raw_transform is None: 265 raw_transform = get_raw_transform() 266 267 # we always use augmentations in the convenience function 268 if transform is None: 269 transform = _get_default_transform( 270 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, 271 3 if ndim == "anisotropic" else ndim 272 ) 273 274 if is_seg_dataset: 275 ds = _load_shallow2deep_segmentation_dataset( 276 raw_paths, 277 raw_key, 278 label_paths, 279 label_key, 280 rf_paths, 281 patch_shape=patch_shape, 282 raw_transform=raw_transform, 283 label_transform=label_transform, 284 transform=transform, 285 rois=rois, 286 n_samples=n_samples, 287 sampler=sampler, 288 ndim=ndim, 289 dtype=dtype, 290 with_channels=with_channels, 291 filter_config=filter_config, 292 rf_channels=rf_channels, 293 ) 294 else: 295 if rois is not None: 296 raise NotImplementedError 297 ds = _load_shallow2deep_image_collection_dataset( 298 raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape, 299 raw_transform=raw_transform, label_transform=label_transform, 300 transform=transform, dtype=dtype, n_samples=n_samples, 301 ) 302 return ds
305def get_shallow2deep_loader( 306 raw_paths, 307 raw_key, 308 label_paths, 309 label_key, 310 rf_paths, 311 batch_size, 312 patch_shape, 313 filter_config=None, 314 raw_transform=None, 315 label_transform=None, 316 transform=None, 317 rois=None, 318 n_samples=None, 319 sampler=None, 320 ndim=None, 321 is_seg_dataset=None, 322 with_channels=False, 323 rf_channels=(1,), 324 **loader_kwargs, 325): 326 ds = get_shallow2deep_dataset( 327 raw_paths=raw_paths, 328 raw_key=raw_key, 329 label_paths=label_paths, 330 label_key=label_key, 331 rf_paths=rf_paths, 332 patch_shape=patch_shape, 333 raw_transform=raw_transform, 334 label_transform=label_transform, 335 transform=transform, 336 rois=rois, 337 n_samples=n_samples, 338 sampler=sampler, 339 ndim=ndim, 340 is_seg_dataset=is_seg_dataset, 341 with_channels=with_channels, 342 filter_config=filter_config, 343 rf_channels=rf_channels, 344 ) 345 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)