torch_em.shallow2deep.shallow2deep_dataset

  1import os
  2import pickle
  3import warnings
  4from glob import glob
  5
  6import numpy as np
  7import torch
  8from torch_em.segmentation import (check_paths, is_segmentation_dataset,
  9                                   get_data_loader, get_raw_transform,
 10                                   samples_to_datasets, _get_default_transform)
 11from torch_em.data import ConcatDataset, ImageCollectionDataset, SegmentationDataset
 12from .prepare_shallow2deep import _get_filters, _apply_filters
 13from ..util import ensure_tensor_with_channels, ensure_spatial_array
 14
 15
 16class _Shallow2DeepBase:
 17    _rf_paths = None
 18    _filter_config = None
 19
 20    @property
 21    def rf_paths(self):
 22        return self._rf_paths
 23
 24    @rf_paths.setter
 25    def rf_paths(self, value):
 26        self._rf_paths = value
 27
 28    @property
 29    def filter_config(self):
 30        return self._filter_config
 31
 32    @filter_config.setter
 33    def filter_config(self, value):
 34        self._filter_config = value
 35
 36    @property
 37    def rf_channels(self):
 38        return self._rf_channels
 39
 40    @rf_channels.setter
 41    def rf_channels(self, value):
 42        if isinstance(value, int):
 43            self.rf_channels = (value,)
 44        else:
 45            assert isinstance(value, tuple)
 46            self._rf_channels = value
 47
 48    def _predict(self, raw, rf, filters_and_sigmas):
 49        features = _apply_filters(raw, filters_and_sigmas)
 50        assert rf.n_features_in_ == features.shape[1], f"{rf.n_features_in_}, {features.shape[1]}"
 51
 52        try:
 53            pred_ = rf.predict_proba(features)
 54            assert pred_.shape[1] > max(self.rf_channels), f"{pred_.shape}, {self.rf_channels}"
 55            pred_ = pred_[:, self.rf_channels]
 56        except IndexError:
 57            warnings.warn(f"Random forest prediction failed for input features of shape: {features.shape}")
 58            pred_shape = (len(features), len(self.rf_channels))
 59            pred_ = np.zeros(pred_shape, dtype="float32")
 60
 61        spatial_shape = raw.shape
 62        out_shape = (len(self.rf_channels),) + spatial_shape
 63        prediction = np.zeros(out_shape, dtype="float32")
 64        for chan in range(pred_.shape[1]):
 65            prediction[chan] = pred_[:, chan].reshape(spatial_shape)
 66
 67        return prediction
 68
 69    def _predict_rf(self, raw):
 70        n_rfs = len(self._rf_paths)
 71        rf_path = self._rf_paths[np.random.randint(0, n_rfs)]
 72        with open(rf_path, "rb") as f:
 73            rf = pickle.load(f)
 74        filters_and_sigmas = _get_filters(self.ndim, self._filter_config)
 75        return self._predict(raw, rf, filters_and_sigmas)
 76
 77    def _predict_rf_anisotropic(self, raw):
 78        n_rfs = len(self._rf_paths)
 79        rf_path = self._rf_paths[np.random.randint(0, n_rfs)]
 80        with open(rf_path, "rb") as f:
 81            rf = pickle.load(f)
 82        filters_and_sigmas = _get_filters(2, self._filter_config)
 83
 84        n_channels = len(self.rf_channels)
 85        prediction = np.zeros((n_channels,) + raw.shape, dtype="float32")
 86        for z in range(raw.shape[0]):
 87            pred = self._predict(raw[z], rf, filters_and_sigmas)
 88            prediction[:, z] = pred
 89
 90        return prediction
 91
 92
 93class Shallow2DeepDataset(SegmentationDataset, _Shallow2DeepBase):
 94    def __getitem__(self, index):
 95        assert self._rf_paths is not None
 96        raw, labels = self._get_sample(index)
 97        initial_label_dtype = labels.dtype
 98
 99        if self.raw_transform is not None:
100            raw = self.raw_transform(raw)
101        if self.label_transform is not None:
102            labels = self.label_transform(labels)
103        if self.transform is not None:
104            raw, labels = self.transform(raw, labels)
105            if self.trafo_halo is not None:
106                raw = self.crop(raw)
107                labels = self.crop(labels)
108        # support enlarging bounding box here as well (for affinity transform) ?
109        if self.label_transform2 is not None:
110            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
111            labels = self.label_transform2(labels)
112
113        if isinstance(raw, (list, tuple)):  # this can be a list or tuple due to transforms
114            assert len(raw) == 1
115            raw = raw[0]
116        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
117        if raw.shape[0] > 1:
118            raise NotImplementedError(
119                f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels"
120            )
121
122        # NOTE we assume single channel raw data here; this needs to be changed for multi-channel
123        if getattr(self, "is_anisotropic", False):
124            prediction = self._predict_rf_anisotropic(raw[0].numpy())
125        else:
126            prediction = self._predict_rf(raw[0].numpy())
127        prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype)
128        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
129        return prediction, labels
130
131
132class Shallow2DeepImageCollectionDataset(ImageCollectionDataset, _Shallow2DeepBase):
133    def __getitem__(self, index):
134        raw, labels = self._get_sample(index)
135        initial_label_dtype = labels.dtype
136
137        if self.raw_transform is not None:
138            raw = self.raw_transform(raw)
139
140        if self.label_transform is not None:
141            labels = self.label_transform(labels)
142
143        if self.transform is not None:
144            raw, labels = self.transform(raw, labels)
145
146        # support enlarging bounding box here as well (for affinity transform) ?
147        if self.label_transform2 is not None:
148            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
149            labels = self.label_transform2(labels)
150
151        if isinstance(raw, (list, tuple)):  # this can be a list or tuple due to transforms
152            assert len(raw) == 1
153            raw = raw[0]
154        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
155        if raw.shape[0] > 1:
156            raise NotImplementedError(
157                f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels"
158            )
159
160        # NOTE we assume single channel raw data here; this needs to be changed for multi-channel
161        prediction = self._predict_rf(raw[0].numpy())
162        prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype)
163        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
164        return prediction, labels
165
166
167def _load_shallow2deep_segmentation_dataset(
168    raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, ndim, **kwargs
169):
170    rois = kwargs.pop("rois", None)
171    filter_config = kwargs.pop("filter_config", None)
172    if ndim == "anisotropic":
173        ndim = 3
174        is_anisotropic = True
175    else:
176        is_anisotropic = False
177
178    if isinstance(raw_paths, str):
179        if rois is not None:
180            assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois)
181        ds = Shallow2DeepDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, ndim=ndim, **kwargs)
182        ds.rf_paths = rf_paths
183        ds.filter_config = filter_config
184        ds.rf_channels = rf_channels
185        ds.is_anisotropic = is_anisotropic
186    else:
187        assert len(raw_paths) > 0
188        if rois is not None:
189            assert len(rois) == len(label_paths), f"{len(rois)}, {len(label_paths)}"
190            assert all(isinstance(roi, tuple) for roi in rois)
191        n_samples = kwargs.pop("n_samples", None)
192
193        samples_per_ds = (
194            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
195        )
196        ds = []
197        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
198            roi = None if rois is None else rois[i]
199            dset = Shallow2DeepDataset(
200                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], ndim=ndim, **kwargs
201            )
202            dset.rf_paths = rf_paths
203            dset.filter_config = filter_config
204            dset.rf_channels = rf_channels
205            dset.is_anisotropic = is_anisotropic
206            ds.append(dset)
207        ds = ConcatDataset(*ds)
208    return ds
209
210
211def _load_shallow2deep_image_collection_dataset(
212    raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape, **kwargs
213):
214    if isinstance(raw_paths, str):
215        assert isinstance(label_paths, str)
216        raw_file_paths = glob(os.path.join(raw_paths, raw_key))
217        raw_file_paths.sort()
218        label_file_paths = glob(os.path.join(label_paths, label_key))
219        label_file_paths.sort()
220        ds = Shallow2DeepImageCollectionDataset(raw_file_paths, label_file_paths, patch_shape, **kwargs)
221    elif isinstance(raw_paths, list) and raw_key is None:
222        assert isinstance(label_paths, list)
223        assert label_key is None
224        assert all(os.path.exists(pp) for pp in raw_paths)
225        assert all(os.path.exists(pp) for pp in label_paths)
226        ds = Shallow2DeepImageCollectionDataset(raw_paths, label_paths, patch_shape, **kwargs)
227    else:
228        raise NotImplementedError
229
230    filter_config = kwargs.pop("filter_config", None)
231    ds.rf_paths = rf_paths
232    ds.filter_config = filter_config
233    ds.rf_channels = rf_channels
234    return ds
235
236
237def get_shallow2deep_dataset(
238    raw_paths,
239    raw_key,
240    label_paths,
241    label_key,
242    rf_paths,
243    patch_shape,
244    raw_transform=None,
245    label_transform=None,
246    transform=None,
247    dtype=torch.float32,
248    rois=None,
249    n_samples=None,
250    sampler=None,
251    ndim=None,
252    is_seg_dataset=None,
253    with_channels=False,
254    filter_config=None,
255    rf_channels=(1,),
256):
257    check_paths(raw_paths, label_paths)
258    if is_seg_dataset is None:
259        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key,
260                                                 label_paths, label_key)
261
262    # we always use a raw transform in the convenience function
263    if raw_transform is None:
264        raw_transform = get_raw_transform()
265
266    # we always use augmentations in the convenience function
267    if transform is None:
268        transform = _get_default_transform(
269            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset,
270            3 if ndim == "anisotropic" else ndim
271        )
272
273    if is_seg_dataset:
274        ds = _load_shallow2deep_segmentation_dataset(
275            raw_paths,
276            raw_key,
277            label_paths,
278            label_key,
279            rf_paths,
280            patch_shape=patch_shape,
281            raw_transform=raw_transform,
282            label_transform=label_transform,
283            transform=transform,
284            rois=rois,
285            n_samples=n_samples,
286            sampler=sampler,
287            ndim=ndim,
288            dtype=dtype,
289            with_channels=with_channels,
290            filter_config=filter_config,
291            rf_channels=rf_channels,
292        )
293    else:
294        if rois is not None:
295            raise NotImplementedError
296        ds = _load_shallow2deep_image_collection_dataset(
297            raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape,
298            raw_transform=raw_transform, label_transform=label_transform,
299            transform=transform, dtype=dtype, n_samples=n_samples,
300        )
301    return ds
302
303
304def get_shallow2deep_loader(
305    raw_paths,
306    raw_key,
307    label_paths,
308    label_key,
309    rf_paths,
310    batch_size,
311    patch_shape,
312    filter_config=None,
313    raw_transform=None,
314    label_transform=None,
315    transform=None,
316    rois=None,
317    n_samples=None,
318    sampler=None,
319    ndim=None,
320    is_seg_dataset=None,
321    with_channels=False,
322    rf_channels=(1,),
323    **loader_kwargs,
324):
325    ds = get_shallow2deep_dataset(
326        raw_paths=raw_paths,
327        raw_key=raw_key,
328        label_paths=label_paths,
329        label_key=label_key,
330        rf_paths=rf_paths,
331        patch_shape=patch_shape,
332        raw_transform=raw_transform,
333        label_transform=label_transform,
334        transform=transform,
335        rois=rois,
336        n_samples=n_samples,
337        sampler=sampler,
338        ndim=ndim,
339        is_seg_dataset=is_seg_dataset,
340        with_channels=with_channels,
341        filter_config=filter_config,
342        rf_channels=rf_channels,
343    )
344    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
class Shallow2DeepDataset(typing.Generic[+T_co]):
 94class Shallow2DeepDataset(SegmentationDataset, _Shallow2DeepBase):
 95    def __getitem__(self, index):
 96        assert self._rf_paths is not None
 97        raw, labels = self._get_sample(index)
 98        initial_label_dtype = labels.dtype
 99
100        if self.raw_transform is not None:
101            raw = self.raw_transform(raw)
102        if self.label_transform is not None:
103            labels = self.label_transform(labels)
104        if self.transform is not None:
105            raw, labels = self.transform(raw, labels)
106            if self.trafo_halo is not None:
107                raw = self.crop(raw)
108                labels = self.crop(labels)
109        # support enlarging bounding box here as well (for affinity transform) ?
110        if self.label_transform2 is not None:
111            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
112            labels = self.label_transform2(labels)
113
114        if isinstance(raw, (list, tuple)):  # this can be a list or tuple due to transforms
115            assert len(raw) == 1
116            raw = raw[0]
117        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
118        if raw.shape[0] > 1:
119            raise NotImplementedError(
120                f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels"
121            )
122
123        # NOTE we assume single channel raw data here; this needs to be changed for multi-channel
124        if getattr(self, "is_anisotropic", False):
125            prediction = self._predict_rf_anisotropic(raw[0].numpy())
126        else:
127            prediction = self._predict_rf(raw[0].numpy())
128        prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype)
129        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
130        return prediction, labels
class Shallow2DeepImageCollectionDataset(typing.Generic[+T_co]):
133class Shallow2DeepImageCollectionDataset(ImageCollectionDataset, _Shallow2DeepBase):
134    def __getitem__(self, index):
135        raw, labels = self._get_sample(index)
136        initial_label_dtype = labels.dtype
137
138        if self.raw_transform is not None:
139            raw = self.raw_transform(raw)
140
141        if self.label_transform is not None:
142            labels = self.label_transform(labels)
143
144        if self.transform is not None:
145            raw, labels = self.transform(raw, labels)
146
147        # support enlarging bounding box here as well (for affinity transform) ?
148        if self.label_transform2 is not None:
149            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
150            labels = self.label_transform2(labels)
151
152        if isinstance(raw, (list, tuple)):  # this can be a list or tuple due to transforms
153            assert len(raw) == 1
154            raw = raw[0]
155        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
156        if raw.shape[0] > 1:
157            raise NotImplementedError(
158                f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels"
159            )
160
161        # NOTE we assume single channel raw data here; this needs to be changed for multi-channel
162        prediction = self._predict_rf(raw[0].numpy())
163        prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype)
164        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
165        return prediction, labels

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

def get_shallow2deep_dataset( raw_paths, raw_key, label_paths, label_key, rf_paths, patch_shape, raw_transform=None, label_transform=None, transform=None, dtype=torch.float32, rois=None, n_samples=None, sampler=None, ndim=None, is_seg_dataset=None, with_channels=False, filter_config=None, rf_channels=(1,)):
238def get_shallow2deep_dataset(
239    raw_paths,
240    raw_key,
241    label_paths,
242    label_key,
243    rf_paths,
244    patch_shape,
245    raw_transform=None,
246    label_transform=None,
247    transform=None,
248    dtype=torch.float32,
249    rois=None,
250    n_samples=None,
251    sampler=None,
252    ndim=None,
253    is_seg_dataset=None,
254    with_channels=False,
255    filter_config=None,
256    rf_channels=(1,),
257):
258    check_paths(raw_paths, label_paths)
259    if is_seg_dataset is None:
260        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key,
261                                                 label_paths, label_key)
262
263    # we always use a raw transform in the convenience function
264    if raw_transform is None:
265        raw_transform = get_raw_transform()
266
267    # we always use augmentations in the convenience function
268    if transform is None:
269        transform = _get_default_transform(
270            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset,
271            3 if ndim == "anisotropic" else ndim
272        )
273
274    if is_seg_dataset:
275        ds = _load_shallow2deep_segmentation_dataset(
276            raw_paths,
277            raw_key,
278            label_paths,
279            label_key,
280            rf_paths,
281            patch_shape=patch_shape,
282            raw_transform=raw_transform,
283            label_transform=label_transform,
284            transform=transform,
285            rois=rois,
286            n_samples=n_samples,
287            sampler=sampler,
288            ndim=ndim,
289            dtype=dtype,
290            with_channels=with_channels,
291            filter_config=filter_config,
292            rf_channels=rf_channels,
293        )
294    else:
295        if rois is not None:
296            raise NotImplementedError
297        ds = _load_shallow2deep_image_collection_dataset(
298            raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape,
299            raw_transform=raw_transform, label_transform=label_transform,
300            transform=transform, dtype=dtype, n_samples=n_samples,
301        )
302    return ds
def get_shallow2deep_loader( raw_paths, raw_key, label_paths, label_key, rf_paths, batch_size, patch_shape, filter_config=None, raw_transform=None, label_transform=None, transform=None, rois=None, n_samples=None, sampler=None, ndim=None, is_seg_dataset=None, with_channels=False, rf_channels=(1,), **loader_kwargs):
305def get_shallow2deep_loader(
306    raw_paths,
307    raw_key,
308    label_paths,
309    label_key,
310    rf_paths,
311    batch_size,
312    patch_shape,
313    filter_config=None,
314    raw_transform=None,
315    label_transform=None,
316    transform=None,
317    rois=None,
318    n_samples=None,
319    sampler=None,
320    ndim=None,
321    is_seg_dataset=None,
322    with_channels=False,
323    rf_channels=(1,),
324    **loader_kwargs,
325):
326    ds = get_shallow2deep_dataset(
327        raw_paths=raw_paths,
328        raw_key=raw_key,
329        label_paths=label_paths,
330        label_key=label_key,
331        rf_paths=rf_paths,
332        patch_shape=patch_shape,
333        raw_transform=raw_transform,
334        label_transform=label_transform,
335        transform=transform,
336        rois=rois,
337        n_samples=n_samples,
338        sampler=sampler,
339        ndim=ndim,
340        is_seg_dataset=is_seg_dataset,
341        with_channels=with_channels,
342        filter_config=filter_config,
343        rf_channels=rf_channels,
344    )
345    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)