torch_em.data.raw_dataset

  1import os
  2import warnings
  3import numpy as np
  4from typing import List, Union, Tuple, Optional, Any, Callable
  5
  6import torch
  7
  8from elf.wrapper import RoiWrapper
  9
 10from ..util import ensure_tensor_with_channels, ensure_patch_shape, load_data
 11
 12
 13class RawDataset(torch.utils.data.Dataset):
 14    """Dataset that provides raw data stored in a container data format for unsupervised training.
 15
 16    The dataset loads a patch from the raw data and returns a sample for a batch.
 17    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 18    Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset.
 19    It also supports regular image formats, such as .tif. For these cases set `raw_key=None`.
 20
 21    The dataset can also be used for contrastive learning that relies on two different views of the same data.
 22    You can use the `augmentations` argument for this.
 23
 24    Args:
 25        raw_path: The file path to the raw image data. May also be a list of file paths.
 26        raw_key: The key to the internal dataset containing the raw data.
 27        patch_shape: The patch shape for a training sample.
 28        raw_transform: Transformation applied to the raw data of a sample.
 29        transform: Transformation to the raw data. This can be used to implement data augmentations.
 30        roi: Region of interest in the raw data.
 31            If given, the raw data will only be loaded from the corresponding area.
 32        dtype: The return data type of the raw data.
 33        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 34        sampler: Sampler for rejecting samples according to a defined criterion.
 35            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 36        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 37        with_channels: Whether the raw data has channels.
 38        augmentations: Augmentations for contrastive learning. If given, these need to be two different callables.
 39            They will be applied to the sampled raw data to return two independent views of the raw data.
 40    """
 41    max_sampling_attempts = 500
 42    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 43    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 44    """
 45
 46    @staticmethod
 47    def compute_len(shape, patch_shape):
 48        n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 49        return n_samples
 50
 51    def __init__(
 52        self,
 53        raw_path: Union[List[Any], str, os.PathLike],
 54        raw_key: Optional[str],
 55        patch_shape: Tuple[int, ...],
 56        raw_transform: Optional[Callable] = None,
 57        transform: Optional[Callable] = None,
 58        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 59        dtype: torch.dtype = torch.float32,
 60        n_samples: Optional[int] = None,
 61        sampler: Optional[Callable] = None,
 62        ndim: Optional[int] = None,
 63        with_channels: bool = False,
 64        augmentations: Optional[Tuple[Callable, Callable]] = None,
 65    ):
 66        self.raw_path = raw_path
 67        self.raw_key = raw_key
 68        self.raw = load_data(raw_path, raw_key)
 69
 70        self._with_channels = with_channels
 71
 72        if roi is not None:
 73            if isinstance(roi, slice):
 74                roi = (roi,)
 75
 76            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 77
 78        self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 79        self.roi = roi
 80
 81        self._ndim = len(self.shape) if ndim is None else ndim
 82        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 83
 84        assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 85        self.patch_shape = patch_shape
 86
 87        self.raw_transform = raw_transform
 88        self.transform = transform
 89        self.sampler = sampler
 90        self.dtype = dtype
 91
 92        if augmentations is not None:
 93            assert len(augmentations) == 2
 94        self.augmentations = augmentations
 95
 96        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 97
 98        self.sample_shape = patch_shape
 99        self.trafo_halo = None
100        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
101        # which is then cut. See code below; but this ne needs to be properly tested
102
103        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
104        # if self.trafo_halo is not None:
105        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
106        #         self.trafo_halo = (0,) + self.trafo_halo
107        #     assert len(self.trafo_halo) == self._ndim
108        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
109        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
110
111    def __len__(self):
112        return self._len
113
114    @property
115    def ndim(self):
116        return self._ndim
117
118    def _sample_bounding_box(self):
119        bb_start = [
120            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
121            for sh, psh in zip(self.shape, self.sample_shape)
122        ]
123        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape))
124
125    def _get_sample(self, index):
126        if self.raw is None:
127            raise RuntimeError("RawDataset has not been properly deserialized.")
128        bb = self._sample_bounding_box()
129        raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
130
131        if self.sampler is not None:
132            sample_id = 0
133            while not self.sampler(raw):
134                bb = self._sample_bounding_box()
135                raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
136                sample_id += 1
137                if sample_id > self.max_sampling_attempts:
138                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
139
140        if self.patch_shape is not None:
141            raw = ensure_patch_shape(
142                raw=raw, labels=None, patch_shape=self.patch_shape, have_raw_channels=self._with_channels
143            )
144
145        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
146        if len(self.patch_shape) == self._ndim + 1:
147            raw = raw.squeeze(1 if self._with_channels else 0)
148
149        return raw
150
151    def crop(self, tensor):
152        bb = self.inner_bb
153        if tensor.ndim > len(bb):
154            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
155        return tensor[bb]
156
157    def __getitem__(self, index):
158        raw = self._get_sample(index)
159
160        if self.raw_transform is not None:
161            raw = self.raw_transform(raw)
162
163        if self.transform is not None:
164            raw = self.transform(raw)
165            if isinstance(raw, list):
166                assert len(raw) == 1
167                raw = raw[0]
168
169            if self.trafo_halo is not None:
170                raw = self.crop(raw)
171
172        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
173        if self.augmentations is not None:
174            aug1, aug2 = self.augmentations
175            raw1, raw2 = aug1(raw), aug2(raw)
176            return raw1, raw2
177
178        return raw
179
180    # need to overwrite pickle to support h5py
181    def __getstate__(self):
182        state = self.__dict__.copy()
183        del state["raw"]
184        return state
185
186    def __setstate__(self, state):
187        raw_path, raw_key = state["raw_path"], state["raw_key"]
188        roi = state["roi"]
189        try:
190            raw = load_data(raw_path, raw_key)
191            if roi is not None:
192                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
193            state["raw"] = raw
194        except Exception:
195            msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
196            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
197            msg += "But it cannot be used for further training and wil throw an error."
198            warnings.warn(msg)
199            state["raw"] = None
200
201        self.__dict__.update(state)
class RawDataset(typing.Generic[+_T_co]):
 14class RawDataset(torch.utils.data.Dataset):
 15    """Dataset that provides raw data stored in a container data format for unsupervised training.
 16
 17    The dataset loads a patch from the raw data and returns a sample for a batch.
 18    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 19    Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset.
 20    It also supports regular image formats, such as .tif. For these cases set `raw_key=None`.
 21
 22    The dataset can also be used for contrastive learning that relies on two different views of the same data.
 23    You can use the `augmentations` argument for this.
 24
 25    Args:
 26        raw_path: The file path to the raw image data. May also be a list of file paths.
 27        raw_key: The key to the internal dataset containing the raw data.
 28        patch_shape: The patch shape for a training sample.
 29        raw_transform: Transformation applied to the raw data of a sample.
 30        transform: Transformation to the raw data. This can be used to implement data augmentations.
 31        roi: Region of interest in the raw data.
 32            If given, the raw data will only be loaded from the corresponding area.
 33        dtype: The return data type of the raw data.
 34        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 35        sampler: Sampler for rejecting samples according to a defined criterion.
 36            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 37        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 38        with_channels: Whether the raw data has channels.
 39        augmentations: Augmentations for contrastive learning. If given, these need to be two different callables.
 40            They will be applied to the sampled raw data to return two independent views of the raw data.
 41    """
 42    max_sampling_attempts = 500
 43    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 44    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 45    """
 46
 47    @staticmethod
 48    def compute_len(shape, patch_shape):
 49        n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 50        return n_samples
 51
 52    def __init__(
 53        self,
 54        raw_path: Union[List[Any], str, os.PathLike],
 55        raw_key: Optional[str],
 56        patch_shape: Tuple[int, ...],
 57        raw_transform: Optional[Callable] = None,
 58        transform: Optional[Callable] = None,
 59        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 60        dtype: torch.dtype = torch.float32,
 61        n_samples: Optional[int] = None,
 62        sampler: Optional[Callable] = None,
 63        ndim: Optional[int] = None,
 64        with_channels: bool = False,
 65        augmentations: Optional[Tuple[Callable, Callable]] = None,
 66    ):
 67        self.raw_path = raw_path
 68        self.raw_key = raw_key
 69        self.raw = load_data(raw_path, raw_key)
 70
 71        self._with_channels = with_channels
 72
 73        if roi is not None:
 74            if isinstance(roi, slice):
 75                roi = (roi,)
 76
 77            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 78
 79        self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 80        self.roi = roi
 81
 82        self._ndim = len(self.shape) if ndim is None else ndim
 83        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 84
 85        assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 86        self.patch_shape = patch_shape
 87
 88        self.raw_transform = raw_transform
 89        self.transform = transform
 90        self.sampler = sampler
 91        self.dtype = dtype
 92
 93        if augmentations is not None:
 94            assert len(augmentations) == 2
 95        self.augmentations = augmentations
 96
 97        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 98
 99        self.sample_shape = patch_shape
100        self.trafo_halo = None
101        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
102        # which is then cut. See code below; but this ne needs to be properly tested
103
104        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
105        # if self.trafo_halo is not None:
106        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
107        #         self.trafo_halo = (0,) + self.trafo_halo
108        #     assert len(self.trafo_halo) == self._ndim
109        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
110        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
111
112    def __len__(self):
113        return self._len
114
115    @property
116    def ndim(self):
117        return self._ndim
118
119    def _sample_bounding_box(self):
120        bb_start = [
121            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
122            for sh, psh in zip(self.shape, self.sample_shape)
123        ]
124        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape))
125
126    def _get_sample(self, index):
127        if self.raw is None:
128            raise RuntimeError("RawDataset has not been properly deserialized.")
129        bb = self._sample_bounding_box()
130        raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
131
132        if self.sampler is not None:
133            sample_id = 0
134            while not self.sampler(raw):
135                bb = self._sample_bounding_box()
136                raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
137                sample_id += 1
138                if sample_id > self.max_sampling_attempts:
139                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
140
141        if self.patch_shape is not None:
142            raw = ensure_patch_shape(
143                raw=raw, labels=None, patch_shape=self.patch_shape, have_raw_channels=self._with_channels
144            )
145
146        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
147        if len(self.patch_shape) == self._ndim + 1:
148            raw = raw.squeeze(1 if self._with_channels else 0)
149
150        return raw
151
152    def crop(self, tensor):
153        bb = self.inner_bb
154        if tensor.ndim > len(bb):
155            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
156        return tensor[bb]
157
158    def __getitem__(self, index):
159        raw = self._get_sample(index)
160
161        if self.raw_transform is not None:
162            raw = self.raw_transform(raw)
163
164        if self.transform is not None:
165            raw = self.transform(raw)
166            if isinstance(raw, list):
167                assert len(raw) == 1
168                raw = raw[0]
169
170            if self.trafo_halo is not None:
171                raw = self.crop(raw)
172
173        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
174        if self.augmentations is not None:
175            aug1, aug2 = self.augmentations
176            raw1, raw2 = aug1(raw), aug2(raw)
177            return raw1, raw2
178
179        return raw
180
181    # need to overwrite pickle to support h5py
182    def __getstate__(self):
183        state = self.__dict__.copy()
184        del state["raw"]
185        return state
186
187    def __setstate__(self, state):
188        raw_path, raw_key = state["raw_path"], state["raw_key"]
189        roi = state["roi"]
190        try:
191            raw = load_data(raw_path, raw_key)
192            if roi is not None:
193                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
194            state["raw"] = raw
195        except Exception:
196            msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
197            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
198            msg += "But it cannot be used for further training and wil throw an error."
199            warnings.warn(msg)
200            state["raw"] = None
201
202        self.__dict__.update(state)

Dataset that provides raw data stored in a container data format for unsupervised training.

The dataset loads a patch from the raw data and returns a sample for a batch. The dataset supports all file formats that can be opened with elf.io.open_file, such as hdf5, zarr or n5. Use raw_path to specify the path to the file and raw_key to specify the internal dataset. It also supports regular image formats, such as .tif. For these cases set raw_key=None.

The dataset can also be used for contrastive learning that relies on two different views of the same data. You can use the augmentations argument for this.

Arguments:
  • raw_path: The file path to the raw image data. May also be a list of file paths.
  • raw_key: The key to the internal dataset containing the raw data.
  • patch_shape: The patch shape for a training sample.
  • raw_transform: Transformation applied to the raw data of a sample.
  • transform: Transformation to the raw data. This can be used to implement data augmentations.
  • roi: Region of interest in the raw data. If given, the raw data will only be loaded from the corresponding area.
  • dtype: The return data type of the raw data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • with_channels: Whether the raw data has channels.
  • augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. They will be applied to the sampled raw data to return two independent views of the raw data.
RawDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: Optional[str], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, roi: Union[slice, Tuple[slice, ...], NoneType] = None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, with_channels: bool = False, augmentations: Optional[Tuple[Callable, Callable]] = None)
 52    def __init__(
 53        self,
 54        raw_path: Union[List[Any], str, os.PathLike],
 55        raw_key: Optional[str],
 56        patch_shape: Tuple[int, ...],
 57        raw_transform: Optional[Callable] = None,
 58        transform: Optional[Callable] = None,
 59        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 60        dtype: torch.dtype = torch.float32,
 61        n_samples: Optional[int] = None,
 62        sampler: Optional[Callable] = None,
 63        ndim: Optional[int] = None,
 64        with_channels: bool = False,
 65        augmentations: Optional[Tuple[Callable, Callable]] = None,
 66    ):
 67        self.raw_path = raw_path
 68        self.raw_key = raw_key
 69        self.raw = load_data(raw_path, raw_key)
 70
 71        self._with_channels = with_channels
 72
 73        if roi is not None:
 74            if isinstance(roi, slice):
 75                roi = (roi,)
 76
 77            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 78
 79        self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 80        self.roi = roi
 81
 82        self._ndim = len(self.shape) if ndim is None else ndim
 83        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 84
 85        assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 86        self.patch_shape = patch_shape
 87
 88        self.raw_transform = raw_transform
 89        self.transform = transform
 90        self.sampler = sampler
 91        self.dtype = dtype
 92
 93        if augmentations is not None:
 94            assert len(augmentations) == 2
 95        self.augmentations = augmentations
 96
 97        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 98
 99        self.sample_shape = patch_shape
100        self.trafo_halo = None
101        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
102        # which is then cut. See code below; but this ne needs to be properly tested
103
104        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
105        # if self.trafo_halo is not None:
106        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
107        #         self.trafo_halo = (0,) + self.trafo_halo
108        #     assert len(self.trafo_halo) == self._ndim
109        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
110        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
max_sampling_attempts = 500

The maximal number of sampling attempts, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

@staticmethod
def compute_len(shape, patch_shape):
47    @staticmethod
48    def compute_len(shape, patch_shape):
49        n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
50        return n_samples
raw_path
raw_key
raw
shape
roi
patch_shape
raw_transform
transform
sampler
dtype
augmentations
sample_shape
trafo_halo
ndim
115    @property
116    def ndim(self):
117        return self._ndim
def crop(self, tensor):
152    def crop(self, tensor):
153        bb = self.inner_bb
154        if tensor.ndim > len(bb):
155            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
156        return tensor[bb]