torch_em.data.raw_dataset
1import os 2import warnings 3import numpy as np 4from typing import List, Union, Tuple, Optional, Any, Callable 5 6import torch 7 8from elf.wrapper import RoiWrapper 9 10from ..util import ensure_tensor_with_channels, ensure_patch_shape, load_data, validate_roi 11 12 13class RawDataset(torch.utils.data.Dataset): 14 """Dataset that provides raw data stored in a container data format for unsupervised training. 15 16 The dataset loads a patch from the raw data and returns a sample for a batch. 17 The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 18 Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset. 19 It also supports regular image formats, such as .tif. For these cases set `raw_key=None`. 20 21 The dataset can also be used for contrastive learning that relies on two different views of the same data. 22 You can use the `augmentations` argument for this. 23 24 Args: 25 raw_path: The file path to the raw image data. May also be a list of file paths. 26 raw_key: The key to the internal dataset containing the raw data. 27 patch_shape: The patch shape for a training sample. 28 raw_transform: Transformation applied to the raw data of a sample. 29 transform: Transformation to the raw data. This can be used to implement data augmentations. 30 roi: Region of interest in the raw data. 31 If given, the raw data will only be loaded from the corresponding area. 32 dtype: The return data type of the raw data. 33 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 34 sampler: Sampler for rejecting samples according to a defined criterion. 35 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 36 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 37 with_channels: Whether the raw data has channels. 38 augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. 39 They will be applied to the sampled raw data to return two independent views of the raw data. 40 """ 41 max_sampling_attempts = 500 42 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 43 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 44 """ 45 46 @staticmethod 47 def compute_len(shape, patch_shape): 48 n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 49 return n_samples 50 51 def __init__( 52 self, 53 raw_path: Union[List[Any], str, os.PathLike], 54 raw_key: Optional[str], 55 patch_shape: Tuple[int, ...], 56 raw_transform: Optional[Callable] = None, 57 transform: Optional[Callable] = None, 58 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 59 dtype: torch.dtype = torch.float32, 60 n_samples: Optional[int] = None, 61 sampler: Optional[Callable] = None, 62 ndim: Optional[int] = None, 63 with_channels: bool = False, 64 augmentations: Optional[Tuple[Callable, Callable]] = None, 65 ): 66 self.raw_path = raw_path 67 self.raw_key = raw_key 68 self.raw = load_data(raw_path, raw_key) 69 70 self._with_channels = with_channels 71 72 if roi is not None: 73 shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 74 roi = validate_roi(roi, shape, patch_shape) 75 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 76 77 self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 78 self.roi = roi 79 80 self._ndim = len(self.shape) if ndim is None else ndim 81 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 82 83 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 84 self.patch_shape = patch_shape 85 86 self.raw_transform = raw_transform 87 self.transform = transform 88 self.sampler = sampler 89 self.dtype = dtype 90 91 if augmentations is not None: 92 assert len(augmentations) == 2 93 self.augmentations = augmentations 94 95 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 96 97 self.sample_shape = patch_shape 98 self.trafo_halo = None 99 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 100 # which is then cut. See code below; but this ne needs to be properly tested 101 102 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 103 # if self.trafo_halo is not None: 104 # if len(self.trafo_halo) == 2 and self._ndim == 3: 105 # self.trafo_halo = (0,) + self.trafo_halo 106 # assert len(self.trafo_halo) == self._ndim 107 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 108 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 109 110 def __len__(self): 111 return self._len 112 113 @property 114 def ndim(self): 115 return self._ndim 116 117 def _sample_bounding_box(self): 118 bb_start = [ 119 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 120 for sh, psh in zip(self.shape, self.sample_shape) 121 ] 122 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape)) 123 124 def _get_sample(self, index): 125 if self.raw is None: 126 raise RuntimeError("RawDataset has not been properly deserialized.") 127 bb = self._sample_bounding_box() 128 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 129 130 if self.sampler is not None: 131 sample_id = 0 132 while not self.sampler(raw): 133 bb = self._sample_bounding_box() 134 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 135 sample_id += 1 136 if sample_id > self.max_sampling_attempts: 137 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 138 139 if self.patch_shape is not None: 140 raw = ensure_patch_shape( 141 raw=raw, labels=None, patch_shape=self.patch_shape, have_raw_channels=self._with_channels 142 ) 143 144 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 145 if len(self.patch_shape) == self._ndim + 1: 146 raw = raw.squeeze(1 if self._with_channels else 0) 147 148 return raw 149 150 def crop(self, tensor): 151 bb = self.inner_bb 152 if tensor.ndim > len(bb): 153 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 154 return tensor[bb] 155 156 def __getitem__(self, index): 157 raw = self._get_sample(index) 158 159 if self.raw_transform is not None: 160 raw = self.raw_transform(raw) 161 162 if self.transform is not None: 163 raw = self.transform(raw) 164 if isinstance(raw, list): 165 assert len(raw) == 1 166 raw = raw[0] 167 168 if self.trafo_halo is not None: 169 raw = self.crop(raw) 170 171 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 172 if self.augmentations is not None: 173 aug1, aug2 = self.augmentations 174 raw1, raw2 = aug1(raw), aug2(raw) 175 return raw1, raw2 176 177 return raw 178 179 # need to overwrite pickle to support h5py 180 def __getstate__(self): 181 state = self.__dict__.copy() 182 del state["raw"] 183 return state 184 185 def __setstate__(self, state): 186 raw_path, raw_key = state["raw_path"], state["raw_key"] 187 roi = state["roi"] 188 try: 189 raw = load_data(raw_path, raw_key) 190 if roi is not None: 191 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 192 state["raw"] = raw 193 except Exception: 194 msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 195 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 196 msg += "But it cannot be used for further training and wil throw an error." 197 warnings.warn(msg) 198 state["raw"] = None 199 200 self.__dict__.update(state) 201 202 203class RawDatasetWithMasks(RawDataset): 204 """Extends `RawDataset` to support a sample mask and a background mask. 205 206 - The sample mask is used by the sampler to extract patches from a region of interest, e.g., 207 using `MinForegroundSampler`, to avoid empty patches. 208 - The background mask is a binary mask identifying regions or structures that belong to the background. 209 It can be used during unsupervised training to subtract background regions from the predicted 210 pseudo labels. 211 212 Args: 213 raw_path: The file path to the raw image data. May also be a list of file paths. 214 raw_key: The key to the internal dataset containing the raw data. 215 patch_shape: The patch shape for a training sample. 216 raw_transform: Transformation applied to the raw data of a sample. 217 transform: Transformation to the raw data. This can be used to implement data augmentations. 218 roi: Region of interest in the raw data. 219 If given, the raw data will only be loaded from the corresponding area. 220 dtype: The return data type of the raw data. 221 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 222 sampler: Sampler for rejecting samples according to a defined criterion. 223 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 224 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 225 with_channels: Whether the raw data has channels. 226 augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. 227 They will be applied to the sampled raw data to return two independent views of the raw data. 228 sample_mask_path: Filepaths to the sample masks used by the sampler to accept or reject 229 patches for training. 230 sample_mask_key: The key to the dataset containing the sample masks. 231 bg_mask_path: Filepaths to the background masks, which will be returned together with the raw sample. 232 bg_mask_key: The key to the dataset containing the background masks. 233 """ 234 235 def __init__( 236 self, 237 raw_path: Union[List[Any], str, os.PathLike], 238 raw_key: Optional[str], 239 patch_shape: Tuple[int, ...], 240 raw_transform: Optional[Callable] = None, 241 transform: Optional[Callable] = None, 242 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 243 dtype: torch.dtype = torch.float32, 244 n_samples: Optional[int] = None, 245 sampler: Optional[Callable] = None, 246 ndim: Optional[int] = None, 247 with_channels: bool = False, 248 augmentations: Optional[Tuple[Callable, Callable]] = None, 249 sample_mask_path: Union[List[Any], str, os.PathLike] = None, 250 sample_mask_key: Optional[str] = None, 251 bg_mask_path: Union[List[Any], str, os.PathLike] = None, 252 bg_mask_key: Optional[str] = None, 253 ): 254 super().__init__( 255 raw_path=raw_path, 256 raw_key=raw_key, 257 patch_shape=patch_shape, 258 raw_transform=raw_transform, 259 transform=transform, 260 roi=roi, 261 dtype=dtype, 262 n_samples=n_samples, 263 sampler=sampler, 264 ndim=ndim, 265 with_channels=with_channels, 266 augmentations=augmentations, 267 ) 268 269 self.sample_mask_path = sample_mask_path 270 self.sample_mask_key = sample_mask_key 271 self.sample_mask = load_data(sample_mask_path, sample_mask_key) if sample_mask_path is not None else None 272 273 self.bg_mask_path = bg_mask_path 274 self.bg_mask_key = bg_mask_key 275 self.bg_mask = load_data(bg_mask_path, bg_mask_key) if bg_mask_path is not None else None 276 277 def _extract_patch(self, data, bb): 278 return data[(slice(None),) + bb] if self._with_channels else data[bb] 279 280 def _get_sample(self, index): 281 if self.raw is None: 282 raise RuntimeError("RawDataset has not been properly deserialized.") 283 284 # default behavior; use if sampler is None 285 bb = self._sample_bounding_box() 286 raw = self._extract_patch(self.raw, bb) 287 288 if self.sampler is not None: 289 sample_id = 0 290 if self.sample_mask is not None: 291 sample_mask = self._extract_patch(self.sample_mask, bb) 292 293 while not self.sampler(raw, sample_mask): 294 bb = self._sample_bounding_box() 295 raw = self._extract_patch(self.raw, bb) 296 sample_mask = self._extract_patch(self.sample_mask, bb) 297 298 sample_id += 1 299 if sample_id > self.max_sampling_attempts: 300 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 301 else: 302 while not self.sampler(raw): 303 bb = self._sample_bounding_box() 304 raw = self._extract_patch(self.raw, bb) 305 sample_id += 1 306 if sample_id > self.max_sampling_attempts: 307 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 308 309 bg_mask = self._extract_patch(self.bg_mask, bb) if self.bg_mask is not None else None 310 311 if self.patch_shape is not None: 312 if bg_mask is not None: 313 raw, bg_mask = ensure_patch_shape( 314 raw=raw, labels=bg_mask, patch_shape=self.patch_shape, 315 have_raw_channels=self._with_channels, have_label_channels=self._with_channels 316 ) 317 else: 318 raw = ensure_patch_shape( 319 raw=raw, labels=None, patch_shape=self.patch_shape, 320 have_raw_channels=self._with_channels, have_label_channels=self._with_channels 321 ) 322 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 323 if len(self.patch_shape) == self._ndim + 1: 324 raw = raw.squeeze(1 if self._with_channels else 0) 325 326 if bg_mask is not None: 327 bg_mask = bg_mask.squeeze(1 if self._with_channels else 0) 328 329 return raw, bg_mask 330 331 def __getitem__(self, index): 332 raw, bg_mask = self._get_sample(index) 333 334 if self.raw_transform is not None: 335 raw = self.raw_transform(raw) 336 337 if self.transform is not None: 338 raw = self.transform(raw) 339 if isinstance(raw, list): 340 assert len(raw) == 1 341 raw = raw[0] 342 343 if self.trafo_halo is not None: 344 raw = self.crop(raw) 345 346 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 347 if bg_mask is not None: 348 bg_mask = ensure_tensor_with_channels(bg_mask, ndim=self._ndim, dtype=self.dtype) 349 350 if self.augmentations is not None: 351 aug1, aug2 = self.augmentations 352 raw1, raw2 = aug1(raw), aug2(raw) 353 354 if bg_mask is not None: 355 356 # if background_mask, returned stacked data 357 return torch.cat((raw1, bg_mask), dim=0), torch.cat((raw2, bg_mask), dim=0) 358 359 # else, return augmented raw 360 return raw1, raw2 361 362 if bg_mask is not None: 363 364 # if background_mask, returned stacked data 365 return torch.cat((raw, bg_mask), dim=0) 366 367 # else, return raw 368 return raw 369 370 def __getstate__(self): 371 state = super().__getstate__() 372 del state["sample_mask"] 373 del state["bg_mask"] 374 return state 375 376 def __setstate__(self, state): 377 super().__setstate__(state) 378 sample_mask_path = state.get("sample_mask_path") 379 sample_mask_key = state.get("sample_mask_key") 380 bg_mask_path = state.get("bg_mask_path") 381 bg_mask_key = state.get("bg_mask_key") 382 self.sample_mask = load_data(sample_mask_path, sample_mask_key) if sample_mask_path is not None else None 383 self.bg_mask = load_data(bg_mask_path, bg_mask_key) if bg_mask_path is not None else None
14class RawDataset(torch.utils.data.Dataset): 15 """Dataset that provides raw data stored in a container data format for unsupervised training. 16 17 The dataset loads a patch from the raw data and returns a sample for a batch. 18 The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 19 Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset. 20 It also supports regular image formats, such as .tif. For these cases set `raw_key=None`. 21 22 The dataset can also be used for contrastive learning that relies on two different views of the same data. 23 You can use the `augmentations` argument for this. 24 25 Args: 26 raw_path: The file path to the raw image data. May also be a list of file paths. 27 raw_key: The key to the internal dataset containing the raw data. 28 patch_shape: The patch shape for a training sample. 29 raw_transform: Transformation applied to the raw data of a sample. 30 transform: Transformation to the raw data. This can be used to implement data augmentations. 31 roi: Region of interest in the raw data. 32 If given, the raw data will only be loaded from the corresponding area. 33 dtype: The return data type of the raw data. 34 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 35 sampler: Sampler for rejecting samples according to a defined criterion. 36 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 37 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 38 with_channels: Whether the raw data has channels. 39 augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. 40 They will be applied to the sampled raw data to return two independent views of the raw data. 41 """ 42 max_sampling_attempts = 500 43 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 44 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 45 """ 46 47 @staticmethod 48 def compute_len(shape, patch_shape): 49 n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 50 return n_samples 51 52 def __init__( 53 self, 54 raw_path: Union[List[Any], str, os.PathLike], 55 raw_key: Optional[str], 56 patch_shape: Tuple[int, ...], 57 raw_transform: Optional[Callable] = None, 58 transform: Optional[Callable] = None, 59 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 60 dtype: torch.dtype = torch.float32, 61 n_samples: Optional[int] = None, 62 sampler: Optional[Callable] = None, 63 ndim: Optional[int] = None, 64 with_channels: bool = False, 65 augmentations: Optional[Tuple[Callable, Callable]] = None, 66 ): 67 self.raw_path = raw_path 68 self.raw_key = raw_key 69 self.raw = load_data(raw_path, raw_key) 70 71 self._with_channels = with_channels 72 73 if roi is not None: 74 shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 75 roi = validate_roi(roi, shape, patch_shape) 76 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 77 78 self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 79 self.roi = roi 80 81 self._ndim = len(self.shape) if ndim is None else ndim 82 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 83 84 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 85 self.patch_shape = patch_shape 86 87 self.raw_transform = raw_transform 88 self.transform = transform 89 self.sampler = sampler 90 self.dtype = dtype 91 92 if augmentations is not None: 93 assert len(augmentations) == 2 94 self.augmentations = augmentations 95 96 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 97 98 self.sample_shape = patch_shape 99 self.trafo_halo = None 100 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 101 # which is then cut. See code below; but this ne needs to be properly tested 102 103 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 104 # if self.trafo_halo is not None: 105 # if len(self.trafo_halo) == 2 and self._ndim == 3: 106 # self.trafo_halo = (0,) + self.trafo_halo 107 # assert len(self.trafo_halo) == self._ndim 108 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 109 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 110 111 def __len__(self): 112 return self._len 113 114 @property 115 def ndim(self): 116 return self._ndim 117 118 def _sample_bounding_box(self): 119 bb_start = [ 120 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 121 for sh, psh in zip(self.shape, self.sample_shape) 122 ] 123 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape)) 124 125 def _get_sample(self, index): 126 if self.raw is None: 127 raise RuntimeError("RawDataset has not been properly deserialized.") 128 bb = self._sample_bounding_box() 129 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 130 131 if self.sampler is not None: 132 sample_id = 0 133 while not self.sampler(raw): 134 bb = self._sample_bounding_box() 135 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 136 sample_id += 1 137 if sample_id > self.max_sampling_attempts: 138 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 139 140 if self.patch_shape is not None: 141 raw = ensure_patch_shape( 142 raw=raw, labels=None, patch_shape=self.patch_shape, have_raw_channels=self._with_channels 143 ) 144 145 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 146 if len(self.patch_shape) == self._ndim + 1: 147 raw = raw.squeeze(1 if self._with_channels else 0) 148 149 return raw 150 151 def crop(self, tensor): 152 bb = self.inner_bb 153 if tensor.ndim > len(bb): 154 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 155 return tensor[bb] 156 157 def __getitem__(self, index): 158 raw = self._get_sample(index) 159 160 if self.raw_transform is not None: 161 raw = self.raw_transform(raw) 162 163 if self.transform is not None: 164 raw = self.transform(raw) 165 if isinstance(raw, list): 166 assert len(raw) == 1 167 raw = raw[0] 168 169 if self.trafo_halo is not None: 170 raw = self.crop(raw) 171 172 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 173 if self.augmentations is not None: 174 aug1, aug2 = self.augmentations 175 raw1, raw2 = aug1(raw), aug2(raw) 176 return raw1, raw2 177 178 return raw 179 180 # need to overwrite pickle to support h5py 181 def __getstate__(self): 182 state = self.__dict__.copy() 183 del state["raw"] 184 return state 185 186 def __setstate__(self, state): 187 raw_path, raw_key = state["raw_path"], state["raw_key"] 188 roi = state["roi"] 189 try: 190 raw = load_data(raw_path, raw_key) 191 if roi is not None: 192 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 193 state["raw"] = raw 194 except Exception: 195 msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 196 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 197 msg += "But it cannot be used for further training and wil throw an error." 198 warnings.warn(msg) 199 state["raw"] = None 200 201 self.__dict__.update(state)
Dataset that provides raw data stored in a container data format for unsupervised training.
The dataset loads a patch from the raw data and returns a sample for a batch.
The dataset supports all file formats that can be opened with elf.io.open_file, such as hdf5, zarr or n5.
Use raw_path to specify the path to the file and raw_key to specify the internal dataset.
It also supports regular image formats, such as .tif. For these cases set raw_key=None.
The dataset can also be used for contrastive learning that relies on two different views of the same data.
You can use the augmentations argument for this.
Arguments:
- raw_path: The file path to the raw image data. May also be a list of file paths.
- raw_key: The key to the internal dataset containing the raw data.
- patch_shape: The patch shape for a training sample.
- raw_transform: Transformation applied to the raw data of a sample.
- transform: Transformation to the raw data. This can be used to implement data augmentations.
- roi: Region of interest in the raw data. If given, the raw data will only be loaded from the corresponding area.
- dtype: The return data type of the raw data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths). - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- with_channels: Whether the raw data has channels.
- augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. They will be applied to the sampled raw data to return two independent views of the raw data.
52 def __init__( 53 self, 54 raw_path: Union[List[Any], str, os.PathLike], 55 raw_key: Optional[str], 56 patch_shape: Tuple[int, ...], 57 raw_transform: Optional[Callable] = None, 58 transform: Optional[Callable] = None, 59 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 60 dtype: torch.dtype = torch.float32, 61 n_samples: Optional[int] = None, 62 sampler: Optional[Callable] = None, 63 ndim: Optional[int] = None, 64 with_channels: bool = False, 65 augmentations: Optional[Tuple[Callable, Callable]] = None, 66 ): 67 self.raw_path = raw_path 68 self.raw_key = raw_key 69 self.raw = load_data(raw_path, raw_key) 70 71 self._with_channels = with_channels 72 73 if roi is not None: 74 shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 75 roi = validate_roi(roi, shape, patch_shape) 76 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 77 78 self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 79 self.roi = roi 80 81 self._ndim = len(self.shape) if ndim is None else ndim 82 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 83 84 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 85 self.patch_shape = patch_shape 86 87 self.raw_transform = raw_transform 88 self.transform = transform 89 self.sampler = sampler 90 self.dtype = dtype 91 92 if augmentations is not None: 93 assert len(augmentations) == 2 94 self.augmentations = augmentations 95 96 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 97 98 self.sample_shape = patch_shape 99 self.trafo_halo = None 100 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 101 # which is then cut. See code below; but this ne needs to be properly tested 102 103 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 104 # if self.trafo_halo is not None: 105 # if len(self.trafo_halo) == 2 and self._ndim == 3: 106 # self.trafo_halo = (0,) + self.trafo_halo 107 # assert len(self.trafo_halo) == self._ndim 108 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 109 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
The maximal number of sampling attempts, for loading a sample via __getitem__.
This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.
204class RawDatasetWithMasks(RawDataset): 205 """Extends `RawDataset` to support a sample mask and a background mask. 206 207 - The sample mask is used by the sampler to extract patches from a region of interest, e.g., 208 using `MinForegroundSampler`, to avoid empty patches. 209 - The background mask is a binary mask identifying regions or structures that belong to the background. 210 It can be used during unsupervised training to subtract background regions from the predicted 211 pseudo labels. 212 213 Args: 214 raw_path: The file path to the raw image data. May also be a list of file paths. 215 raw_key: The key to the internal dataset containing the raw data. 216 patch_shape: The patch shape for a training sample. 217 raw_transform: Transformation applied to the raw data of a sample. 218 transform: Transformation to the raw data. This can be used to implement data augmentations. 219 roi: Region of interest in the raw data. 220 If given, the raw data will only be loaded from the corresponding area. 221 dtype: The return data type of the raw data. 222 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 223 sampler: Sampler for rejecting samples according to a defined criterion. 224 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 225 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 226 with_channels: Whether the raw data has channels. 227 augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. 228 They will be applied to the sampled raw data to return two independent views of the raw data. 229 sample_mask_path: Filepaths to the sample masks used by the sampler to accept or reject 230 patches for training. 231 sample_mask_key: The key to the dataset containing the sample masks. 232 bg_mask_path: Filepaths to the background masks, which will be returned together with the raw sample. 233 bg_mask_key: The key to the dataset containing the background masks. 234 """ 235 236 def __init__( 237 self, 238 raw_path: Union[List[Any], str, os.PathLike], 239 raw_key: Optional[str], 240 patch_shape: Tuple[int, ...], 241 raw_transform: Optional[Callable] = None, 242 transform: Optional[Callable] = None, 243 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 244 dtype: torch.dtype = torch.float32, 245 n_samples: Optional[int] = None, 246 sampler: Optional[Callable] = None, 247 ndim: Optional[int] = None, 248 with_channels: bool = False, 249 augmentations: Optional[Tuple[Callable, Callable]] = None, 250 sample_mask_path: Union[List[Any], str, os.PathLike] = None, 251 sample_mask_key: Optional[str] = None, 252 bg_mask_path: Union[List[Any], str, os.PathLike] = None, 253 bg_mask_key: Optional[str] = None, 254 ): 255 super().__init__( 256 raw_path=raw_path, 257 raw_key=raw_key, 258 patch_shape=patch_shape, 259 raw_transform=raw_transform, 260 transform=transform, 261 roi=roi, 262 dtype=dtype, 263 n_samples=n_samples, 264 sampler=sampler, 265 ndim=ndim, 266 with_channels=with_channels, 267 augmentations=augmentations, 268 ) 269 270 self.sample_mask_path = sample_mask_path 271 self.sample_mask_key = sample_mask_key 272 self.sample_mask = load_data(sample_mask_path, sample_mask_key) if sample_mask_path is not None else None 273 274 self.bg_mask_path = bg_mask_path 275 self.bg_mask_key = bg_mask_key 276 self.bg_mask = load_data(bg_mask_path, bg_mask_key) if bg_mask_path is not None else None 277 278 def _extract_patch(self, data, bb): 279 return data[(slice(None),) + bb] if self._with_channels else data[bb] 280 281 def _get_sample(self, index): 282 if self.raw is None: 283 raise RuntimeError("RawDataset has not been properly deserialized.") 284 285 # default behavior; use if sampler is None 286 bb = self._sample_bounding_box() 287 raw = self._extract_patch(self.raw, bb) 288 289 if self.sampler is not None: 290 sample_id = 0 291 if self.sample_mask is not None: 292 sample_mask = self._extract_patch(self.sample_mask, bb) 293 294 while not self.sampler(raw, sample_mask): 295 bb = self._sample_bounding_box() 296 raw = self._extract_patch(self.raw, bb) 297 sample_mask = self._extract_patch(self.sample_mask, bb) 298 299 sample_id += 1 300 if sample_id > self.max_sampling_attempts: 301 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 302 else: 303 while not self.sampler(raw): 304 bb = self._sample_bounding_box() 305 raw = self._extract_patch(self.raw, bb) 306 sample_id += 1 307 if sample_id > self.max_sampling_attempts: 308 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 309 310 bg_mask = self._extract_patch(self.bg_mask, bb) if self.bg_mask is not None else None 311 312 if self.patch_shape is not None: 313 if bg_mask is not None: 314 raw, bg_mask = ensure_patch_shape( 315 raw=raw, labels=bg_mask, patch_shape=self.patch_shape, 316 have_raw_channels=self._with_channels, have_label_channels=self._with_channels 317 ) 318 else: 319 raw = ensure_patch_shape( 320 raw=raw, labels=None, patch_shape=self.patch_shape, 321 have_raw_channels=self._with_channels, have_label_channels=self._with_channels 322 ) 323 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 324 if len(self.patch_shape) == self._ndim + 1: 325 raw = raw.squeeze(1 if self._with_channels else 0) 326 327 if bg_mask is not None: 328 bg_mask = bg_mask.squeeze(1 if self._with_channels else 0) 329 330 return raw, bg_mask 331 332 def __getitem__(self, index): 333 raw, bg_mask = self._get_sample(index) 334 335 if self.raw_transform is not None: 336 raw = self.raw_transform(raw) 337 338 if self.transform is not None: 339 raw = self.transform(raw) 340 if isinstance(raw, list): 341 assert len(raw) == 1 342 raw = raw[0] 343 344 if self.trafo_halo is not None: 345 raw = self.crop(raw) 346 347 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 348 if bg_mask is not None: 349 bg_mask = ensure_tensor_with_channels(bg_mask, ndim=self._ndim, dtype=self.dtype) 350 351 if self.augmentations is not None: 352 aug1, aug2 = self.augmentations 353 raw1, raw2 = aug1(raw), aug2(raw) 354 355 if bg_mask is not None: 356 357 # if background_mask, returned stacked data 358 return torch.cat((raw1, bg_mask), dim=0), torch.cat((raw2, bg_mask), dim=0) 359 360 # else, return augmented raw 361 return raw1, raw2 362 363 if bg_mask is not None: 364 365 # if background_mask, returned stacked data 366 return torch.cat((raw, bg_mask), dim=0) 367 368 # else, return raw 369 return raw 370 371 def __getstate__(self): 372 state = super().__getstate__() 373 del state["sample_mask"] 374 del state["bg_mask"] 375 return state 376 377 def __setstate__(self, state): 378 super().__setstate__(state) 379 sample_mask_path = state.get("sample_mask_path") 380 sample_mask_key = state.get("sample_mask_key") 381 bg_mask_path = state.get("bg_mask_path") 382 bg_mask_key = state.get("bg_mask_key") 383 self.sample_mask = load_data(sample_mask_path, sample_mask_key) if sample_mask_path is not None else None 384 self.bg_mask = load_data(bg_mask_path, bg_mask_key) if bg_mask_path is not None else None
Extends RawDataset to support a sample mask and a background mask.
- The sample mask is used by the sampler to extract patches from a region of interest, e.g.,
using `MinForegroundSampler`, to avoid empty patches.
- The background mask is a binary mask identifying regions or structures that belong to the background.
It can be used during unsupervised training to subtract background regions from the predicted
pseudo labels.
Arguments:
- raw_path: The file path to the raw image data. May also be a list of file paths.
- raw_key: The key to the internal dataset containing the raw data.
- patch_shape: The patch shape for a training sample.
- raw_transform: Transformation applied to the raw data of a sample.
- transform: Transformation to the raw data. This can be used to implement data augmentations.
- roi: Region of interest in the raw data. If given, the raw data will only be loaded from the corresponding area.
- dtype: The return data type of the raw data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths). - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- with_channels: Whether the raw data has channels.
- augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. They will be applied to the sampled raw data to return two independent views of the raw data.
- sample_mask_path: Filepaths to the sample masks used by the sampler to accept or reject patches for training.
- sample_mask_key: The key to the dataset containing the sample masks.
- bg_mask_path: Filepaths to the background masks, which will be returned together with the raw sample.
- bg_mask_key: The key to the dataset containing the background masks.
236 def __init__( 237 self, 238 raw_path: Union[List[Any], str, os.PathLike], 239 raw_key: Optional[str], 240 patch_shape: Tuple[int, ...], 241 raw_transform: Optional[Callable] = None, 242 transform: Optional[Callable] = None, 243 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 244 dtype: torch.dtype = torch.float32, 245 n_samples: Optional[int] = None, 246 sampler: Optional[Callable] = None, 247 ndim: Optional[int] = None, 248 with_channels: bool = False, 249 augmentations: Optional[Tuple[Callable, Callable]] = None, 250 sample_mask_path: Union[List[Any], str, os.PathLike] = None, 251 sample_mask_key: Optional[str] = None, 252 bg_mask_path: Union[List[Any], str, os.PathLike] = None, 253 bg_mask_key: Optional[str] = None, 254 ): 255 super().__init__( 256 raw_path=raw_path, 257 raw_key=raw_key, 258 patch_shape=patch_shape, 259 raw_transform=raw_transform, 260 transform=transform, 261 roi=roi, 262 dtype=dtype, 263 n_samples=n_samples, 264 sampler=sampler, 265 ndim=ndim, 266 with_channels=with_channels, 267 augmentations=augmentations, 268 ) 269 270 self.sample_mask_path = sample_mask_path 271 self.sample_mask_key = sample_mask_key 272 self.sample_mask = load_data(sample_mask_path, sample_mask_key) if sample_mask_path is not None else None 273 274 self.bg_mask_path = bg_mask_path 275 self.bg_mask_key = bg_mask_key 276 self.bg_mask = load_data(bg_mask_path, bg_mask_key) if bg_mask_path is not None else None