torch_em.data.segmentation_dataset

  1import os
  2import warnings
  3import numpy as np
  4from typing import List, Union, Tuple, Optional, Any, Callable
  5
  6import torch
  7
  8from elf.wrapper import RoiWrapper
  9
 10from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data, ensure_patch_shape
 11
 12
 13class SegmentationDataset(torch.utils.data.Dataset):
 14    """Dataset that provides raw data and labels stored in a container data format for segmentation training.
 15
 16    The dataset loads a patch from the raw and label data and returns a sample for a batch.
 17    Image data and label data must have the same shape, except for potential channels.
 18    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 19    Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset.
 20    It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`.
 21
 22    Args:
 23        raw_path: The file path to the raw image data. May also be a list of file paths.
 24        raw_key: The key to the internal dataset containing the raw data.
 25        label_path: The file path to the label data. May also be a list of file paths.
 26        label_key: The key to the internal dataset containing the label data
 27        patch_shape: The patch shape for a training sample.
 28        raw_transform: Transformation applied to the raw data of a sample.
 29        label_transform: Transformation applied to the label data of a sample,
 30            before applying augmentations via `transform`.
 31        label_transform2: Transformation applied to the label data of a sample,
 32            after applying augmentations via `transform`.
 33        transform: Transformation applied to both the raw data and label data of a sample.
 34            This can be used to implement data augmentations.
 35        roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
 36        dtype: The return data type of the raw data.
 37        label_dtype: The return data type of the label data.
 38        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 39        sampler: Sampler for rejecting samples according to a defined criterion.
 40            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 41        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 42        with_channels: Whether the raw data has channels.
 43        with_label_channels: Whether the label data has channels.
 44        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 45        z_ext: Extra bounding box for loading the data across z.
 46    """
 47    max_sampling_attempts = 500
 48    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 49    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 50    """
 51
 52    @staticmethod
 53    def compute_len(shape, patch_shape):
 54        if patch_shape is None:
 55            return 1
 56        else:
 57            n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 58            return n_samples
 59
 60    def __init__(
 61        self,
 62        raw_path: Union[List[Any], str, os.PathLike],
 63        raw_key: Optional[str],
 64        label_path: Union[List[Any], str, os.PathLike],
 65        label_key: Optional[str],
 66        patch_shape: Tuple[int, ...],
 67        raw_transform: Optional[Callable] = None,
 68        label_transform: Optional[Callable] = None,
 69        label_transform2: Optional[Callable] = None,
 70        transform: Optional[Callable] = None,
 71        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 72        dtype: torch.dtype = torch.float32,
 73        label_dtype: torch.dtype = torch.float32,
 74        n_samples: Optional[int] = None,
 75        sampler: Optional[Callable] = None,
 76        ndim: Optional[int] = None,
 77        with_channels: bool = False,
 78        with_label_channels: bool = False,
 79        with_padding: bool = True,
 80        z_ext: Optional[int] = None,
 81    ):
 82        self.raw_path = raw_path
 83        self.raw_key = raw_key
 84        self.raw = load_data(raw_path, raw_key)
 85
 86        self.label_path = label_path
 87        self.label_key = label_key
 88        self.labels = load_data(label_path, label_key)
 89
 90        self._with_channels = with_channels
 91        self._with_label_channels = with_label_channels
 92
 93        if roi is not None:
 94            if isinstance(roi, slice):
 95                roi = (roi,)
 96
 97            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 98            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
 99                RoiWrapper(self.labels, roi)
100
101        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
102        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
103        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
104
105        self.shape = shape_raw
106        self.roi = roi
107
108        self._ndim = len(shape_raw) if ndim is None else ndim
109        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
110
111        if patch_shape is not None:
112            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
113
114        self.patch_shape = patch_shape
115
116        self.raw_transform = raw_transform
117        self.label_transform = label_transform
118        self.label_transform2 = label_transform2
119        self.transform = transform
120        self.sampler = sampler
121        self.with_padding = with_padding
122
123        self.dtype = dtype
124        self.label_dtype = label_dtype
125
126        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
127
128        self.z_ext = z_ext
129
130        self.sample_shape = patch_shape
131        self.trafo_halo = None
132        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
133        # which is then cut. See code below; but this ne needs to be properly tested
134
135        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
136        # if self.trafo_halo is not None:
137        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
138        #         self.trafo_halo = (0,) + self.trafo_halo
139        #     assert len(self.trafo_halo) == self._ndim
140        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
141        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
142
143    def __len__(self):
144        return self._len
145
146    @property
147    def ndim(self):
148        return self._ndim
149
150    def _sample_bounding_box(self):
151        if self.sample_shape is None:
152            if self.z_ext is None:
153                bb_start = [0] * len(self.shape)
154                patch_shape_for_bb = self.shape
155            else:
156                z_diff = self.shape[0] - self.z_ext
157                bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:])
158                patch_shape_for_bb = (self.z_ext, *self.shape[1:])
159
160        else:
161            bb_start = [
162                np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape)
163            ]
164            patch_shape_for_bb = self.sample_shape
165
166        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
167
168    def _get_desired_raw_and_labels(self):
169        bb = self._sample_bounding_box()
170        bb_raw = (slice(None),) + bb if self._with_channels else bb
171        bb_labels = (slice(None),) + bb if self._with_label_channels else bb
172        raw, labels = self.raw[bb_raw], self.labels[bb_labels]
173        return raw, labels
174
175    def _get_sample(self, index):
176        if self.raw is None or self.labels is None:
177            raise RuntimeError("SegmentationDataset has not been properly deserialized.")
178
179        raw, labels = self._get_desired_raw_and_labels()
180
181        if self.sampler is not None:
182            sample_id = 0
183            while not self.sampler(raw, labels):
184                raw, labels = self._get_desired_raw_and_labels()
185                sample_id += 1
186                if sample_id > self.max_sampling_attempts:
187                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
188
189        # Padding the patch to match the expected input shape.
190        if self.patch_shape is not None and self.with_padding:
191            raw, labels = ensure_patch_shape(
192                raw=raw,
193                labels=labels,
194                patch_shape=self.patch_shape,
195                have_raw_channels=self._with_channels,
196                have_label_channels=self._with_label_channels,
197            )
198
199        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
200        if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1:
201            raw = raw.squeeze(1 if self._with_channels else 0)
202            labels = labels.squeeze(1 if self._with_label_channels else 0)
203
204        return raw, labels
205
206    def crop(self, tensor):
207        """@private
208        """
209        bb = self.inner_bb
210        if tensor.ndim > len(bb):
211            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
212        return tensor[bb]
213
214    def __getitem__(self, index):
215        raw, labels = self._get_sample(index)
216        initial_label_dtype = labels.dtype
217
218        if self.raw_transform is not None:
219            raw = self.raw_transform(raw)
220
221        if self.label_transform is not None:
222            labels = self.label_transform(labels)
223
224        if self.transform is not None:
225            raw, labels = self.transform(raw, labels)
226            if self.trafo_halo is not None:
227                raw = self.crop(raw)
228                labels = self.crop(labels)
229
230        # support enlarging bounding box here as well (for affinity transform) ?
231        if self.label_transform2 is not None:
232            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
233            labels = self.label_transform2(labels)
234
235        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
236        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
237        return raw, labels
238
239    # need to overwrite pickle to support h5py
240    def __getstate__(self):
241        state = self.__dict__.copy()
242        del state["raw"]
243        del state["labels"]
244        return state
245
246    def __setstate__(self, state):
247        raw_path, raw_key = state["raw_path"], state["raw_key"]
248        label_path, label_key = state["label_path"], state["label_key"]
249        roi = state["roi"]
250        try:
251            raw = load_data(raw_path, raw_key)
252            if roi is not None:
253                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
254            state["raw"] = raw
255        except Exception:
256            msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
257            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
258            msg += "But it cannot be used for further training and will throw an error."
259            warnings.warn(msg)
260            state["raw"] = None
261
262        try:
263            labels = load_data(label_path, label_key)
264            if roi is not None:
265                labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
266                    RoiWrapper(labels, roi)
267            state["labels"] = labels
268        except Exception:
269            msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n"
270            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
271            msg += "But it cannot be used for further training and will throw an error."
272            warnings.warn(msg)
273            state["labels"] = None
274
275        self.__dict__.update(state)
class SegmentationDataset(typing.Generic[+_T_co]):
 14class SegmentationDataset(torch.utils.data.Dataset):
 15    """Dataset that provides raw data and labels stored in a container data format for segmentation training.
 16
 17    The dataset loads a patch from the raw and label data and returns a sample for a batch.
 18    Image data and label data must have the same shape, except for potential channels.
 19    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 20    Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset.
 21    It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`.
 22
 23    Args:
 24        raw_path: The file path to the raw image data. May also be a list of file paths.
 25        raw_key: The key to the internal dataset containing the raw data.
 26        label_path: The file path to the label data. May also be a list of file paths.
 27        label_key: The key to the internal dataset containing the label data
 28        patch_shape: The patch shape for a training sample.
 29        raw_transform: Transformation applied to the raw data of a sample.
 30        label_transform: Transformation applied to the label data of a sample,
 31            before applying augmentations via `transform`.
 32        label_transform2: Transformation applied to the label data of a sample,
 33            after applying augmentations via `transform`.
 34        transform: Transformation applied to both the raw data and label data of a sample.
 35            This can be used to implement data augmentations.
 36        roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
 37        dtype: The return data type of the raw data.
 38        label_dtype: The return data type of the label data.
 39        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 40        sampler: Sampler for rejecting samples according to a defined criterion.
 41            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 42        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 43        with_channels: Whether the raw data has channels.
 44        with_label_channels: Whether the label data has channels.
 45        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 46        z_ext: Extra bounding box for loading the data across z.
 47    """
 48    max_sampling_attempts = 500
 49    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 50    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 51    """
 52
 53    @staticmethod
 54    def compute_len(shape, patch_shape):
 55        if patch_shape is None:
 56            return 1
 57        else:
 58            n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 59            return n_samples
 60
 61    def __init__(
 62        self,
 63        raw_path: Union[List[Any], str, os.PathLike],
 64        raw_key: Optional[str],
 65        label_path: Union[List[Any], str, os.PathLike],
 66        label_key: Optional[str],
 67        patch_shape: Tuple[int, ...],
 68        raw_transform: Optional[Callable] = None,
 69        label_transform: Optional[Callable] = None,
 70        label_transform2: Optional[Callable] = None,
 71        transform: Optional[Callable] = None,
 72        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 73        dtype: torch.dtype = torch.float32,
 74        label_dtype: torch.dtype = torch.float32,
 75        n_samples: Optional[int] = None,
 76        sampler: Optional[Callable] = None,
 77        ndim: Optional[int] = None,
 78        with_channels: bool = False,
 79        with_label_channels: bool = False,
 80        with_padding: bool = True,
 81        z_ext: Optional[int] = None,
 82    ):
 83        self.raw_path = raw_path
 84        self.raw_key = raw_key
 85        self.raw = load_data(raw_path, raw_key)
 86
 87        self.label_path = label_path
 88        self.label_key = label_key
 89        self.labels = load_data(label_path, label_key)
 90
 91        self._with_channels = with_channels
 92        self._with_label_channels = with_label_channels
 93
 94        if roi is not None:
 95            if isinstance(roi, slice):
 96                roi = (roi,)
 97
 98            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 99            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
100                RoiWrapper(self.labels, roi)
101
102        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
103        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
104        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
105
106        self.shape = shape_raw
107        self.roi = roi
108
109        self._ndim = len(shape_raw) if ndim is None else ndim
110        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
111
112        if patch_shape is not None:
113            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
114
115        self.patch_shape = patch_shape
116
117        self.raw_transform = raw_transform
118        self.label_transform = label_transform
119        self.label_transform2 = label_transform2
120        self.transform = transform
121        self.sampler = sampler
122        self.with_padding = with_padding
123
124        self.dtype = dtype
125        self.label_dtype = label_dtype
126
127        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
128
129        self.z_ext = z_ext
130
131        self.sample_shape = patch_shape
132        self.trafo_halo = None
133        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
134        # which is then cut. See code below; but this ne needs to be properly tested
135
136        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
137        # if self.trafo_halo is not None:
138        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
139        #         self.trafo_halo = (0,) + self.trafo_halo
140        #     assert len(self.trafo_halo) == self._ndim
141        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
142        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
143
144    def __len__(self):
145        return self._len
146
147    @property
148    def ndim(self):
149        return self._ndim
150
151    def _sample_bounding_box(self):
152        if self.sample_shape is None:
153            if self.z_ext is None:
154                bb_start = [0] * len(self.shape)
155                patch_shape_for_bb = self.shape
156            else:
157                z_diff = self.shape[0] - self.z_ext
158                bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:])
159                patch_shape_for_bb = (self.z_ext, *self.shape[1:])
160
161        else:
162            bb_start = [
163                np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape)
164            ]
165            patch_shape_for_bb = self.sample_shape
166
167        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
168
169    def _get_desired_raw_and_labels(self):
170        bb = self._sample_bounding_box()
171        bb_raw = (slice(None),) + bb if self._with_channels else bb
172        bb_labels = (slice(None),) + bb if self._with_label_channels else bb
173        raw, labels = self.raw[bb_raw], self.labels[bb_labels]
174        return raw, labels
175
176    def _get_sample(self, index):
177        if self.raw is None or self.labels is None:
178            raise RuntimeError("SegmentationDataset has not been properly deserialized.")
179
180        raw, labels = self._get_desired_raw_and_labels()
181
182        if self.sampler is not None:
183            sample_id = 0
184            while not self.sampler(raw, labels):
185                raw, labels = self._get_desired_raw_and_labels()
186                sample_id += 1
187                if sample_id > self.max_sampling_attempts:
188                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
189
190        # Padding the patch to match the expected input shape.
191        if self.patch_shape is not None and self.with_padding:
192            raw, labels = ensure_patch_shape(
193                raw=raw,
194                labels=labels,
195                patch_shape=self.patch_shape,
196                have_raw_channels=self._with_channels,
197                have_label_channels=self._with_label_channels,
198            )
199
200        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
201        if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1:
202            raw = raw.squeeze(1 if self._with_channels else 0)
203            labels = labels.squeeze(1 if self._with_label_channels else 0)
204
205        return raw, labels
206
207    def crop(self, tensor):
208        """@private
209        """
210        bb = self.inner_bb
211        if tensor.ndim > len(bb):
212            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
213        return tensor[bb]
214
215    def __getitem__(self, index):
216        raw, labels = self._get_sample(index)
217        initial_label_dtype = labels.dtype
218
219        if self.raw_transform is not None:
220            raw = self.raw_transform(raw)
221
222        if self.label_transform is not None:
223            labels = self.label_transform(labels)
224
225        if self.transform is not None:
226            raw, labels = self.transform(raw, labels)
227            if self.trafo_halo is not None:
228                raw = self.crop(raw)
229                labels = self.crop(labels)
230
231        # support enlarging bounding box here as well (for affinity transform) ?
232        if self.label_transform2 is not None:
233            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
234            labels = self.label_transform2(labels)
235
236        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
237        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
238        return raw, labels
239
240    # need to overwrite pickle to support h5py
241    def __getstate__(self):
242        state = self.__dict__.copy()
243        del state["raw"]
244        del state["labels"]
245        return state
246
247    def __setstate__(self, state):
248        raw_path, raw_key = state["raw_path"], state["raw_key"]
249        label_path, label_key = state["label_path"], state["label_key"]
250        roi = state["roi"]
251        try:
252            raw = load_data(raw_path, raw_key)
253            if roi is not None:
254                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
255            state["raw"] = raw
256        except Exception:
257            msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
258            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
259            msg += "But it cannot be used for further training and will throw an error."
260            warnings.warn(msg)
261            state["raw"] = None
262
263        try:
264            labels = load_data(label_path, label_key)
265            if roi is not None:
266                labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
267                    RoiWrapper(labels, roi)
268            state["labels"] = labels
269        except Exception:
270            msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n"
271            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
272            msg += "But it cannot be used for further training and will throw an error."
273            warnings.warn(msg)
274            state["labels"] = None
275
276        self.__dict__.update(state)

Dataset that provides raw data and labels stored in a container data format for segmentation training.

The dataset loads a patch from the raw and label data and returns a sample for a batch. Image data and label data must have the same shape, except for potential channels. The dataset supports all file formats that can be opened with elf.io.open_file, such as hdf5, zarr or n5. Use raw_path / label_path to specify the file path and raw_key / label_key to specify the internal dataset. It also supports regular image formats, such as .tif. For these cases set raw_key=None / label_key=None.

Arguments:
  • raw_path: The file path to the raw image data. May also be a list of file paths.
  • raw_key: The key to the internal dataset containing the raw data.
  • label_path: The file path to the label data. May also be a list of file paths.
  • label_key: The key to the internal dataset containing the label data
  • patch_shape: The patch shape for a training sample.
  • raw_transform: Transformation applied to the raw data of a sample.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
SegmentationDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_path: Union[List[Any], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, transform: Optional[Callable] = None, roi: Union[slice, Tuple[slice, ...], NoneType] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, with_channels: bool = False, with_label_channels: bool = False, with_padding: bool = True, z_ext: Optional[int] = None)
 61    def __init__(
 62        self,
 63        raw_path: Union[List[Any], str, os.PathLike],
 64        raw_key: Optional[str],
 65        label_path: Union[List[Any], str, os.PathLike],
 66        label_key: Optional[str],
 67        patch_shape: Tuple[int, ...],
 68        raw_transform: Optional[Callable] = None,
 69        label_transform: Optional[Callable] = None,
 70        label_transform2: Optional[Callable] = None,
 71        transform: Optional[Callable] = None,
 72        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 73        dtype: torch.dtype = torch.float32,
 74        label_dtype: torch.dtype = torch.float32,
 75        n_samples: Optional[int] = None,
 76        sampler: Optional[Callable] = None,
 77        ndim: Optional[int] = None,
 78        with_channels: bool = False,
 79        with_label_channels: bool = False,
 80        with_padding: bool = True,
 81        z_ext: Optional[int] = None,
 82    ):
 83        self.raw_path = raw_path
 84        self.raw_key = raw_key
 85        self.raw = load_data(raw_path, raw_key)
 86
 87        self.label_path = label_path
 88        self.label_key = label_key
 89        self.labels = load_data(label_path, label_key)
 90
 91        self._with_channels = with_channels
 92        self._with_label_channels = with_label_channels
 93
 94        if roi is not None:
 95            if isinstance(roi, slice):
 96                roi = (roi,)
 97
 98            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 99            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
100                RoiWrapper(self.labels, roi)
101
102        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
103        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
104        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
105
106        self.shape = shape_raw
107        self.roi = roi
108
109        self._ndim = len(shape_raw) if ndim is None else ndim
110        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
111
112        if patch_shape is not None:
113            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
114
115        self.patch_shape = patch_shape
116
117        self.raw_transform = raw_transform
118        self.label_transform = label_transform
119        self.label_transform2 = label_transform2
120        self.transform = transform
121        self.sampler = sampler
122        self.with_padding = with_padding
123
124        self.dtype = dtype
125        self.label_dtype = label_dtype
126
127        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
128
129        self.z_ext = z_ext
130
131        self.sample_shape = patch_shape
132        self.trafo_halo = None
133        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
134        # which is then cut. See code below; but this ne needs to be properly tested
135
136        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
137        # if self.trafo_halo is not None:
138        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
139        #         self.trafo_halo = (0,) + self.trafo_halo
140        #     assert len(self.trafo_halo) == self._ndim
141        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
142        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
max_sampling_attempts = 500

The maximal number of sampling attempts, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

@staticmethod
def compute_len(shape, patch_shape):
53    @staticmethod
54    def compute_len(shape, patch_shape):
55        if patch_shape is None:
56            return 1
57        else:
58            n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
59            return n_samples
raw_path
raw_key
raw
label_path
label_key
labels
shape
roi
patch_shape
raw_transform
label_transform
label_transform2
transform
sampler
with_padding
dtype
label_dtype
z_ext
sample_shape
trafo_halo
ndim
147    @property
148    def ndim(self):
149        return self._ndim