torch_em.data.segmentation_dataset

  1import os
  2import warnings
  3from typing import List, Union, Tuple, Optional, Any, Callable
  4
  5import numpy as np
  6from math import ceil
  7
  8import torch
  9
 10from elf.wrapper import RoiWrapper
 11
 12from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data, ensure_patch_shape
 13
 14
 15class SegmentationDataset(torch.utils.data.Dataset):
 16    """Dataset that provides raw data and labels stored in a container data format for segmentation training.
 17
 18    The dataset loads a patch from the raw and label data and returns a sample for a batch.
 19    Image data and label data must have the same shape, except for potential channels.
 20    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 21    Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset.
 22    It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`.
 23
 24    Args:
 25        raw_path: The file path to the raw image data. May also be a list of file paths.
 26        raw_key: The key to the internal dataset containing the raw data.
 27        label_path: The file path to the label data. May also be a list of file paths.
 28        label_key: The key to the internal dataset containing the label data
 29        patch_shape: The patch shape for a training sample.
 30        raw_transform: Transformation applied to the raw data of a sample.
 31        label_transform: Transformation applied to the label data of a sample,
 32            before applying augmentations via `transform`.
 33        label_transform2: Transformation applied to the label data of a sample,
 34            after applying augmentations via `transform`.
 35        transform: Transformation applied to both the raw data and label data of a sample.
 36            This can be used to implement data augmentations.
 37        roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
 38        dtype: The return data type of the raw data.
 39        label_dtype: The return data type of the label data.
 40        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 41        sampler: Sampler for rejecting samples according to a defined criterion.
 42            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 43        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 44        with_channels: Whether the raw data has channels.
 45        with_label_channels: Whether the label data has channels.
 46        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 47        z_ext: Extra bounding box for loading the data across z.
 48        pre_label_transform: Transformation applied to the label data of a chosen random sample,
 49            before applying the sample validity via the `sampler`.
 50    """
 51    max_sampling_attempts = 500
 52    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 53    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 54    """
 55
 56    @staticmethod
 57    def compute_len(shape, patch_shape):
 58        if patch_shape is None:
 59            return 1
 60        else:
 61            n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 62            return n_samples
 63
 64    def __init__(
 65        self,
 66        raw_path: Union[List[Any], str, os.PathLike],
 67        raw_key: Optional[str],
 68        label_path: Union[List[Any], str, os.PathLike],
 69        label_key: Optional[str],
 70        patch_shape: Tuple[int, ...],
 71        raw_transform: Optional[Callable] = None,
 72        label_transform: Optional[Callable] = None,
 73        label_transform2: Optional[Callable] = None,
 74        transform: Optional[Callable] = None,
 75        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 76        dtype: torch.dtype = torch.float32,
 77        label_dtype: torch.dtype = torch.float32,
 78        n_samples: Optional[int] = None,
 79        sampler: Optional[Callable] = None,
 80        ndim: Optional[int] = None,
 81        with_channels: bool = False,
 82        with_label_channels: bool = False,
 83        with_padding: bool = True,
 84        z_ext: Optional[int] = None,
 85        pre_label_transform: Optional[Callable] = None,
 86    ):
 87        self.raw_path = raw_path
 88        self.raw_key = raw_key
 89        self.raw = load_data(raw_path, raw_key)
 90
 91        self.label_path = label_path
 92        self.label_key = label_key
 93        self.labels = load_data(label_path, label_key)
 94
 95        self._with_channels = with_channels
 96        self._with_label_channels = with_label_channels
 97
 98        if roi is not None:
 99            if isinstance(roi, slice):
100                roi = (roi,)
101
102            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
103            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
104                RoiWrapper(self.labels, roi)
105
106        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
107        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
108        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
109
110        self.shape = shape_raw
111        self.roi = roi
112
113        self._ndim = len(shape_raw) if ndim is None else ndim
114        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
115
116        if patch_shape is not None:
117            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
118
119        self.patch_shape = patch_shape
120
121        self.raw_transform = raw_transform
122        self.label_transform = label_transform
123        self.label_transform2 = label_transform2
124        self.transform = transform
125        self.sampler = sampler
126        self.with_padding = with_padding
127        self.pre_label_transform = pre_label_transform
128
129        self.dtype = dtype
130        self.label_dtype = label_dtype
131
132        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
133
134        self.z_ext = z_ext
135
136        self.sample_shape = patch_shape
137        self.trafo_halo = None
138        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
139        # which is then cut. See code below; but this ne needs to be properly tested
140
141        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
142        # if self.trafo_halo is not None:
143        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
144        #         self.trafo_halo = (0,) + self.trafo_halo
145        #     assert len(self.trafo_halo) == self._ndim
146        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
147        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
148
149    def __len__(self):
150        return self._len
151
152    @property
153    def ndim(self):
154        return self._ndim
155
156    def _sample_bounding_box(self):
157        if self.sample_shape is None:
158            if self.z_ext is None:
159                bb_start = [0] * len(self.shape)
160                patch_shape_for_bb = self.shape
161            else:
162                z_diff = self.shape[0] - self.z_ext
163                bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:])
164                patch_shape_for_bb = (self.z_ext, *self.shape[1:])
165
166        else:
167            bb_start = [
168                np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape)
169            ]
170            patch_shape_for_bb = self.sample_shape
171
172        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
173
174    def _get_desired_raw_and_labels(self):
175        bb = self._sample_bounding_box()
176        bb_raw = (slice(None),) + bb if self._with_channels else bb
177        bb_labels = (slice(None),) + bb if self._with_label_channels else bb
178        raw, labels = self.raw[bb_raw], self.labels[bb_labels]
179
180        # Additional label transform on top to make sampler consider expected labels
181        # (eg. run connected components on disconnected semantic labels)
182        pre_label_transform = getattr(self, "pre_label_transform", None)
183        if pre_label_transform is not None:
184            labels = pre_label_transform(labels)
185
186        return raw, labels
187
188    def _get_sample(self, index):
189        if self.raw is None or self.labels is None:
190            raise RuntimeError("SegmentationDataset has not been properly deserialized.")
191
192        raw, labels = self._get_desired_raw_and_labels()
193
194        if self.sampler is not None:
195            sample_id = 0
196            while not self.sampler(raw, labels):
197                raw, labels = self._get_desired_raw_and_labels()
198                sample_id += 1
199                if sample_id > self.max_sampling_attempts:
200                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
201
202        # Padding the patch to match the expected input shape.
203        if self.patch_shape is not None and self.with_padding:
204            raw, labels = ensure_patch_shape(
205                raw=raw,
206                labels=labels,
207                patch_shape=self.patch_shape,
208                have_raw_channels=self._with_channels,
209                have_label_channels=self._with_label_channels,
210            )
211
212        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
213        if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1:
214            raw = raw.squeeze(1 if self._with_channels else 0)
215            labels = labels.squeeze(1 if self._with_label_channels else 0)
216
217        return raw, labels
218
219    def crop(self, tensor):
220        """@private
221        """
222        bb = self.inner_bb
223        if tensor.ndim > len(bb):
224            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
225        return tensor[bb]
226
227    def __getitem__(self, index):
228        raw, labels = self._get_sample(index)
229        initial_label_dtype = labels.dtype
230
231        if self.raw_transform is not None:
232            raw = self.raw_transform(raw)
233
234        if self.label_transform is not None:
235            labels = self.label_transform(labels)
236
237        if self.transform is not None:
238            raw, labels = self.transform(raw, labels)
239            if self.trafo_halo is not None:
240                raw = self.crop(raw)
241                labels = self.crop(labels)
242
243        # support enlarging bounding box here as well (for affinity transform) ?
244        if self.label_transform2 is not None:
245            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
246            labels = self.label_transform2(labels)
247
248        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
249        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
250        return raw, labels
251
252    # need to overwrite pickle to support h5py
253    def __getstate__(self):
254        state = self.__dict__.copy()
255        del state["raw"]
256        del state["labels"]
257        return state
258
259    def __setstate__(self, state):
260        raw_path, raw_key = state["raw_path"], state["raw_key"]
261        label_path, label_key = state["label_path"], state["label_key"]
262        roi = state["roi"]
263        try:
264            raw = load_data(raw_path, raw_key)
265            if roi is not None:
266                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
267            state["raw"] = raw
268        except Exception:
269            msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
270            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
271            msg += "But it cannot be used for further training and will throw an error."
272            warnings.warn(msg)
273            state["raw"] = None
274
275        try:
276            labels = load_data(label_path, label_key)
277            if roi is not None:
278                labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
279                    RoiWrapper(labels, roi)
280            state["labels"] = labels
281        except Exception:
282            msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n"
283            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
284            msg += "But it cannot be used for further training and will throw an error."
285            warnings.warn(msg)
286            state["labels"] = None
287
288        self.__dict__.update(state)
class SegmentationDataset(typing.Generic[+_T_co]):
 16class SegmentationDataset(torch.utils.data.Dataset):
 17    """Dataset that provides raw data and labels stored in a container data format for segmentation training.
 18
 19    The dataset loads a patch from the raw and label data and returns a sample for a batch.
 20    Image data and label data must have the same shape, except for potential channels.
 21    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 22    Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset.
 23    It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`.
 24
 25    Args:
 26        raw_path: The file path to the raw image data. May also be a list of file paths.
 27        raw_key: The key to the internal dataset containing the raw data.
 28        label_path: The file path to the label data. May also be a list of file paths.
 29        label_key: The key to the internal dataset containing the label data
 30        patch_shape: The patch shape for a training sample.
 31        raw_transform: Transformation applied to the raw data of a sample.
 32        label_transform: Transformation applied to the label data of a sample,
 33            before applying augmentations via `transform`.
 34        label_transform2: Transformation applied to the label data of a sample,
 35            after applying augmentations via `transform`.
 36        transform: Transformation applied to both the raw data and label data of a sample.
 37            This can be used to implement data augmentations.
 38        roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
 39        dtype: The return data type of the raw data.
 40        label_dtype: The return data type of the label data.
 41        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 42        sampler: Sampler for rejecting samples according to a defined criterion.
 43            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 44        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 45        with_channels: Whether the raw data has channels.
 46        with_label_channels: Whether the label data has channels.
 47        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 48        z_ext: Extra bounding box for loading the data across z.
 49        pre_label_transform: Transformation applied to the label data of a chosen random sample,
 50            before applying the sample validity via the `sampler`.
 51    """
 52    max_sampling_attempts = 500
 53    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 54    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 55    """
 56
 57    @staticmethod
 58    def compute_len(shape, patch_shape):
 59        if patch_shape is None:
 60            return 1
 61        else:
 62            n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 63            return n_samples
 64
 65    def __init__(
 66        self,
 67        raw_path: Union[List[Any], str, os.PathLike],
 68        raw_key: Optional[str],
 69        label_path: Union[List[Any], str, os.PathLike],
 70        label_key: Optional[str],
 71        patch_shape: Tuple[int, ...],
 72        raw_transform: Optional[Callable] = None,
 73        label_transform: Optional[Callable] = None,
 74        label_transform2: Optional[Callable] = None,
 75        transform: Optional[Callable] = None,
 76        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 77        dtype: torch.dtype = torch.float32,
 78        label_dtype: torch.dtype = torch.float32,
 79        n_samples: Optional[int] = None,
 80        sampler: Optional[Callable] = None,
 81        ndim: Optional[int] = None,
 82        with_channels: bool = False,
 83        with_label_channels: bool = False,
 84        with_padding: bool = True,
 85        z_ext: Optional[int] = None,
 86        pre_label_transform: Optional[Callable] = None,
 87    ):
 88        self.raw_path = raw_path
 89        self.raw_key = raw_key
 90        self.raw = load_data(raw_path, raw_key)
 91
 92        self.label_path = label_path
 93        self.label_key = label_key
 94        self.labels = load_data(label_path, label_key)
 95
 96        self._with_channels = with_channels
 97        self._with_label_channels = with_label_channels
 98
 99        if roi is not None:
100            if isinstance(roi, slice):
101                roi = (roi,)
102
103            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
104            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
105                RoiWrapper(self.labels, roi)
106
107        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
108        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
109        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
110
111        self.shape = shape_raw
112        self.roi = roi
113
114        self._ndim = len(shape_raw) if ndim is None else ndim
115        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
116
117        if patch_shape is not None:
118            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
119
120        self.patch_shape = patch_shape
121
122        self.raw_transform = raw_transform
123        self.label_transform = label_transform
124        self.label_transform2 = label_transform2
125        self.transform = transform
126        self.sampler = sampler
127        self.with_padding = with_padding
128        self.pre_label_transform = pre_label_transform
129
130        self.dtype = dtype
131        self.label_dtype = label_dtype
132
133        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
134
135        self.z_ext = z_ext
136
137        self.sample_shape = patch_shape
138        self.trafo_halo = None
139        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
140        # which is then cut. See code below; but this ne needs to be properly tested
141
142        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
143        # if self.trafo_halo is not None:
144        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
145        #         self.trafo_halo = (0,) + self.trafo_halo
146        #     assert len(self.trafo_halo) == self._ndim
147        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
148        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
149
150    def __len__(self):
151        return self._len
152
153    @property
154    def ndim(self):
155        return self._ndim
156
157    def _sample_bounding_box(self):
158        if self.sample_shape is None:
159            if self.z_ext is None:
160                bb_start = [0] * len(self.shape)
161                patch_shape_for_bb = self.shape
162            else:
163                z_diff = self.shape[0] - self.z_ext
164                bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:])
165                patch_shape_for_bb = (self.z_ext, *self.shape[1:])
166
167        else:
168            bb_start = [
169                np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape)
170            ]
171            patch_shape_for_bb = self.sample_shape
172
173        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
174
175    def _get_desired_raw_and_labels(self):
176        bb = self._sample_bounding_box()
177        bb_raw = (slice(None),) + bb if self._with_channels else bb
178        bb_labels = (slice(None),) + bb if self._with_label_channels else bb
179        raw, labels = self.raw[bb_raw], self.labels[bb_labels]
180
181        # Additional label transform on top to make sampler consider expected labels
182        # (eg. run connected components on disconnected semantic labels)
183        pre_label_transform = getattr(self, "pre_label_transform", None)
184        if pre_label_transform is not None:
185            labels = pre_label_transform(labels)
186
187        return raw, labels
188
189    def _get_sample(self, index):
190        if self.raw is None or self.labels is None:
191            raise RuntimeError("SegmentationDataset has not been properly deserialized.")
192
193        raw, labels = self._get_desired_raw_and_labels()
194
195        if self.sampler is not None:
196            sample_id = 0
197            while not self.sampler(raw, labels):
198                raw, labels = self._get_desired_raw_and_labels()
199                sample_id += 1
200                if sample_id > self.max_sampling_attempts:
201                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
202
203        # Padding the patch to match the expected input shape.
204        if self.patch_shape is not None and self.with_padding:
205            raw, labels = ensure_patch_shape(
206                raw=raw,
207                labels=labels,
208                patch_shape=self.patch_shape,
209                have_raw_channels=self._with_channels,
210                have_label_channels=self._with_label_channels,
211            )
212
213        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
214        if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1:
215            raw = raw.squeeze(1 if self._with_channels else 0)
216            labels = labels.squeeze(1 if self._with_label_channels else 0)
217
218        return raw, labels
219
220    def crop(self, tensor):
221        """@private
222        """
223        bb = self.inner_bb
224        if tensor.ndim > len(bb):
225            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
226        return tensor[bb]
227
228    def __getitem__(self, index):
229        raw, labels = self._get_sample(index)
230        initial_label_dtype = labels.dtype
231
232        if self.raw_transform is not None:
233            raw = self.raw_transform(raw)
234
235        if self.label_transform is not None:
236            labels = self.label_transform(labels)
237
238        if self.transform is not None:
239            raw, labels = self.transform(raw, labels)
240            if self.trafo_halo is not None:
241                raw = self.crop(raw)
242                labels = self.crop(labels)
243
244        # support enlarging bounding box here as well (for affinity transform) ?
245        if self.label_transform2 is not None:
246            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
247            labels = self.label_transform2(labels)
248
249        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
250        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
251        return raw, labels
252
253    # need to overwrite pickle to support h5py
254    def __getstate__(self):
255        state = self.__dict__.copy()
256        del state["raw"]
257        del state["labels"]
258        return state
259
260    def __setstate__(self, state):
261        raw_path, raw_key = state["raw_path"], state["raw_key"]
262        label_path, label_key = state["label_path"], state["label_key"]
263        roi = state["roi"]
264        try:
265            raw = load_data(raw_path, raw_key)
266            if roi is not None:
267                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
268            state["raw"] = raw
269        except Exception:
270            msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
271            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
272            msg += "But it cannot be used for further training and will throw an error."
273            warnings.warn(msg)
274            state["raw"] = None
275
276        try:
277            labels = load_data(label_path, label_key)
278            if roi is not None:
279                labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
280                    RoiWrapper(labels, roi)
281            state["labels"] = labels
282        except Exception:
283            msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n"
284            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
285            msg += "But it cannot be used for further training and will throw an error."
286            warnings.warn(msg)
287            state["labels"] = None
288
289        self.__dict__.update(state)

Dataset that provides raw data and labels stored in a container data format for segmentation training.

The dataset loads a patch from the raw and label data and returns a sample for a batch. Image data and label data must have the same shape, except for potential channels. The dataset supports all file formats that can be opened with elf.io.open_file, such as hdf5, zarr or n5. Use raw_path / label_path to specify the file path and raw_key / label_key to specify the internal dataset. It also supports regular image formats, such as .tif. For these cases set raw_key=None / label_key=None.

Arguments:
  • raw_path: The file path to the raw image data. May also be a list of file paths.
  • raw_key: The key to the internal dataset containing the raw data.
  • label_path: The file path to the label data. May also be a list of file paths.
  • label_key: The key to the internal dataset containing the label data
  • patch_shape: The patch shape for a training sample.
  • raw_transform: Transformation applied to the raw data of a sample.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
  • pre_label_transform: Transformation applied to the label data of a chosen random sample, before applying the sample validity via the sampler.
SegmentationDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_path: Union[List[Any], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, transform: Optional[Callable] = None, roi: Union[slice, Tuple[slice, ...], NoneType] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, with_channels: bool = False, with_label_channels: bool = False, with_padding: bool = True, z_ext: Optional[int] = None, pre_label_transform: Optional[Callable] = None)
 65    def __init__(
 66        self,
 67        raw_path: Union[List[Any], str, os.PathLike],
 68        raw_key: Optional[str],
 69        label_path: Union[List[Any], str, os.PathLike],
 70        label_key: Optional[str],
 71        patch_shape: Tuple[int, ...],
 72        raw_transform: Optional[Callable] = None,
 73        label_transform: Optional[Callable] = None,
 74        label_transform2: Optional[Callable] = None,
 75        transform: Optional[Callable] = None,
 76        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 77        dtype: torch.dtype = torch.float32,
 78        label_dtype: torch.dtype = torch.float32,
 79        n_samples: Optional[int] = None,
 80        sampler: Optional[Callable] = None,
 81        ndim: Optional[int] = None,
 82        with_channels: bool = False,
 83        with_label_channels: bool = False,
 84        with_padding: bool = True,
 85        z_ext: Optional[int] = None,
 86        pre_label_transform: Optional[Callable] = None,
 87    ):
 88        self.raw_path = raw_path
 89        self.raw_key = raw_key
 90        self.raw = load_data(raw_path, raw_key)
 91
 92        self.label_path = label_path
 93        self.label_key = label_key
 94        self.labels = load_data(label_path, label_key)
 95
 96        self._with_channels = with_channels
 97        self._with_label_channels = with_label_channels
 98
 99        if roi is not None:
100            if isinstance(roi, slice):
101                roi = (roi,)
102
103            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
104            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
105                RoiWrapper(self.labels, roi)
106
107        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
108        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
109        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
110
111        self.shape = shape_raw
112        self.roi = roi
113
114        self._ndim = len(shape_raw) if ndim is None else ndim
115        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
116
117        if patch_shape is not None:
118            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
119
120        self.patch_shape = patch_shape
121
122        self.raw_transform = raw_transform
123        self.label_transform = label_transform
124        self.label_transform2 = label_transform2
125        self.transform = transform
126        self.sampler = sampler
127        self.with_padding = with_padding
128        self.pre_label_transform = pre_label_transform
129
130        self.dtype = dtype
131        self.label_dtype = label_dtype
132
133        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
134
135        self.z_ext = z_ext
136
137        self.sample_shape = patch_shape
138        self.trafo_halo = None
139        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
140        # which is then cut. See code below; but this ne needs to be properly tested
141
142        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
143        # if self.trafo_halo is not None:
144        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
145        #         self.trafo_halo = (0,) + self.trafo_halo
146        #     assert len(self.trafo_halo) == self._ndim
147        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
148        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
max_sampling_attempts = 500

The maximal number of sampling attempts, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

@staticmethod
def compute_len(shape, patch_shape):
57    @staticmethod
58    def compute_len(shape, patch_shape):
59        if patch_shape is None:
60            return 1
61        else:
62            n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
63            return n_samples
raw_path
raw_key
raw
label_path
label_key
labels
shape
roi
patch_shape
raw_transform
label_transform
label_transform2
transform
sampler
with_padding
pre_label_transform
dtype
label_dtype
z_ext
sample_shape
trafo_halo
ndim
153    @property
154    def ndim(self):
155        return self._ndim