torch_em.data.raw_dataset

  1import os
  2import warnings
  3import numpy as np
  4from typing import List, Union, Tuple, Optional, Any, Callable
  5
  6import torch
  7
  8from elf.wrapper import RoiWrapper
  9
 10from ..util import ensure_tensor_with_channels, ensure_patch_shape, load_data, validate_roi
 11
 12
 13class RawDataset(torch.utils.data.Dataset):
 14    """Dataset that provides raw data stored in a container data format for unsupervised training.
 15
 16    The dataset loads a patch from the raw data and returns a sample for a batch.
 17    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 18    Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset.
 19    It also supports regular image formats, such as .tif. For these cases set `raw_key=None`.
 20
 21    The dataset can also be used for contrastive learning that relies on two different views of the same data.
 22    You can use the `augmentations` argument for this.
 23
 24    Args:
 25        raw_path: The file path to the raw image data. May also be a list of file paths.
 26        raw_key: The key to the internal dataset containing the raw data.
 27        patch_shape: The patch shape for a training sample.
 28        raw_transform: Transformation applied to the raw data of a sample.
 29        transform: Transformation to the raw data. This can be used to implement data augmentations.
 30        roi: Region of interest in the raw data.
 31            If given, the raw data will only be loaded from the corresponding area.
 32        dtype: The return data type of the raw data.
 33        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 34        sampler: Sampler for rejecting samples according to a defined criterion.
 35            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 36        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 37        with_channels: Whether the raw data has channels.
 38        augmentations: Augmentations for contrastive learning. If given, these need to be two different callables.
 39            They will be applied to the sampled raw data to return two independent views of the raw data.
 40    """
 41    max_sampling_attempts = 500
 42    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 43    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 44    """
 45
 46    @staticmethod
 47    def compute_len(shape, patch_shape):
 48        n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 49        return n_samples
 50
 51    def __init__(
 52        self,
 53        raw_path: Union[List[Any], str, os.PathLike],
 54        raw_key: Optional[str],
 55        patch_shape: Tuple[int, ...],
 56        raw_transform: Optional[Callable] = None,
 57        transform: Optional[Callable] = None,
 58        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 59        dtype: torch.dtype = torch.float32,
 60        n_samples: Optional[int] = None,
 61        sampler: Optional[Callable] = None,
 62        ndim: Optional[int] = None,
 63        with_channels: bool = False,
 64        augmentations: Optional[Tuple[Callable, Callable]] = None,
 65    ):
 66        self.raw_path = raw_path
 67        self.raw_key = raw_key
 68        self.raw = load_data(raw_path, raw_key)
 69
 70        self._with_channels = with_channels
 71
 72        if roi is not None:
 73            shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 74            roi = validate_roi(roi, shape, patch_shape)
 75            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 76
 77        self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 78        self.roi = roi
 79
 80        self._ndim = len(self.shape) if ndim is None else ndim
 81        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 82
 83        assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 84        self.patch_shape = patch_shape
 85
 86        self.raw_transform = raw_transform
 87        self.transform = transform
 88        self.sampler = sampler
 89        self.dtype = dtype
 90
 91        if augmentations is not None:
 92            assert len(augmentations) == 2
 93        self.augmentations = augmentations
 94
 95        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 96
 97        self.sample_shape = patch_shape
 98        self.trafo_halo = None
 99        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
100        # which is then cut. See code below; but this ne needs to be properly tested
101
102        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
103        # if self.trafo_halo is not None:
104        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
105        #         self.trafo_halo = (0,) + self.trafo_halo
106        #     assert len(self.trafo_halo) == self._ndim
107        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
108        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
109
110    def __len__(self):
111        return self._len
112
113    @property
114    def ndim(self):
115        return self._ndim
116
117    def _sample_bounding_box(self):
118        bb_start = [
119            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
120            for sh, psh in zip(self.shape, self.sample_shape)
121        ]
122        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape))
123
124    def _get_sample(self, index):
125        if self.raw is None:
126            raise RuntimeError("RawDataset has not been properly deserialized.")
127        bb = self._sample_bounding_box()
128        raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
129
130        if self.sampler is not None:
131            sample_id = 0
132            while not self.sampler(raw):
133                bb = self._sample_bounding_box()
134                raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
135                sample_id += 1
136                if sample_id > self.max_sampling_attempts:
137                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
138
139        if self.patch_shape is not None:
140            raw = ensure_patch_shape(
141                raw=raw, labels=None, patch_shape=self.patch_shape, have_raw_channels=self._with_channels
142            )
143
144        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
145        if len(self.patch_shape) == self._ndim + 1:
146            raw = raw.squeeze(1 if self._with_channels else 0)
147
148        return raw
149
150    def crop(self, tensor):
151        bb = self.inner_bb
152        if tensor.ndim > len(bb):
153            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
154        return tensor[bb]
155
156    def __getitem__(self, index):
157        raw = self._get_sample(index)
158
159        if self.raw_transform is not None:
160            raw = self.raw_transform(raw)
161
162        if self.transform is not None:
163            raw = self.transform(raw)
164            if isinstance(raw, list):
165                assert len(raw) == 1
166                raw = raw[0]
167
168            if self.trafo_halo is not None:
169                raw = self.crop(raw)
170
171        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
172        if self.augmentations is not None:
173            aug1, aug2 = self.augmentations
174            raw1, raw2 = aug1(raw), aug2(raw)
175            return raw1, raw2
176
177        return raw
178
179    # need to overwrite pickle to support h5py
180    def __getstate__(self):
181        state = self.__dict__.copy()
182        del state["raw"]
183        return state
184
185    def __setstate__(self, state):
186        raw_path, raw_key = state["raw_path"], state["raw_key"]
187        roi = state["roi"]
188        try:
189            raw = load_data(raw_path, raw_key)
190            if roi is not None:
191                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
192            state["raw"] = raw
193        except Exception:
194            msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
195            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
196            msg += "But it cannot be used for further training and wil throw an error."
197            warnings.warn(msg)
198            state["raw"] = None
199
200        self.__dict__.update(state)
201
202
203class RawDatasetWithMasks(RawDataset):
204    """Extends `RawDataset` to support a sample mask and a background mask.
205
206        - The sample mask is used by the sampler to extract patches from a region of interest, e.g.,
207            using `MinForegroundSampler`, to avoid empty patches.
208        - The background mask is a binary mask identifying regions or structures that belong to the background.
209            It can be used during unsupervised training to subtract background regions from the predicted
210            pseudo labels.
211
212    Args:
213        raw_path: The file path to the raw image data. May also be a list of file paths.
214        raw_key: The key to the internal dataset containing the raw data.
215        patch_shape: The patch shape for a training sample.
216        raw_transform: Transformation applied to the raw data of a sample.
217        transform: Transformation to the raw data. This can be used to implement data augmentations.
218        roi: Region of interest in the raw data.
219            If given, the raw data will only be loaded from the corresponding area.
220        dtype: The return data type of the raw data.
221        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
222        sampler: Sampler for rejecting samples according to a defined criterion.
223            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
224        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
225        with_channels: Whether the raw data has channels.
226        augmentations: Augmentations for contrastive learning. If given, these need to be two different callables.
227            They will be applied to the sampled raw data to return two independent views of the raw data.
228        sample_mask_path: Filepaths to the sample masks used by the sampler to accept or reject
229            patches for training.
230        sample_mask_key: The key to the dataset containing the sample masks.
231        bg_mask_path: Filepaths to the background masks, which will be returned together with the raw sample.
232        bg_mask_key: The key to the dataset containing the background masks.
233    """
234
235    def __init__(
236        self,
237        raw_path: Union[List[Any], str, os.PathLike],
238        raw_key: Optional[str],
239        patch_shape: Tuple[int, ...],
240        raw_transform: Optional[Callable] = None,
241        transform: Optional[Callable] = None,
242        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
243        dtype: torch.dtype = torch.float32,
244        n_samples: Optional[int] = None,
245        sampler: Optional[Callable] = None,
246        ndim: Optional[int] = None,
247        with_channels: bool = False,
248        augmentations: Optional[Tuple[Callable, Callable]] = None,
249        sample_mask_path: Union[List[Any], str, os.PathLike] = None,
250        sample_mask_key: Optional[str] = None,
251        bg_mask_path: Union[List[Any], str, os.PathLike] = None,
252        bg_mask_key: Optional[str] = None,
253    ):
254        super().__init__(
255            raw_path=raw_path,
256            raw_key=raw_key,
257            patch_shape=patch_shape,
258            raw_transform=raw_transform,
259            transform=transform,
260            roi=roi,
261            dtype=dtype,
262            n_samples=n_samples,
263            sampler=sampler,
264            ndim=ndim,
265            with_channels=with_channels,
266            augmentations=augmentations,
267        )
268
269        self.sample_mask_path = sample_mask_path
270        self.sample_mask_key = sample_mask_key
271        self.sample_mask = load_data(sample_mask_path, sample_mask_key) if sample_mask_path is not None else None
272
273        self.bg_mask_path = bg_mask_path
274        self.bg_mask_key = bg_mask_key
275        self.bg_mask = load_data(bg_mask_path, bg_mask_key) if bg_mask_path is not None else None
276
277    def _extract_patch(self, data, bb):
278        return data[(slice(None),) + bb] if self._with_channels else data[bb]
279
280    def _get_sample(self, index):
281        if self.raw is None:
282            raise RuntimeError("RawDataset has not been properly deserialized.")
283
284        # default behavior; use if sampler is None
285        bb = self._sample_bounding_box()
286        raw = self._extract_patch(self.raw, bb)
287
288        if self.sampler is not None:
289            sample_id = 0
290            if self.sample_mask is not None:
291                sample_mask = self._extract_patch(self.sample_mask, bb)
292
293                while not self.sampler(raw, sample_mask):
294                    bb = self._sample_bounding_box()
295                    raw = self._extract_patch(self.raw, bb)
296                    sample_mask = self._extract_patch(self.sample_mask, bb)
297
298                    sample_id += 1
299                    if sample_id > self.max_sampling_attempts:
300                        raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
301            else:
302                while not self.sampler(raw):
303                    bb = self._sample_bounding_box()
304                    raw = self._extract_patch(self.raw, bb)
305                    sample_id += 1
306                    if sample_id > self.max_sampling_attempts:
307                        raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
308
309        bg_mask = self._extract_patch(self.bg_mask, bb) if self.bg_mask is not None else None
310
311        if self.patch_shape is not None:
312            if bg_mask is not None:
313                raw, bg_mask = ensure_patch_shape(
314                    raw=raw, labels=bg_mask, patch_shape=self.patch_shape,
315                    have_raw_channels=self._with_channels, have_label_channels=self._with_channels
316                )
317            else:
318                raw = ensure_patch_shape(
319                    raw=raw, labels=None, patch_shape=self.patch_shape,
320                    have_raw_channels=self._with_channels, have_label_channels=self._with_channels
321                )
322        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
323        if len(self.patch_shape) == self._ndim + 1:
324            raw = raw.squeeze(1 if self._with_channels else 0)
325
326            if bg_mask is not None:
327                bg_mask = bg_mask.squeeze(1 if self._with_channels else 0)
328
329        return raw, bg_mask
330
331    def __getitem__(self, index):
332        raw, bg_mask = self._get_sample(index)
333
334        if self.raw_transform is not None:
335            raw = self.raw_transform(raw)
336
337        if self.transform is not None:
338            raw = self.transform(raw)
339            if isinstance(raw, list):
340                assert len(raw) == 1
341                raw = raw[0]
342
343            if self.trafo_halo is not None:
344                raw = self.crop(raw)
345
346        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
347        if bg_mask is not None:
348            bg_mask = ensure_tensor_with_channels(bg_mask, ndim=self._ndim, dtype=self.dtype)
349
350        if self.augmentations is not None:
351            aug1, aug2 = self.augmentations
352            raw1, raw2 = aug1(raw), aug2(raw)
353
354            if bg_mask is not None:
355
356                # if background_mask, returned stacked data
357                return torch.cat((raw1, bg_mask), dim=0), torch.cat((raw2, bg_mask), dim=0)
358
359            # else, return augmented raw
360            return raw1, raw2
361
362        if bg_mask is not None:
363
364            # if background_mask, returned stacked data
365            return torch.cat((raw, bg_mask), dim=0)
366
367        # else, return raw
368        return raw
369
370    def __getstate__(self):
371        state = super().__getstate__()
372        del state["sample_mask"]
373        del state["bg_mask"]
374        return state
375
376    def __setstate__(self, state):
377        super().__setstate__(state)
378        sample_mask_path = state.get("sample_mask_path")
379        sample_mask_key = state.get("sample_mask_key")
380        bg_mask_path = state.get("bg_mask_path")
381        bg_mask_key = state.get("bg_mask_key")
382        self.sample_mask = load_data(sample_mask_path, sample_mask_key) if sample_mask_path is not None else None
383        self.bg_mask = load_data(bg_mask_path, bg_mask_key) if bg_mask_path is not None else None
class RawDataset(typing.Generic[+_T_co]):
 14class RawDataset(torch.utils.data.Dataset):
 15    """Dataset that provides raw data stored in a container data format for unsupervised training.
 16
 17    The dataset loads a patch from the raw data and returns a sample for a batch.
 18    The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
 19    Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset.
 20    It also supports regular image formats, such as .tif. For these cases set `raw_key=None`.
 21
 22    The dataset can also be used for contrastive learning that relies on two different views of the same data.
 23    You can use the `augmentations` argument for this.
 24
 25    Args:
 26        raw_path: The file path to the raw image data. May also be a list of file paths.
 27        raw_key: The key to the internal dataset containing the raw data.
 28        patch_shape: The patch shape for a training sample.
 29        raw_transform: Transformation applied to the raw data of a sample.
 30        transform: Transformation to the raw data. This can be used to implement data augmentations.
 31        roi: Region of interest in the raw data.
 32            If given, the raw data will only be loaded from the corresponding area.
 33        dtype: The return data type of the raw data.
 34        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 35        sampler: Sampler for rejecting samples according to a defined criterion.
 36            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 37        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
 38        with_channels: Whether the raw data has channels.
 39        augmentations: Augmentations for contrastive learning. If given, these need to be two different callables.
 40            They will be applied to the sampled raw data to return two independent views of the raw data.
 41    """
 42    max_sampling_attempts = 500
 43    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 44    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 45    """
 46
 47    @staticmethod
 48    def compute_len(shape, patch_shape):
 49        n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 50        return n_samples
 51
 52    def __init__(
 53        self,
 54        raw_path: Union[List[Any], str, os.PathLike],
 55        raw_key: Optional[str],
 56        patch_shape: Tuple[int, ...],
 57        raw_transform: Optional[Callable] = None,
 58        transform: Optional[Callable] = None,
 59        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 60        dtype: torch.dtype = torch.float32,
 61        n_samples: Optional[int] = None,
 62        sampler: Optional[Callable] = None,
 63        ndim: Optional[int] = None,
 64        with_channels: bool = False,
 65        augmentations: Optional[Tuple[Callable, Callable]] = None,
 66    ):
 67        self.raw_path = raw_path
 68        self.raw_key = raw_key
 69        self.raw = load_data(raw_path, raw_key)
 70
 71        self._with_channels = with_channels
 72
 73        if roi is not None:
 74            shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 75            roi = validate_roi(roi, shape, patch_shape)
 76            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 77
 78        self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 79        self.roi = roi
 80
 81        self._ndim = len(self.shape) if ndim is None else ndim
 82        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 83
 84        assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 85        self.patch_shape = patch_shape
 86
 87        self.raw_transform = raw_transform
 88        self.transform = transform
 89        self.sampler = sampler
 90        self.dtype = dtype
 91
 92        if augmentations is not None:
 93            assert len(augmentations) == 2
 94        self.augmentations = augmentations
 95
 96        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 97
 98        self.sample_shape = patch_shape
 99        self.trafo_halo = None
100        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
101        # which is then cut. See code below; but this ne needs to be properly tested
102
103        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
104        # if self.trafo_halo is not None:
105        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
106        #         self.trafo_halo = (0,) + self.trafo_halo
107        #     assert len(self.trafo_halo) == self._ndim
108        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
109        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
110
111    def __len__(self):
112        return self._len
113
114    @property
115    def ndim(self):
116        return self._ndim
117
118    def _sample_bounding_box(self):
119        bb_start = [
120            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
121            for sh, psh in zip(self.shape, self.sample_shape)
122        ]
123        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape))
124
125    def _get_sample(self, index):
126        if self.raw is None:
127            raise RuntimeError("RawDataset has not been properly deserialized.")
128        bb = self._sample_bounding_box()
129        raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
130
131        if self.sampler is not None:
132            sample_id = 0
133            while not self.sampler(raw):
134                bb = self._sample_bounding_box()
135                raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
136                sample_id += 1
137                if sample_id > self.max_sampling_attempts:
138                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
139
140        if self.patch_shape is not None:
141            raw = ensure_patch_shape(
142                raw=raw, labels=None, patch_shape=self.patch_shape, have_raw_channels=self._with_channels
143            )
144
145        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
146        if len(self.patch_shape) == self._ndim + 1:
147            raw = raw.squeeze(1 if self._with_channels else 0)
148
149        return raw
150
151    def crop(self, tensor):
152        bb = self.inner_bb
153        if tensor.ndim > len(bb):
154            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
155        return tensor[bb]
156
157    def __getitem__(self, index):
158        raw = self._get_sample(index)
159
160        if self.raw_transform is not None:
161            raw = self.raw_transform(raw)
162
163        if self.transform is not None:
164            raw = self.transform(raw)
165            if isinstance(raw, list):
166                assert len(raw) == 1
167                raw = raw[0]
168
169            if self.trafo_halo is not None:
170                raw = self.crop(raw)
171
172        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
173        if self.augmentations is not None:
174            aug1, aug2 = self.augmentations
175            raw1, raw2 = aug1(raw), aug2(raw)
176            return raw1, raw2
177
178        return raw
179
180    # need to overwrite pickle to support h5py
181    def __getstate__(self):
182        state = self.__dict__.copy()
183        del state["raw"]
184        return state
185
186    def __setstate__(self, state):
187        raw_path, raw_key = state["raw_path"], state["raw_key"]
188        roi = state["roi"]
189        try:
190            raw = load_data(raw_path, raw_key)
191            if roi is not None:
192                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
193            state["raw"] = raw
194        except Exception:
195            msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
196            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
197            msg += "But it cannot be used for further training and wil throw an error."
198            warnings.warn(msg)
199            state["raw"] = None
200
201        self.__dict__.update(state)

Dataset that provides raw data stored in a container data format for unsupervised training.

The dataset loads a patch from the raw data and returns a sample for a batch. The dataset supports all file formats that can be opened with elf.io.open_file, such as hdf5, zarr or n5. Use raw_path to specify the path to the file and raw_key to specify the internal dataset. It also supports regular image formats, such as .tif. For these cases set raw_key=None.

The dataset can also be used for contrastive learning that relies on two different views of the same data. You can use the augmentations argument for this.

Arguments:
  • raw_path: The file path to the raw image data. May also be a list of file paths.
  • raw_key: The key to the internal dataset containing the raw data.
  • patch_shape: The patch shape for a training sample.
  • raw_transform: Transformation applied to the raw data of a sample.
  • transform: Transformation to the raw data. This can be used to implement data augmentations.
  • roi: Region of interest in the raw data. If given, the raw data will only be loaded from the corresponding area.
  • dtype: The return data type of the raw data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • with_channels: Whether the raw data has channels.
  • augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. They will be applied to the sampled raw data to return two independent views of the raw data.
RawDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: Optional[str], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, roi: Union[slice, Tuple[slice, ...], NoneType] = None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, with_channels: bool = False, augmentations: Optional[Tuple[Callable, Callable]] = None)
 52    def __init__(
 53        self,
 54        raw_path: Union[List[Any], str, os.PathLike],
 55        raw_key: Optional[str],
 56        patch_shape: Tuple[int, ...],
 57        raw_transform: Optional[Callable] = None,
 58        transform: Optional[Callable] = None,
 59        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
 60        dtype: torch.dtype = torch.float32,
 61        n_samples: Optional[int] = None,
 62        sampler: Optional[Callable] = None,
 63        ndim: Optional[int] = None,
 64        with_channels: bool = False,
 65        augmentations: Optional[Tuple[Callable, Callable]] = None,
 66    ):
 67        self.raw_path = raw_path
 68        self.raw_key = raw_key
 69        self.raw = load_data(raw_path, raw_key)
 70
 71        self._with_channels = with_channels
 72
 73        if roi is not None:
 74            shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 75            roi = validate_roi(roi, shape, patch_shape)
 76            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 77
 78        self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 79        self.roi = roi
 80
 81        self._ndim = len(self.shape) if ndim is None else ndim
 82        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 83
 84        assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 85        self.patch_shape = patch_shape
 86
 87        self.raw_transform = raw_transform
 88        self.transform = transform
 89        self.sampler = sampler
 90        self.dtype = dtype
 91
 92        if augmentations is not None:
 93            assert len(augmentations) == 2
 94        self.augmentations = augmentations
 95
 96        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 97
 98        self.sample_shape = patch_shape
 99        self.trafo_halo = None
100        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
101        # which is then cut. See code below; but this ne needs to be properly tested
102
103        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
104        # if self.trafo_halo is not None:
105        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
106        #         self.trafo_halo = (0,) + self.trafo_halo
107        #     assert len(self.trafo_halo) == self._ndim
108        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
109        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
max_sampling_attempts = 500

The maximal number of sampling attempts, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

@staticmethod
def compute_len(shape, patch_shape):
47    @staticmethod
48    def compute_len(shape, patch_shape):
49        n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
50        return n_samples
raw_path
raw_key
raw
shape
roi
patch_shape
raw_transform
transform
sampler
dtype
augmentations
sample_shape
trafo_halo
ndim
114    @property
115    def ndim(self):
116        return self._ndim
def crop(self, tensor):
151    def crop(self, tensor):
152        bb = self.inner_bb
153        if tensor.ndim > len(bb):
154            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
155        return tensor[bb]
class RawDatasetWithMasks(typing.Generic[+_T_co]):
204class RawDatasetWithMasks(RawDataset):
205    """Extends `RawDataset` to support a sample mask and a background mask.
206
207        - The sample mask is used by the sampler to extract patches from a region of interest, e.g.,
208            using `MinForegroundSampler`, to avoid empty patches.
209        - The background mask is a binary mask identifying regions or structures that belong to the background.
210            It can be used during unsupervised training to subtract background regions from the predicted
211            pseudo labels.
212
213    Args:
214        raw_path: The file path to the raw image data. May also be a list of file paths.
215        raw_key: The key to the internal dataset containing the raw data.
216        patch_shape: The patch shape for a training sample.
217        raw_transform: Transformation applied to the raw data of a sample.
218        transform: Transformation to the raw data. This can be used to implement data augmentations.
219        roi: Region of interest in the raw data.
220            If given, the raw data will only be loaded from the corresponding area.
221        dtype: The return data type of the raw data.
222        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
223        sampler: Sampler for rejecting samples according to a defined criterion.
224            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
225        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
226        with_channels: Whether the raw data has channels.
227        augmentations: Augmentations for contrastive learning. If given, these need to be two different callables.
228            They will be applied to the sampled raw data to return two independent views of the raw data.
229        sample_mask_path: Filepaths to the sample masks used by the sampler to accept or reject
230            patches for training.
231        sample_mask_key: The key to the dataset containing the sample masks.
232        bg_mask_path: Filepaths to the background masks, which will be returned together with the raw sample.
233        bg_mask_key: The key to the dataset containing the background masks.
234    """
235
236    def __init__(
237        self,
238        raw_path: Union[List[Any], str, os.PathLike],
239        raw_key: Optional[str],
240        patch_shape: Tuple[int, ...],
241        raw_transform: Optional[Callable] = None,
242        transform: Optional[Callable] = None,
243        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
244        dtype: torch.dtype = torch.float32,
245        n_samples: Optional[int] = None,
246        sampler: Optional[Callable] = None,
247        ndim: Optional[int] = None,
248        with_channels: bool = False,
249        augmentations: Optional[Tuple[Callable, Callable]] = None,
250        sample_mask_path: Union[List[Any], str, os.PathLike] = None,
251        sample_mask_key: Optional[str] = None,
252        bg_mask_path: Union[List[Any], str, os.PathLike] = None,
253        bg_mask_key: Optional[str] = None,
254    ):
255        super().__init__(
256            raw_path=raw_path,
257            raw_key=raw_key,
258            patch_shape=patch_shape,
259            raw_transform=raw_transform,
260            transform=transform,
261            roi=roi,
262            dtype=dtype,
263            n_samples=n_samples,
264            sampler=sampler,
265            ndim=ndim,
266            with_channels=with_channels,
267            augmentations=augmentations,
268        )
269
270        self.sample_mask_path = sample_mask_path
271        self.sample_mask_key = sample_mask_key
272        self.sample_mask = load_data(sample_mask_path, sample_mask_key) if sample_mask_path is not None else None
273
274        self.bg_mask_path = bg_mask_path
275        self.bg_mask_key = bg_mask_key
276        self.bg_mask = load_data(bg_mask_path, bg_mask_key) if bg_mask_path is not None else None
277
278    def _extract_patch(self, data, bb):
279        return data[(slice(None),) + bb] if self._with_channels else data[bb]
280
281    def _get_sample(self, index):
282        if self.raw is None:
283            raise RuntimeError("RawDataset has not been properly deserialized.")
284
285        # default behavior; use if sampler is None
286        bb = self._sample_bounding_box()
287        raw = self._extract_patch(self.raw, bb)
288
289        if self.sampler is not None:
290            sample_id = 0
291            if self.sample_mask is not None:
292                sample_mask = self._extract_patch(self.sample_mask, bb)
293
294                while not self.sampler(raw, sample_mask):
295                    bb = self._sample_bounding_box()
296                    raw = self._extract_patch(self.raw, bb)
297                    sample_mask = self._extract_patch(self.sample_mask, bb)
298
299                    sample_id += 1
300                    if sample_id > self.max_sampling_attempts:
301                        raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
302            else:
303                while not self.sampler(raw):
304                    bb = self._sample_bounding_box()
305                    raw = self._extract_patch(self.raw, bb)
306                    sample_id += 1
307                    if sample_id > self.max_sampling_attempts:
308                        raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
309
310        bg_mask = self._extract_patch(self.bg_mask, bb) if self.bg_mask is not None else None
311
312        if self.patch_shape is not None:
313            if bg_mask is not None:
314                raw, bg_mask = ensure_patch_shape(
315                    raw=raw, labels=bg_mask, patch_shape=self.patch_shape,
316                    have_raw_channels=self._with_channels, have_label_channels=self._with_channels
317                )
318            else:
319                raw = ensure_patch_shape(
320                    raw=raw, labels=None, patch_shape=self.patch_shape,
321                    have_raw_channels=self._with_channels, have_label_channels=self._with_channels
322                )
323        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
324        if len(self.patch_shape) == self._ndim + 1:
325            raw = raw.squeeze(1 if self._with_channels else 0)
326
327            if bg_mask is not None:
328                bg_mask = bg_mask.squeeze(1 if self._with_channels else 0)
329
330        return raw, bg_mask
331
332    def __getitem__(self, index):
333        raw, bg_mask = self._get_sample(index)
334
335        if self.raw_transform is not None:
336            raw = self.raw_transform(raw)
337
338        if self.transform is not None:
339            raw = self.transform(raw)
340            if isinstance(raw, list):
341                assert len(raw) == 1
342                raw = raw[0]
343
344            if self.trafo_halo is not None:
345                raw = self.crop(raw)
346
347        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
348        if bg_mask is not None:
349            bg_mask = ensure_tensor_with_channels(bg_mask, ndim=self._ndim, dtype=self.dtype)
350
351        if self.augmentations is not None:
352            aug1, aug2 = self.augmentations
353            raw1, raw2 = aug1(raw), aug2(raw)
354
355            if bg_mask is not None:
356
357                # if background_mask, returned stacked data
358                return torch.cat((raw1, bg_mask), dim=0), torch.cat((raw2, bg_mask), dim=0)
359
360            # else, return augmented raw
361            return raw1, raw2
362
363        if bg_mask is not None:
364
365            # if background_mask, returned stacked data
366            return torch.cat((raw, bg_mask), dim=0)
367
368        # else, return raw
369        return raw
370
371    def __getstate__(self):
372        state = super().__getstate__()
373        del state["sample_mask"]
374        del state["bg_mask"]
375        return state
376
377    def __setstate__(self, state):
378        super().__setstate__(state)
379        sample_mask_path = state.get("sample_mask_path")
380        sample_mask_key = state.get("sample_mask_key")
381        bg_mask_path = state.get("bg_mask_path")
382        bg_mask_key = state.get("bg_mask_key")
383        self.sample_mask = load_data(sample_mask_path, sample_mask_key) if sample_mask_path is not None else None
384        self.bg_mask = load_data(bg_mask_path, bg_mask_key) if bg_mask_path is not None else None

Extends RawDataset to support a sample mask and a background mask.

- The sample mask is used by the sampler to extract patches from a region of interest, e.g.,
    using `MinForegroundSampler`, to avoid empty patches.
- The background mask is a binary mask identifying regions or structures that belong to the background.
    It can be used during unsupervised training to subtract background regions from the predicted
    pseudo labels.
Arguments:
  • raw_path: The file path to the raw image data. May also be a list of file paths.
  • raw_key: The key to the internal dataset containing the raw data.
  • patch_shape: The patch shape for a training sample.
  • raw_transform: Transformation applied to the raw data of a sample.
  • transform: Transformation to the raw data. This can be used to implement data augmentations.
  • roi: Region of interest in the raw data. If given, the raw data will only be loaded from the corresponding area.
  • dtype: The return data type of the raw data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • with_channels: Whether the raw data has channels.
  • augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. They will be applied to the sampled raw data to return two independent views of the raw data.
  • sample_mask_path: Filepaths to the sample masks used by the sampler to accept or reject patches for training.
  • sample_mask_key: The key to the dataset containing the sample masks.
  • bg_mask_path: Filepaths to the background masks, which will be returned together with the raw sample.
  • bg_mask_key: The key to the dataset containing the background masks.
RawDatasetWithMasks( raw_path: Union[List[Any], str, os.PathLike], raw_key: Optional[str], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, roi: Union[slice, Tuple[slice, ...], NoneType] = None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, with_channels: bool = False, augmentations: Optional[Tuple[Callable, Callable]] = None, sample_mask_path: Union[List[Any], str, os.PathLike] = None, sample_mask_key: Optional[str] = None, bg_mask_path: Union[List[Any], str, os.PathLike] = None, bg_mask_key: Optional[str] = None)
236    def __init__(
237        self,
238        raw_path: Union[List[Any], str, os.PathLike],
239        raw_key: Optional[str],
240        patch_shape: Tuple[int, ...],
241        raw_transform: Optional[Callable] = None,
242        transform: Optional[Callable] = None,
243        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
244        dtype: torch.dtype = torch.float32,
245        n_samples: Optional[int] = None,
246        sampler: Optional[Callable] = None,
247        ndim: Optional[int] = None,
248        with_channels: bool = False,
249        augmentations: Optional[Tuple[Callable, Callable]] = None,
250        sample_mask_path: Union[List[Any], str, os.PathLike] = None,
251        sample_mask_key: Optional[str] = None,
252        bg_mask_path: Union[List[Any], str, os.PathLike] = None,
253        bg_mask_key: Optional[str] = None,
254    ):
255        super().__init__(
256            raw_path=raw_path,
257            raw_key=raw_key,
258            patch_shape=patch_shape,
259            raw_transform=raw_transform,
260            transform=transform,
261            roi=roi,
262            dtype=dtype,
263            n_samples=n_samples,
264            sampler=sampler,
265            ndim=ndim,
266            with_channels=with_channels,
267            augmentations=augmentations,
268        )
269
270        self.sample_mask_path = sample_mask_path
271        self.sample_mask_key = sample_mask_key
272        self.sample_mask = load_data(sample_mask_path, sample_mask_key) if sample_mask_path is not None else None
273
274        self.bg_mask_path = bg_mask_path
275        self.bg_mask_key = bg_mask_key
276        self.bg_mask = load_data(bg_mask_path, bg_mask_key) if bg_mask_path is not None else None
sample_mask_path
sample_mask_key
sample_mask
bg_mask_path
bg_mask_key
bg_mask