torch_em.data.segmentation_dataset

  1import os
  2import warnings
  3from typing import List, Union, Tuple, Optional, Any, Callable
  4
  5import numpy as np
  6from math import ceil
  7
  8import torch
  9
 10from elf.wrapper import RoiWrapper
 11
 12from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data, ensure_patch_shape, validate_roi
 13
 14
 15class SegmentationDataset(torch.utils.data.Dataset):
 16    """Dataset that provides raw data and labels stored in a container data format for segmentation training.
 17
 18    The dataset loads a patch from the raw and label data and returns a sample for a batch.
 19    Image data and label data must have the same shape, except for potential channels.
 20    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 21    Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset.
 22    It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`.
 23
 24    Args:
 25        raw_path: The file path to the raw image data. May also be a list of file paths.
 26        raw_key: The key to the internal dataset containing the raw data.
 27        label_path: The file path to the label data. May also be a list of file paths.
 28        label_key: The key to the internal dataset containing the label data
 29        patch_shape: The patch shape for a training sample.
 30        raw_transform: Transformation applied to the raw data of a sample.
 31        label_transform: Transformation applied to the label data of a sample,
 32            before applying augmentations via `transform`.
 33        label_transform2: Transformation applied to the label data of a sample,
 34            after applying augmentations via `transform`.
 35        transform: Transformation applied to both the raw data and label data of a sample.
 36            This can be used to implement data augmentations.
 37        roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
 38        dtype: The return data type of the raw data.
 39        label_dtype: The return data type of the label data.
 40        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 41        sampler: Sampler for rejecting samples according to a defined criterion.
 42            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 43        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 44        with_channels: Whether the raw data has channels.
 45        with_label_channels: Whether the label data has channels.
 46        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 47        z_ext: Extra bounding box for loading the data across z.
 48        pre_label_transform: Transformation applied to the label data of a chosen random sample,
 49            before applying the sample validity via the `sampler`.
 50    """
 51    max_sampling_attempts = 500
 52    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 53    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 54    """
 55
 56    @staticmethod
 57    def compute_len(shape, patch_shape):
 58        if patch_shape is None:
 59            return 1
 60        else:
 61            n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 62            return n_samples
 63
 64    def __init__(
 65        self,
 66        raw_path: Union[List[Any], str, os.PathLike],
 67        raw_key: Optional[str],
 68        label_path: Union[List[Any], str, os.PathLike],
 69        label_key: Optional[str],
 70        patch_shape: Tuple[int, ...],
 71        raw_transform: Optional[Callable] = None,
 72        label_transform: Optional[Callable] = None,
 73        label_transform2: Optional[Callable] = None,
 74        transform: Optional[Callable] = None,
 75        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 76        dtype: torch.dtype = torch.float32,
 77        label_dtype: torch.dtype = torch.float32,
 78        n_samples: Optional[int] = None,
 79        sampler: Optional[Callable] = None,
 80        ndim: Optional[int] = None,
 81        with_channels: bool = False,
 82        with_label_channels: bool = False,
 83        with_padding: bool = True,
 84        z_ext: Optional[int] = None,
 85        pre_label_transform: Optional[Callable] = None,
 86    ):
 87        self.raw_path = raw_path
 88        self.raw_key = raw_key
 89        self.raw = load_data(raw_path, raw_key)
 90
 91        self.label_path = label_path
 92        self.label_key = label_key
 93        self.labels = load_data(label_path, label_key)
 94
 95        self._with_channels = with_channels
 96        self._with_label_channels = with_label_channels
 97
 98        if roi is not None:
 99            shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
100            roi = validate_roi(roi, shape, patch_shape)
101            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
102            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
103                RoiWrapper(self.labels, roi)
104
105        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
106        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
107        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
108
109        self.shape = shape_raw
110        self.roi = roi
111
112        self._ndim = len(shape_raw) if ndim is None else ndim
113        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
114
115        if patch_shape is not None:
116            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
117
118        self.patch_shape = patch_shape
119
120        self.raw_transform = raw_transform
121        self.label_transform = label_transform
122        self.label_transform2 = label_transform2
123        self.transform = transform
124        self.sampler = sampler
125        self.with_padding = with_padding
126        self.pre_label_transform = pre_label_transform
127
128        self.dtype = dtype
129        self.label_dtype = label_dtype
130
131        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
132
133        self.z_ext = z_ext
134
135        self.sample_shape = patch_shape
136        self.trafo_halo = None
137        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
138        # which is then cut. See code below; but this ne needs to be properly tested
139
140        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
141        # if self.trafo_halo is not None:
142        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
143        #         self.trafo_halo = (0,) + self.trafo_halo
144        #     assert len(self.trafo_halo) == self._ndim
145        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
146        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
147
148    def __len__(self):
149        return self._len
150
151    @property
152    def ndim(self):
153        return self._ndim
154
155    def _sample_bounding_box(self):
156        if self.sample_shape is None:
157            if self.z_ext is None:
158                bb_start = [0] * len(self.shape)
159                patch_shape_for_bb = self.shape
160            else:
161                z_diff = self.shape[0] - self.z_ext
162                bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:])
163                patch_shape_for_bb = (self.z_ext, *self.shape[1:])
164
165        else:
166            bb_start = [
167                np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape)
168            ]
169            patch_shape_for_bb = self.sample_shape
170
171        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
172
173    def _get_desired_raw_and_labels(self):
174        bb = self._sample_bounding_box()
175        bb_raw = (slice(None),) + bb if self._with_channels else bb
176        bb_labels = (slice(None),) + bb if self._with_label_channels else bb
177        raw, labels = self.raw[bb_raw], self.labels[bb_labels]
178
179        # Additional label transform on top to make sampler consider expected labels
180        # (eg. run connected components on disconnected semantic labels)
181        pre_label_transform = getattr(self, "pre_label_transform", None)
182        if pre_label_transform is not None:
183            labels = pre_label_transform(labels)
184
185        return raw, labels
186
187    def _get_sample(self, index):
188        if self.raw is None or self.labels is None:
189            raise RuntimeError("SegmentationDataset has not been properly deserialized.")
190
191        raw, labels = self._get_desired_raw_and_labels()
192
193        if self.sampler is not None:
194            sample_id = 0
195            while not self.sampler(raw, labels):
196                raw, labels = self._get_desired_raw_and_labels()
197                sample_id += 1
198                if sample_id > self.max_sampling_attempts:
199                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
200
201        # Padding the patch to match the expected input shape.
202        if self.patch_shape is not None and self.with_padding:
203            raw, labels = ensure_patch_shape(
204                raw=raw,
205                labels=labels,
206                patch_shape=self.patch_shape,
207                have_raw_channels=self._with_channels,
208                have_label_channels=self._with_label_channels,
209            )
210
211        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
212        if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1:
213            raw = raw.squeeze(1 if self._with_channels else 0)
214            labels = labels.squeeze(1 if self._with_label_channels else 0)
215
216        return raw, labels
217
218    def crop(self, tensor):
219        """@private
220        """
221        bb = self.inner_bb
222        if tensor.ndim > len(bb):
223            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
224        return tensor[bb]
225
226    def __getitem__(self, index):
227        raw, labels = self._get_sample(index)
228        initial_label_dtype = labels.dtype
229
230        if self.raw_transform is not None:
231            raw = self.raw_transform(raw)
232
233        if self.label_transform is not None:
234            labels = self.label_transform(labels)
235
236        if self.transform is not None:
237            raw, labels = self.transform(raw, labels)
238            if self.trafo_halo is not None:
239                raw = self.crop(raw)
240                labels = self.crop(labels)
241
242        # support enlarging bounding box here as well (for affinity transform) ?
243        if self.label_transform2 is not None:
244            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
245            labels = self.label_transform2(labels)
246
247        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
248        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
249        return raw, labels
250
251    # need to overwrite pickle to support h5py
252    def __getstate__(self):
253        state = self.__dict__.copy()
254        del state["raw"]
255        del state["labels"]
256        return state
257
258    def __setstate__(self, state):
259        raw_path, raw_key = state["raw_path"], state["raw_key"]
260        label_path, label_key = state["label_path"], state["label_key"]
261        roi = state["roi"]
262        try:
263            raw = load_data(raw_path, raw_key)
264            if roi is not None:
265                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
266            state["raw"] = raw
267        except Exception:
268            msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
269            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
270            msg += "But it cannot be used for further training and will throw an error."
271            warnings.warn(msg)
272            state["raw"] = None
273
274        try:
275            labels = load_data(label_path, label_key)
276            if roi is not None:
277                labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
278                    RoiWrapper(labels, roi)
279            state["labels"] = labels
280        except Exception:
281            msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n"
282            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
283            msg += "But it cannot be used for further training and will throw an error."
284            warnings.warn(msg)
285            state["labels"] = None
286
287        self.__dict__.update(state)
class SegmentationDataset(typing.Generic[+_T_co]):
 16class SegmentationDataset(torch.utils.data.Dataset):
 17    """Dataset that provides raw data and labels stored in a container data format for segmentation training.
 18
 19    The dataset loads a patch from the raw and label data and returns a sample for a batch.
 20    Image data and label data must have the same shape, except for potential channels.
 21    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 22    Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset.
 23    It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`.
 24
 25    Args:
 26        raw_path: The file path to the raw image data. May also be a list of file paths.
 27        raw_key: The key to the internal dataset containing the raw data.
 28        label_path: The file path to the label data. May also be a list of file paths.
 29        label_key: The key to the internal dataset containing the label data
 30        patch_shape: The patch shape for a training sample.
 31        raw_transform: Transformation applied to the raw data of a sample.
 32        label_transform: Transformation applied to the label data of a sample,
 33            before applying augmentations via `transform`.
 34        label_transform2: Transformation applied to the label data of a sample,
 35            after applying augmentations via `transform`.
 36        transform: Transformation applied to both the raw data and label data of a sample.
 37            This can be used to implement data augmentations.
 38        roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
 39        dtype: The return data type of the raw data.
 40        label_dtype: The return data type of the label data.
 41        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 42        sampler: Sampler for rejecting samples according to a defined criterion.
 43            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 44        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 45        with_channels: Whether the raw data has channels.
 46        with_label_channels: Whether the label data has channels.
 47        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 48        z_ext: Extra bounding box for loading the data across z.
 49        pre_label_transform: Transformation applied to the label data of a chosen random sample,
 50            before applying the sample validity via the `sampler`.
 51    """
 52    max_sampling_attempts = 500
 53    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 54    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 55    """
 56
 57    @staticmethod
 58    def compute_len(shape, patch_shape):
 59        if patch_shape is None:
 60            return 1
 61        else:
 62            n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 63            return n_samples
 64
 65    def __init__(
 66        self,
 67        raw_path: Union[List[Any], str, os.PathLike],
 68        raw_key: Optional[str],
 69        label_path: Union[List[Any], str, os.PathLike],
 70        label_key: Optional[str],
 71        patch_shape: Tuple[int, ...],
 72        raw_transform: Optional[Callable] = None,
 73        label_transform: Optional[Callable] = None,
 74        label_transform2: Optional[Callable] = None,
 75        transform: Optional[Callable] = None,
 76        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 77        dtype: torch.dtype = torch.float32,
 78        label_dtype: torch.dtype = torch.float32,
 79        n_samples: Optional[int] = None,
 80        sampler: Optional[Callable] = None,
 81        ndim: Optional[int] = None,
 82        with_channels: bool = False,
 83        with_label_channels: bool = False,
 84        with_padding: bool = True,
 85        z_ext: Optional[int] = None,
 86        pre_label_transform: Optional[Callable] = None,
 87    ):
 88        self.raw_path = raw_path
 89        self.raw_key = raw_key
 90        self.raw = load_data(raw_path, raw_key)
 91
 92        self.label_path = label_path
 93        self.label_key = label_key
 94        self.labels = load_data(label_path, label_key)
 95
 96        self._with_channels = with_channels
 97        self._with_label_channels = with_label_channels
 98
 99        if roi is not None:
100            shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
101            roi = validate_roi(roi, shape, patch_shape)
102            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
103            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
104                RoiWrapper(self.labels, roi)
105
106        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
107        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
108        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
109
110        self.shape = shape_raw
111        self.roi = roi
112
113        self._ndim = len(shape_raw) if ndim is None else ndim
114        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
115
116        if patch_shape is not None:
117            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
118
119        self.patch_shape = patch_shape
120
121        self.raw_transform = raw_transform
122        self.label_transform = label_transform
123        self.label_transform2 = label_transform2
124        self.transform = transform
125        self.sampler = sampler
126        self.with_padding = with_padding
127        self.pre_label_transform = pre_label_transform
128
129        self.dtype = dtype
130        self.label_dtype = label_dtype
131
132        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
133
134        self.z_ext = z_ext
135
136        self.sample_shape = patch_shape
137        self.trafo_halo = None
138        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
139        # which is then cut. See code below; but this ne needs to be properly tested
140
141        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
142        # if self.trafo_halo is not None:
143        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
144        #         self.trafo_halo = (0,) + self.trafo_halo
145        #     assert len(self.trafo_halo) == self._ndim
146        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
147        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
148
149    def __len__(self):
150        return self._len
151
152    @property
153    def ndim(self):
154        return self._ndim
155
156    def _sample_bounding_box(self):
157        if self.sample_shape is None:
158            if self.z_ext is None:
159                bb_start = [0] * len(self.shape)
160                patch_shape_for_bb = self.shape
161            else:
162                z_diff = self.shape[0] - self.z_ext
163                bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:])
164                patch_shape_for_bb = (self.z_ext, *self.shape[1:])
165
166        else:
167            bb_start = [
168                np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape)
169            ]
170            patch_shape_for_bb = self.sample_shape
171
172        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
173
174    def _get_desired_raw_and_labels(self):
175        bb = self._sample_bounding_box()
176        bb_raw = (slice(None),) + bb if self._with_channels else bb
177        bb_labels = (slice(None),) + bb if self._with_label_channels else bb
178        raw, labels = self.raw[bb_raw], self.labels[bb_labels]
179
180        # Additional label transform on top to make sampler consider expected labels
181        # (eg. run connected components on disconnected semantic labels)
182        pre_label_transform = getattr(self, "pre_label_transform", None)
183        if pre_label_transform is not None:
184            labels = pre_label_transform(labels)
185
186        return raw, labels
187
188    def _get_sample(self, index):
189        if self.raw is None or self.labels is None:
190            raise RuntimeError("SegmentationDataset has not been properly deserialized.")
191
192        raw, labels = self._get_desired_raw_and_labels()
193
194        if self.sampler is not None:
195            sample_id = 0
196            while not self.sampler(raw, labels):
197                raw, labels = self._get_desired_raw_and_labels()
198                sample_id += 1
199                if sample_id > self.max_sampling_attempts:
200                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
201
202        # Padding the patch to match the expected input shape.
203        if self.patch_shape is not None and self.with_padding:
204            raw, labels = ensure_patch_shape(
205                raw=raw,
206                labels=labels,
207                patch_shape=self.patch_shape,
208                have_raw_channels=self._with_channels,
209                have_label_channels=self._with_label_channels,
210            )
211
212        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
213        if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1:
214            raw = raw.squeeze(1 if self._with_channels else 0)
215            labels = labels.squeeze(1 if self._with_label_channels else 0)
216
217        return raw, labels
218
219    def crop(self, tensor):
220        """@private
221        """
222        bb = self.inner_bb
223        if tensor.ndim > len(bb):
224            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
225        return tensor[bb]
226
227    def __getitem__(self, index):
228        raw, labels = self._get_sample(index)
229        initial_label_dtype = labels.dtype
230
231        if self.raw_transform is not None:
232            raw = self.raw_transform(raw)
233
234        if self.label_transform is not None:
235            labels = self.label_transform(labels)
236
237        if self.transform is not None:
238            raw, labels = self.transform(raw, labels)
239            if self.trafo_halo is not None:
240                raw = self.crop(raw)
241                labels = self.crop(labels)
242
243        # support enlarging bounding box here as well (for affinity transform) ?
244        if self.label_transform2 is not None:
245            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
246            labels = self.label_transform2(labels)
247
248        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
249        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
250        return raw, labels
251
252    # need to overwrite pickle to support h5py
253    def __getstate__(self):
254        state = self.__dict__.copy()
255        del state["raw"]
256        del state["labels"]
257        return state
258
259    def __setstate__(self, state):
260        raw_path, raw_key = state["raw_path"], state["raw_key"]
261        label_path, label_key = state["label_path"], state["label_key"]
262        roi = state["roi"]
263        try:
264            raw = load_data(raw_path, raw_key)
265            if roi is not None:
266                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
267            state["raw"] = raw
268        except Exception:
269            msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
270            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
271            msg += "But it cannot be used for further training and will throw an error."
272            warnings.warn(msg)
273            state["raw"] = None
274
275        try:
276            labels = load_data(label_path, label_key)
277            if roi is not None:
278                labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
279                    RoiWrapper(labels, roi)
280            state["labels"] = labels
281        except Exception:
282            msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n"
283            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
284            msg += "But it cannot be used for further training and will throw an error."
285            warnings.warn(msg)
286            state["labels"] = None
287
288        self.__dict__.update(state)

Dataset that provides raw data and labels stored in a container data format for segmentation training.

The dataset loads a patch from the raw and label data and returns a sample for a batch. Image data and label data must have the same shape, except for potential channels. The dataset supports all file formats that can be opened with elf.io.open_file, such as hdf5, zarr or n5. Use raw_path / label_path to specify the file path and raw_key / label_key to specify the internal dataset. It also supports regular image formats, such as .tif. For these cases set raw_key=None / label_key=None.

Arguments:
  • raw_path: The file path to the raw image data. May also be a list of file paths.
  • raw_key: The key to the internal dataset containing the raw data.
  • label_path: The file path to the label data. May also be a list of file paths.
  • label_key: The key to the internal dataset containing the label data
  • patch_shape: The patch shape for a training sample.
  • raw_transform: Transformation applied to the raw data of a sample.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
  • pre_label_transform: Transformation applied to the label data of a chosen random sample, before applying the sample validity via the sampler.
SegmentationDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_path: Union[List[Any], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, transform: Optional[Callable] = None, roi: Union[slice, Tuple[slice, ...], NoneType] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, with_channels: bool = False, with_label_channels: bool = False, with_padding: bool = True, z_ext: Optional[int] = None, pre_label_transform: Optional[Callable] = None)
 65    def __init__(
 66        self,
 67        raw_path: Union[List[Any], str, os.PathLike],
 68        raw_key: Optional[str],
 69        label_path: Union[List[Any], str, os.PathLike],
 70        label_key: Optional[str],
 71        patch_shape: Tuple[int, ...],
 72        raw_transform: Optional[Callable] = None,
 73        label_transform: Optional[Callable] = None,
 74        label_transform2: Optional[Callable] = None,
 75        transform: Optional[Callable] = None,
 76        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 77        dtype: torch.dtype = torch.float32,
 78        label_dtype: torch.dtype = torch.float32,
 79        n_samples: Optional[int] = None,
 80        sampler: Optional[Callable] = None,
 81        ndim: Optional[int] = None,
 82        with_channels: bool = False,
 83        with_label_channels: bool = False,
 84        with_padding: bool = True,
 85        z_ext: Optional[int] = None,
 86        pre_label_transform: Optional[Callable] = None,
 87    ):
 88        self.raw_path = raw_path
 89        self.raw_key = raw_key
 90        self.raw = load_data(raw_path, raw_key)
 91
 92        self.label_path = label_path
 93        self.label_key = label_key
 94        self.labels = load_data(label_path, label_key)
 95
 96        self._with_channels = with_channels
 97        self._with_label_channels = with_label_channels
 98
 99        if roi is not None:
100            shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
101            roi = validate_roi(roi, shape, patch_shape)
102            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
103            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
104                RoiWrapper(self.labels, roi)
105
106        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
107        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
108        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
109
110        self.shape = shape_raw
111        self.roi = roi
112
113        self._ndim = len(shape_raw) if ndim is None else ndim
114        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
115
116        if patch_shape is not None:
117            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
118
119        self.patch_shape = patch_shape
120
121        self.raw_transform = raw_transform
122        self.label_transform = label_transform
123        self.label_transform2 = label_transform2
124        self.transform = transform
125        self.sampler = sampler
126        self.with_padding = with_padding
127        self.pre_label_transform = pre_label_transform
128
129        self.dtype = dtype
130        self.label_dtype = label_dtype
131
132        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
133
134        self.z_ext = z_ext
135
136        self.sample_shape = patch_shape
137        self.trafo_halo = None
138        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
139        # which is then cut. See code below; but this ne needs to be properly tested
140
141        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
142        # if self.trafo_halo is not None:
143        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
144        #         self.trafo_halo = (0,) + self.trafo_halo
145        #     assert len(self.trafo_halo) == self._ndim
146        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
147        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
max_sampling_attempts = 500

The maximal number of sampling attempts, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

@staticmethod
def compute_len(shape, patch_shape):
57    @staticmethod
58    def compute_len(shape, patch_shape):
59        if patch_shape is None:
60            return 1
61        else:
62            n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
63            return n_samples
raw_path
raw_key
raw
label_path
label_key
labels
shape
roi
patch_shape
raw_transform
label_transform
label_transform2
transform
sampler
with_padding
pre_label_transform
dtype
label_dtype
z_ext
sample_shape
trafo_halo
ndim
152    @property
153    def ndim(self):
154        return self._ndim