torch_em.data.segmentation_dataset

  1import os
  2import warnings
  3from typing import List, Union, Tuple, Optional, Any, Callable
  4
  5import numpy as np
  6from math import ceil
  7
  8import torch
  9
 10from elf.wrapper import RoiWrapper
 11
 12from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data, ensure_patch_shape
 13
 14
 15class SegmentationDataset(torch.utils.data.Dataset):
 16    """Dataset that provides raw data and labels stored in a container data format for segmentation training.
 17
 18    The dataset loads a patch from the raw and label data and returns a sample for a batch.
 19    Image data and label data must have the same shape, except for potential channels.
 20    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 21    Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset.
 22    It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`.
 23
 24    Args:
 25        raw_path: The file path to the raw image data. May also be a list of file paths.
 26        raw_key: The key to the internal dataset containing the raw data.
 27        label_path: The file path to the label data. May also be a list of file paths.
 28        label_key: The key to the internal dataset containing the label data
 29        patch_shape: The patch shape for a training sample.
 30        raw_transform: Transformation applied to the raw data of a sample.
 31        label_transform: Transformation applied to the label data of a sample,
 32            before applying augmentations via `transform`.
 33        label_transform2: Transformation applied to the label data of a sample,
 34            after applying augmentations via `transform`.
 35        transform: Transformation applied to both the raw data and label data of a sample.
 36            This can be used to implement data augmentations.
 37        roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
 38        dtype: The return data type of the raw data.
 39        label_dtype: The return data type of the label data.
 40        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 41        sampler: Sampler for rejecting samples according to a defined criterion.
 42            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 43        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 44        with_channels: Whether the raw data has channels.
 45        with_label_channels: Whether the label data has channels.
 46        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 47        z_ext: Extra bounding box for loading the data across z.
 48    """
 49    max_sampling_attempts = 500
 50    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 51    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 52    """
 53
 54    @staticmethod
 55    def compute_len(shape, patch_shape):
 56        if patch_shape is None:
 57            return 1
 58        else:
 59            n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 60            return n_samples
 61
 62    def __init__(
 63        self,
 64        raw_path: Union[List[Any], str, os.PathLike],
 65        raw_key: Optional[str],
 66        label_path: Union[List[Any], str, os.PathLike],
 67        label_key: Optional[str],
 68        patch_shape: Tuple[int, ...],
 69        raw_transform: Optional[Callable] = None,
 70        label_transform: Optional[Callable] = None,
 71        label_transform2: Optional[Callable] = None,
 72        transform: Optional[Callable] = None,
 73        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 74        dtype: torch.dtype = torch.float32,
 75        label_dtype: torch.dtype = torch.float32,
 76        n_samples: Optional[int] = None,
 77        sampler: Optional[Callable] = None,
 78        ndim: Optional[int] = None,
 79        with_channels: bool = False,
 80        with_label_channels: bool = False,
 81        with_padding: bool = True,
 82        z_ext: Optional[int] = None,
 83    ):
 84        self.raw_path = raw_path
 85        self.raw_key = raw_key
 86        self.raw = load_data(raw_path, raw_key)
 87
 88        self.label_path = label_path
 89        self.label_key = label_key
 90        self.labels = load_data(label_path, label_key)
 91
 92        self._with_channels = with_channels
 93        self._with_label_channels = with_label_channels
 94
 95        if roi is not None:
 96            if isinstance(roi, slice):
 97                roi = (roi,)
 98
 99            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
100            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
101                RoiWrapper(self.labels, roi)
102
103        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
104        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
105        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
106
107        self.shape = shape_raw
108        self.roi = roi
109
110        self._ndim = len(shape_raw) if ndim is None else ndim
111        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
112
113        if patch_shape is not None:
114            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
115
116        self.patch_shape = patch_shape
117
118        self.raw_transform = raw_transform
119        self.label_transform = label_transform
120        self.label_transform2 = label_transform2
121        self.transform = transform
122        self.sampler = sampler
123        self.with_padding = with_padding
124
125        self.dtype = dtype
126        self.label_dtype = label_dtype
127
128        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
129
130        self.z_ext = z_ext
131
132        self.sample_shape = patch_shape
133        self.trafo_halo = None
134        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
135        # which is then cut. See code below; but this ne needs to be properly tested
136
137        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
138        # if self.trafo_halo is not None:
139        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
140        #         self.trafo_halo = (0,) + self.trafo_halo
141        #     assert len(self.trafo_halo) == self._ndim
142        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
143        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
144
145    def __len__(self):
146        return self._len
147
148    @property
149    def ndim(self):
150        return self._ndim
151
152    def _sample_bounding_box(self):
153        if self.sample_shape is None:
154            if self.z_ext is None:
155                bb_start = [0] * len(self.shape)
156                patch_shape_for_bb = self.shape
157            else:
158                z_diff = self.shape[0] - self.z_ext
159                bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:])
160                patch_shape_for_bb = (self.z_ext, *self.shape[1:])
161
162        else:
163            bb_start = [
164                np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape)
165            ]
166            patch_shape_for_bb = self.sample_shape
167
168        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
169
170    def _get_desired_raw_and_labels(self):
171        bb = self._sample_bounding_box()
172        bb_raw = (slice(None),) + bb if self._with_channels else bb
173        bb_labels = (slice(None),) + bb if self._with_label_channels else bb
174        raw, labels = self.raw[bb_raw], self.labels[bb_labels]
175        return raw, labels
176
177    def _get_sample(self, index):
178        if self.raw is None or self.labels is None:
179            raise RuntimeError("SegmentationDataset has not been properly deserialized.")
180
181        raw, labels = self._get_desired_raw_and_labels()
182
183        if self.sampler is not None:
184            sample_id = 0
185            while not self.sampler(raw, labels):
186                raw, labels = self._get_desired_raw_and_labels()
187                sample_id += 1
188                if sample_id > self.max_sampling_attempts:
189                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
190
191        # Padding the patch to match the expected input shape.
192        if self.patch_shape is not None and self.with_padding:
193            raw, labels = ensure_patch_shape(
194                raw=raw,
195                labels=labels,
196                patch_shape=self.patch_shape,
197                have_raw_channels=self._with_channels,
198                have_label_channels=self._with_label_channels,
199            )
200
201        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
202        if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1:
203            raw = raw.squeeze(1 if self._with_channels else 0)
204            labels = labels.squeeze(1 if self._with_label_channels else 0)
205
206        return raw, labels
207
208    def crop(self, tensor):
209        """@private
210        """
211        bb = self.inner_bb
212        if tensor.ndim > len(bb):
213            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
214        return tensor[bb]
215
216    def __getitem__(self, index):
217        raw, labels = self._get_sample(index)
218        initial_label_dtype = labels.dtype
219
220        if self.raw_transform is not None:
221            raw = self.raw_transform(raw)
222
223        if self.label_transform is not None:
224            labels = self.label_transform(labels)
225
226        if self.transform is not None:
227            raw, labels = self.transform(raw, labels)
228            if self.trafo_halo is not None:
229                raw = self.crop(raw)
230                labels = self.crop(labels)
231
232        # support enlarging bounding box here as well (for affinity transform) ?
233        if self.label_transform2 is not None:
234            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
235            labels = self.label_transform2(labels)
236
237        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
238        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
239        return raw, labels
240
241    # need to overwrite pickle to support h5py
242    def __getstate__(self):
243        state = self.__dict__.copy()
244        del state["raw"]
245        del state["labels"]
246        return state
247
248    def __setstate__(self, state):
249        raw_path, raw_key = state["raw_path"], state["raw_key"]
250        label_path, label_key = state["label_path"], state["label_key"]
251        roi = state["roi"]
252        try:
253            raw = load_data(raw_path, raw_key)
254            if roi is not None:
255                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
256            state["raw"] = raw
257        except Exception:
258            msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
259            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
260            msg += "But it cannot be used for further training and will throw an error."
261            warnings.warn(msg)
262            state["raw"] = None
263
264        try:
265            labels = load_data(label_path, label_key)
266            if roi is not None:
267                labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
268                    RoiWrapper(labels, roi)
269            state["labels"] = labels
270        except Exception:
271            msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n"
272            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
273            msg += "But it cannot be used for further training and will throw an error."
274            warnings.warn(msg)
275            state["labels"] = None
276
277        self.__dict__.update(state)
class SegmentationDataset(typing.Generic[+_T_co]):
 16class SegmentationDataset(torch.utils.data.Dataset):
 17    """Dataset that provides raw data and labels stored in a container data format for segmentation training.
 18
 19    The dataset loads a patch from the raw and label data and returns a sample for a batch.
 20    Image data and label data must have the same shape, except for potential channels.
 21    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 22    Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset.
 23    It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`.
 24
 25    Args:
 26        raw_path: The file path to the raw image data. May also be a list of file paths.
 27        raw_key: The key to the internal dataset containing the raw data.
 28        label_path: The file path to the label data. May also be a list of file paths.
 29        label_key: The key to the internal dataset containing the label data
 30        patch_shape: The patch shape for a training sample.
 31        raw_transform: Transformation applied to the raw data of a sample.
 32        label_transform: Transformation applied to the label data of a sample,
 33            before applying augmentations via `transform`.
 34        label_transform2: Transformation applied to the label data of a sample,
 35            after applying augmentations via `transform`.
 36        transform: Transformation applied to both the raw data and label data of a sample.
 37            This can be used to implement data augmentations.
 38        roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
 39        dtype: The return data type of the raw data.
 40        label_dtype: The return data type of the label data.
 41        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 42        sampler: Sampler for rejecting samples according to a defined criterion.
 43            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 44        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 45        with_channels: Whether the raw data has channels.
 46        with_label_channels: Whether the label data has channels.
 47        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 48        z_ext: Extra bounding box for loading the data across z.
 49    """
 50    max_sampling_attempts = 500
 51    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 52    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 53    """
 54
 55    @staticmethod
 56    def compute_len(shape, patch_shape):
 57        if patch_shape is None:
 58            return 1
 59        else:
 60            n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 61            return n_samples
 62
 63    def __init__(
 64        self,
 65        raw_path: Union[List[Any], str, os.PathLike],
 66        raw_key: Optional[str],
 67        label_path: Union[List[Any], str, os.PathLike],
 68        label_key: Optional[str],
 69        patch_shape: Tuple[int, ...],
 70        raw_transform: Optional[Callable] = None,
 71        label_transform: Optional[Callable] = None,
 72        label_transform2: Optional[Callable] = None,
 73        transform: Optional[Callable] = None,
 74        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 75        dtype: torch.dtype = torch.float32,
 76        label_dtype: torch.dtype = torch.float32,
 77        n_samples: Optional[int] = None,
 78        sampler: Optional[Callable] = None,
 79        ndim: Optional[int] = None,
 80        with_channels: bool = False,
 81        with_label_channels: bool = False,
 82        with_padding: bool = True,
 83        z_ext: Optional[int] = None,
 84    ):
 85        self.raw_path = raw_path
 86        self.raw_key = raw_key
 87        self.raw = load_data(raw_path, raw_key)
 88
 89        self.label_path = label_path
 90        self.label_key = label_key
 91        self.labels = load_data(label_path, label_key)
 92
 93        self._with_channels = with_channels
 94        self._with_label_channels = with_label_channels
 95
 96        if roi is not None:
 97            if isinstance(roi, slice):
 98                roi = (roi,)
 99
100            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
101            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
102                RoiWrapper(self.labels, roi)
103
104        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
105        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
106        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
107
108        self.shape = shape_raw
109        self.roi = roi
110
111        self._ndim = len(shape_raw) if ndim is None else ndim
112        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
113
114        if patch_shape is not None:
115            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
116
117        self.patch_shape = patch_shape
118
119        self.raw_transform = raw_transform
120        self.label_transform = label_transform
121        self.label_transform2 = label_transform2
122        self.transform = transform
123        self.sampler = sampler
124        self.with_padding = with_padding
125
126        self.dtype = dtype
127        self.label_dtype = label_dtype
128
129        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
130
131        self.z_ext = z_ext
132
133        self.sample_shape = patch_shape
134        self.trafo_halo = None
135        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
136        # which is then cut. See code below; but this ne needs to be properly tested
137
138        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
139        # if self.trafo_halo is not None:
140        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
141        #         self.trafo_halo = (0,) + self.trafo_halo
142        #     assert len(self.trafo_halo) == self._ndim
143        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
144        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
145
146    def __len__(self):
147        return self._len
148
149    @property
150    def ndim(self):
151        return self._ndim
152
153    def _sample_bounding_box(self):
154        if self.sample_shape is None:
155            if self.z_ext is None:
156                bb_start = [0] * len(self.shape)
157                patch_shape_for_bb = self.shape
158            else:
159                z_diff = self.shape[0] - self.z_ext
160                bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:])
161                patch_shape_for_bb = (self.z_ext, *self.shape[1:])
162
163        else:
164            bb_start = [
165                np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape)
166            ]
167            patch_shape_for_bb = self.sample_shape
168
169        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
170
171    def _get_desired_raw_and_labels(self):
172        bb = self._sample_bounding_box()
173        bb_raw = (slice(None),) + bb if self._with_channels else bb
174        bb_labels = (slice(None),) + bb if self._with_label_channels else bb
175        raw, labels = self.raw[bb_raw], self.labels[bb_labels]
176        return raw, labels
177
178    def _get_sample(self, index):
179        if self.raw is None or self.labels is None:
180            raise RuntimeError("SegmentationDataset has not been properly deserialized.")
181
182        raw, labels = self._get_desired_raw_and_labels()
183
184        if self.sampler is not None:
185            sample_id = 0
186            while not self.sampler(raw, labels):
187                raw, labels = self._get_desired_raw_and_labels()
188                sample_id += 1
189                if sample_id > self.max_sampling_attempts:
190                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
191
192        # Padding the patch to match the expected input shape.
193        if self.patch_shape is not None and self.with_padding:
194            raw, labels = ensure_patch_shape(
195                raw=raw,
196                labels=labels,
197                patch_shape=self.patch_shape,
198                have_raw_channels=self._with_channels,
199                have_label_channels=self._with_label_channels,
200            )
201
202        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
203        if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1:
204            raw = raw.squeeze(1 if self._with_channels else 0)
205            labels = labels.squeeze(1 if self._with_label_channels else 0)
206
207        return raw, labels
208
209    def crop(self, tensor):
210        """@private
211        """
212        bb = self.inner_bb
213        if tensor.ndim > len(bb):
214            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
215        return tensor[bb]
216
217    def __getitem__(self, index):
218        raw, labels = self._get_sample(index)
219        initial_label_dtype = labels.dtype
220
221        if self.raw_transform is not None:
222            raw = self.raw_transform(raw)
223
224        if self.label_transform is not None:
225            labels = self.label_transform(labels)
226
227        if self.transform is not None:
228            raw, labels = self.transform(raw, labels)
229            if self.trafo_halo is not None:
230                raw = self.crop(raw)
231                labels = self.crop(labels)
232
233        # support enlarging bounding box here as well (for affinity transform) ?
234        if self.label_transform2 is not None:
235            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
236            labels = self.label_transform2(labels)
237
238        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
239        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
240        return raw, labels
241
242    # need to overwrite pickle to support h5py
243    def __getstate__(self):
244        state = self.__dict__.copy()
245        del state["raw"]
246        del state["labels"]
247        return state
248
249    def __setstate__(self, state):
250        raw_path, raw_key = state["raw_path"], state["raw_key"]
251        label_path, label_key = state["label_path"], state["label_key"]
252        roi = state["roi"]
253        try:
254            raw = load_data(raw_path, raw_key)
255            if roi is not None:
256                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
257            state["raw"] = raw
258        except Exception:
259            msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
260            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
261            msg += "But it cannot be used for further training and will throw an error."
262            warnings.warn(msg)
263            state["raw"] = None
264
265        try:
266            labels = load_data(label_path, label_key)
267            if roi is not None:
268                labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
269                    RoiWrapper(labels, roi)
270            state["labels"] = labels
271        except Exception:
272            msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n"
273            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
274            msg += "But it cannot be used for further training and will throw an error."
275            warnings.warn(msg)
276            state["labels"] = None
277
278        self.__dict__.update(state)

Dataset that provides raw data and labels stored in a container data format for segmentation training.

The dataset loads a patch from the raw and label data and returns a sample for a batch. Image data and label data must have the same shape, except for potential channels. The dataset supports all file formats that can be opened with elf.io.open_file, such as hdf5, zarr or n5. Use raw_path / label_path to specify the file path and raw_key / label_key to specify the internal dataset. It also supports regular image formats, such as .tif. For these cases set raw_key=None / label_key=None.

Arguments:
  • raw_path: The file path to the raw image data. May also be a list of file paths.
  • raw_key: The key to the internal dataset containing the raw data.
  • label_path: The file path to the label data. May also be a list of file paths.
  • label_key: The key to the internal dataset containing the label data
  • patch_shape: The patch shape for a training sample.
  • raw_transform: Transformation applied to the raw data of a sample.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
SegmentationDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_path: Union[List[Any], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, transform: Optional[Callable] = None, roi: Union[slice, Tuple[slice, ...], NoneType] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, with_channels: bool = False, with_label_channels: bool = False, with_padding: bool = True, z_ext: Optional[int] = None)
 63    def __init__(
 64        self,
 65        raw_path: Union[List[Any], str, os.PathLike],
 66        raw_key: Optional[str],
 67        label_path: Union[List[Any], str, os.PathLike],
 68        label_key: Optional[str],
 69        patch_shape: Tuple[int, ...],
 70        raw_transform: Optional[Callable] = None,
 71        label_transform: Optional[Callable] = None,
 72        label_transform2: Optional[Callable] = None,
 73        transform: Optional[Callable] = None,
 74        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 75        dtype: torch.dtype = torch.float32,
 76        label_dtype: torch.dtype = torch.float32,
 77        n_samples: Optional[int] = None,
 78        sampler: Optional[Callable] = None,
 79        ndim: Optional[int] = None,
 80        with_channels: bool = False,
 81        with_label_channels: bool = False,
 82        with_padding: bool = True,
 83        z_ext: Optional[int] = None,
 84    ):
 85        self.raw_path = raw_path
 86        self.raw_key = raw_key
 87        self.raw = load_data(raw_path, raw_key)
 88
 89        self.label_path = label_path
 90        self.label_key = label_key
 91        self.labels = load_data(label_path, label_key)
 92
 93        self._with_channels = with_channels
 94        self._with_label_channels = with_label_channels
 95
 96        if roi is not None:
 97            if isinstance(roi, slice):
 98                roi = (roi,)
 99
100            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
101            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
102                RoiWrapper(self.labels, roi)
103
104        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
105        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
106        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
107
108        self.shape = shape_raw
109        self.roi = roi
110
111        self._ndim = len(shape_raw) if ndim is None else ndim
112        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
113
114        if patch_shape is not None:
115            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
116
117        self.patch_shape = patch_shape
118
119        self.raw_transform = raw_transform
120        self.label_transform = label_transform
121        self.label_transform2 = label_transform2
122        self.transform = transform
123        self.sampler = sampler
124        self.with_padding = with_padding
125
126        self.dtype = dtype
127        self.label_dtype = label_dtype
128
129        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
130
131        self.z_ext = z_ext
132
133        self.sample_shape = patch_shape
134        self.trafo_halo = None
135        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
136        # which is then cut. See code below; but this ne needs to be properly tested
137
138        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
139        # if self.trafo_halo is not None:
140        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
141        #         self.trafo_halo = (0,) + self.trafo_halo
142        #     assert len(self.trafo_halo) == self._ndim
143        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
144        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
max_sampling_attempts = 500

The maximal number of sampling attempts, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

@staticmethod
def compute_len(shape, patch_shape):
55    @staticmethod
56    def compute_len(shape, patch_shape):
57        if patch_shape is None:
58            return 1
59        else:
60            n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
61            return n_samples
raw_path
raw_key
raw
label_path
label_key
labels
shape
roi
patch_shape
raw_transform
label_transform
label_transform2
transform
sampler
with_padding
dtype
label_dtype
z_ext
sample_shape
trafo_halo
ndim
149    @property
150    def ndim(self):
151        return self._ndim