torch_em.data.segmentation_dataset
1import os 2import warnings 3from typing import List, Union, Tuple, Optional, Any, Callable 4 5import numpy as np 6from math import ceil 7 8import torch 9 10from elf.wrapper import RoiWrapper 11 12from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data, ensure_patch_shape 13 14 15class SegmentationDataset(torch.utils.data.Dataset): 16 """Dataset that provides raw data and labels stored in a container data format for segmentation training. 17 18 The dataset loads a patch from the raw and label data and returns a sample for a batch. 19 Image data and label data must have the same shape, except for potential channels. 20 The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 21 Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset. 22 It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`. 23 24 Args: 25 raw_path: The file path to the raw image data. May also be a list of file paths. 26 raw_key: The key to the internal dataset containing the raw data. 27 label_path: The file path to the label data. May also be a list of file paths. 28 label_key: The key to the internal dataset containing the label data 29 patch_shape: The patch shape for a training sample. 30 raw_transform: Transformation applied to the raw data of a sample. 31 label_transform: Transformation applied to the label data of a sample, 32 before applying augmentations via `transform`. 33 label_transform2: Transformation applied to the label data of a sample, 34 after applying augmentations via `transform`. 35 transform: Transformation applied to both the raw data and label data of a sample. 36 This can be used to implement data augmentations. 37 roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area. 38 dtype: The return data type of the raw data. 39 label_dtype: The return data type of the label data. 40 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 41 sampler: Sampler for rejecting samples according to a defined criterion. 42 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 43 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 44 with_channels: Whether the raw data has channels. 45 with_label_channels: Whether the label data has channels. 46 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 47 z_ext: Extra bounding box for loading the data across z. 48 pre_label_transform: Transformation applied to the label data of a chosen random sample, 49 before applying the sample validity via the `sampler`. 50 """ 51 max_sampling_attempts = 500 52 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 53 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 54 """ 55 56 @staticmethod 57 def compute_len(shape, patch_shape): 58 if patch_shape is None: 59 return 1 60 else: 61 n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 62 return n_samples 63 64 def __init__( 65 self, 66 raw_path: Union[List[Any], str, os.PathLike], 67 raw_key: Optional[str], 68 label_path: Union[List[Any], str, os.PathLike], 69 label_key: Optional[str], 70 patch_shape: Tuple[int, ...], 71 raw_transform: Optional[Callable] = None, 72 label_transform: Optional[Callable] = None, 73 label_transform2: Optional[Callable] = None, 74 transform: Optional[Callable] = None, 75 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 76 dtype: torch.dtype = torch.float32, 77 label_dtype: torch.dtype = torch.float32, 78 n_samples: Optional[int] = None, 79 sampler: Optional[Callable] = None, 80 ndim: Optional[int] = None, 81 with_channels: bool = False, 82 with_label_channels: bool = False, 83 with_padding: bool = True, 84 z_ext: Optional[int] = None, 85 pre_label_transform: Optional[Callable] = None, 86 ): 87 self.raw_path = raw_path 88 self.raw_key = raw_key 89 self.raw = load_data(raw_path, raw_key) 90 91 self.label_path = label_path 92 self.label_key = label_key 93 self.labels = load_data(label_path, label_key) 94 95 self._with_channels = with_channels 96 self._with_label_channels = with_label_channels 97 98 if roi is not None: 99 if isinstance(roi, slice): 100 roi = (roi,) 101 102 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 103 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 104 RoiWrapper(self.labels, roi) 105 106 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 107 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 108 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 109 110 self.shape = shape_raw 111 self.roi = roi 112 113 self._ndim = len(shape_raw) if ndim is None else ndim 114 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 115 116 if patch_shape is not None: 117 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 118 119 self.patch_shape = patch_shape 120 121 self.raw_transform = raw_transform 122 self.label_transform = label_transform 123 self.label_transform2 = label_transform2 124 self.transform = transform 125 self.sampler = sampler 126 self.with_padding = with_padding 127 self.pre_label_transform = pre_label_transform 128 129 self.dtype = dtype 130 self.label_dtype = label_dtype 131 132 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 133 134 self.z_ext = z_ext 135 136 self.sample_shape = patch_shape 137 self.trafo_halo = None 138 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 139 # which is then cut. See code below; but this ne needs to be properly tested 140 141 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 142 # if self.trafo_halo is not None: 143 # if len(self.trafo_halo) == 2 and self._ndim == 3: 144 # self.trafo_halo = (0,) + self.trafo_halo 145 # assert len(self.trafo_halo) == self._ndim 146 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 147 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 148 149 def __len__(self): 150 return self._len 151 152 @property 153 def ndim(self): 154 return self._ndim 155 156 def _sample_bounding_box(self): 157 if self.sample_shape is None: 158 if self.z_ext is None: 159 bb_start = [0] * len(self.shape) 160 patch_shape_for_bb = self.shape 161 else: 162 z_diff = self.shape[0] - self.z_ext 163 bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:]) 164 patch_shape_for_bb = (self.z_ext, *self.shape[1:]) 165 166 else: 167 bb_start = [ 168 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape) 169 ] 170 patch_shape_for_bb = self.sample_shape 171 172 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 173 174 def _get_desired_raw_and_labels(self): 175 bb = self._sample_bounding_box() 176 bb_raw = (slice(None),) + bb if self._with_channels else bb 177 bb_labels = (slice(None),) + bb if self._with_label_channels else bb 178 raw, labels = self.raw[bb_raw], self.labels[bb_labels] 179 180 # Additional label transform on top to make sampler consider expected labels 181 # (eg. run connected components on disconnected semantic labels) 182 pre_label_transform = getattr(self, "pre_label_transform", None) 183 if pre_label_transform is not None: 184 labels = pre_label_transform(labels) 185 186 return raw, labels 187 188 def _get_sample(self, index): 189 if self.raw is None or self.labels is None: 190 raise RuntimeError("SegmentationDataset has not been properly deserialized.") 191 192 raw, labels = self._get_desired_raw_and_labels() 193 194 if self.sampler is not None: 195 sample_id = 0 196 while not self.sampler(raw, labels): 197 raw, labels = self._get_desired_raw_and_labels() 198 sample_id += 1 199 if sample_id > self.max_sampling_attempts: 200 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 201 202 # Padding the patch to match the expected input shape. 203 if self.patch_shape is not None and self.with_padding: 204 raw, labels = ensure_patch_shape( 205 raw=raw, 206 labels=labels, 207 patch_shape=self.patch_shape, 208 have_raw_channels=self._with_channels, 209 have_label_channels=self._with_label_channels, 210 ) 211 212 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 213 if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1: 214 raw = raw.squeeze(1 if self._with_channels else 0) 215 labels = labels.squeeze(1 if self._with_label_channels else 0) 216 217 return raw, labels 218 219 def crop(self, tensor): 220 """@private 221 """ 222 bb = self.inner_bb 223 if tensor.ndim > len(bb): 224 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 225 return tensor[bb] 226 227 def __getitem__(self, index): 228 raw, labels = self._get_sample(index) 229 initial_label_dtype = labels.dtype 230 231 if self.raw_transform is not None: 232 raw = self.raw_transform(raw) 233 234 if self.label_transform is not None: 235 labels = self.label_transform(labels) 236 237 if self.transform is not None: 238 raw, labels = self.transform(raw, labels) 239 if self.trafo_halo is not None: 240 raw = self.crop(raw) 241 labels = self.crop(labels) 242 243 # support enlarging bounding box here as well (for affinity transform) ? 244 if self.label_transform2 is not None: 245 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 246 labels = self.label_transform2(labels) 247 248 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 249 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 250 return raw, labels 251 252 # need to overwrite pickle to support h5py 253 def __getstate__(self): 254 state = self.__dict__.copy() 255 del state["raw"] 256 del state["labels"] 257 return state 258 259 def __setstate__(self, state): 260 raw_path, raw_key = state["raw_path"], state["raw_key"] 261 label_path, label_key = state["label_path"], state["label_key"] 262 roi = state["roi"] 263 try: 264 raw = load_data(raw_path, raw_key) 265 if roi is not None: 266 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 267 state["raw"] = raw 268 except Exception: 269 msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 270 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 271 msg += "But it cannot be used for further training and will throw an error." 272 warnings.warn(msg) 273 state["raw"] = None 274 275 try: 276 labels = load_data(label_path, label_key) 277 if roi is not None: 278 labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\ 279 RoiWrapper(labels, roi) 280 state["labels"] = labels 281 except Exception: 282 msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n" 283 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 284 msg += "But it cannot be used for further training and will throw an error." 285 warnings.warn(msg) 286 state["labels"] = None 287 288 self.__dict__.update(state)
16class SegmentationDataset(torch.utils.data.Dataset): 17 """Dataset that provides raw data and labels stored in a container data format for segmentation training. 18 19 The dataset loads a patch from the raw and label data and returns a sample for a batch. 20 Image data and label data must have the same shape, except for potential channels. 21 The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 22 Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset. 23 It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`. 24 25 Args: 26 raw_path: The file path to the raw image data. May also be a list of file paths. 27 raw_key: The key to the internal dataset containing the raw data. 28 label_path: The file path to the label data. May also be a list of file paths. 29 label_key: The key to the internal dataset containing the label data 30 patch_shape: The patch shape for a training sample. 31 raw_transform: Transformation applied to the raw data of a sample. 32 label_transform: Transformation applied to the label data of a sample, 33 before applying augmentations via `transform`. 34 label_transform2: Transformation applied to the label data of a sample, 35 after applying augmentations via `transform`. 36 transform: Transformation applied to both the raw data and label data of a sample. 37 This can be used to implement data augmentations. 38 roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area. 39 dtype: The return data type of the raw data. 40 label_dtype: The return data type of the label data. 41 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 42 sampler: Sampler for rejecting samples according to a defined criterion. 43 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 44 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 45 with_channels: Whether the raw data has channels. 46 with_label_channels: Whether the label data has channels. 47 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 48 z_ext: Extra bounding box for loading the data across z. 49 pre_label_transform: Transformation applied to the label data of a chosen random sample, 50 before applying the sample validity via the `sampler`. 51 """ 52 max_sampling_attempts = 500 53 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 54 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 55 """ 56 57 @staticmethod 58 def compute_len(shape, patch_shape): 59 if patch_shape is None: 60 return 1 61 else: 62 n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 63 return n_samples 64 65 def __init__( 66 self, 67 raw_path: Union[List[Any], str, os.PathLike], 68 raw_key: Optional[str], 69 label_path: Union[List[Any], str, os.PathLike], 70 label_key: Optional[str], 71 patch_shape: Tuple[int, ...], 72 raw_transform: Optional[Callable] = None, 73 label_transform: Optional[Callable] = None, 74 label_transform2: Optional[Callable] = None, 75 transform: Optional[Callable] = None, 76 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 77 dtype: torch.dtype = torch.float32, 78 label_dtype: torch.dtype = torch.float32, 79 n_samples: Optional[int] = None, 80 sampler: Optional[Callable] = None, 81 ndim: Optional[int] = None, 82 with_channels: bool = False, 83 with_label_channels: bool = False, 84 with_padding: bool = True, 85 z_ext: Optional[int] = None, 86 pre_label_transform: Optional[Callable] = None, 87 ): 88 self.raw_path = raw_path 89 self.raw_key = raw_key 90 self.raw = load_data(raw_path, raw_key) 91 92 self.label_path = label_path 93 self.label_key = label_key 94 self.labels = load_data(label_path, label_key) 95 96 self._with_channels = with_channels 97 self._with_label_channels = with_label_channels 98 99 if roi is not None: 100 if isinstance(roi, slice): 101 roi = (roi,) 102 103 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 104 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 105 RoiWrapper(self.labels, roi) 106 107 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 108 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 109 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 110 111 self.shape = shape_raw 112 self.roi = roi 113 114 self._ndim = len(shape_raw) if ndim is None else ndim 115 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 116 117 if patch_shape is not None: 118 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 119 120 self.patch_shape = patch_shape 121 122 self.raw_transform = raw_transform 123 self.label_transform = label_transform 124 self.label_transform2 = label_transform2 125 self.transform = transform 126 self.sampler = sampler 127 self.with_padding = with_padding 128 self.pre_label_transform = pre_label_transform 129 130 self.dtype = dtype 131 self.label_dtype = label_dtype 132 133 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 134 135 self.z_ext = z_ext 136 137 self.sample_shape = patch_shape 138 self.trafo_halo = None 139 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 140 # which is then cut. See code below; but this ne needs to be properly tested 141 142 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 143 # if self.trafo_halo is not None: 144 # if len(self.trafo_halo) == 2 and self._ndim == 3: 145 # self.trafo_halo = (0,) + self.trafo_halo 146 # assert len(self.trafo_halo) == self._ndim 147 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 148 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 149 150 def __len__(self): 151 return self._len 152 153 @property 154 def ndim(self): 155 return self._ndim 156 157 def _sample_bounding_box(self): 158 if self.sample_shape is None: 159 if self.z_ext is None: 160 bb_start = [0] * len(self.shape) 161 patch_shape_for_bb = self.shape 162 else: 163 z_diff = self.shape[0] - self.z_ext 164 bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:]) 165 patch_shape_for_bb = (self.z_ext, *self.shape[1:]) 166 167 else: 168 bb_start = [ 169 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape) 170 ] 171 patch_shape_for_bb = self.sample_shape 172 173 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 174 175 def _get_desired_raw_and_labels(self): 176 bb = self._sample_bounding_box() 177 bb_raw = (slice(None),) + bb if self._with_channels else bb 178 bb_labels = (slice(None),) + bb if self._with_label_channels else bb 179 raw, labels = self.raw[bb_raw], self.labels[bb_labels] 180 181 # Additional label transform on top to make sampler consider expected labels 182 # (eg. run connected components on disconnected semantic labels) 183 pre_label_transform = getattr(self, "pre_label_transform", None) 184 if pre_label_transform is not None: 185 labels = pre_label_transform(labels) 186 187 return raw, labels 188 189 def _get_sample(self, index): 190 if self.raw is None or self.labels is None: 191 raise RuntimeError("SegmentationDataset has not been properly deserialized.") 192 193 raw, labels = self._get_desired_raw_and_labels() 194 195 if self.sampler is not None: 196 sample_id = 0 197 while not self.sampler(raw, labels): 198 raw, labels = self._get_desired_raw_and_labels() 199 sample_id += 1 200 if sample_id > self.max_sampling_attempts: 201 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 202 203 # Padding the patch to match the expected input shape. 204 if self.patch_shape is not None and self.with_padding: 205 raw, labels = ensure_patch_shape( 206 raw=raw, 207 labels=labels, 208 patch_shape=self.patch_shape, 209 have_raw_channels=self._with_channels, 210 have_label_channels=self._with_label_channels, 211 ) 212 213 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 214 if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1: 215 raw = raw.squeeze(1 if self._with_channels else 0) 216 labels = labels.squeeze(1 if self._with_label_channels else 0) 217 218 return raw, labels 219 220 def crop(self, tensor): 221 """@private 222 """ 223 bb = self.inner_bb 224 if tensor.ndim > len(bb): 225 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 226 return tensor[bb] 227 228 def __getitem__(self, index): 229 raw, labels = self._get_sample(index) 230 initial_label_dtype = labels.dtype 231 232 if self.raw_transform is not None: 233 raw = self.raw_transform(raw) 234 235 if self.label_transform is not None: 236 labels = self.label_transform(labels) 237 238 if self.transform is not None: 239 raw, labels = self.transform(raw, labels) 240 if self.trafo_halo is not None: 241 raw = self.crop(raw) 242 labels = self.crop(labels) 243 244 # support enlarging bounding box here as well (for affinity transform) ? 245 if self.label_transform2 is not None: 246 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 247 labels = self.label_transform2(labels) 248 249 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 250 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 251 return raw, labels 252 253 # need to overwrite pickle to support h5py 254 def __getstate__(self): 255 state = self.__dict__.copy() 256 del state["raw"] 257 del state["labels"] 258 return state 259 260 def __setstate__(self, state): 261 raw_path, raw_key = state["raw_path"], state["raw_key"] 262 label_path, label_key = state["label_path"], state["label_key"] 263 roi = state["roi"] 264 try: 265 raw = load_data(raw_path, raw_key) 266 if roi is not None: 267 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 268 state["raw"] = raw 269 except Exception: 270 msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 271 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 272 msg += "But it cannot be used for further training and will throw an error." 273 warnings.warn(msg) 274 state["raw"] = None 275 276 try: 277 labels = load_data(label_path, label_key) 278 if roi is not None: 279 labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\ 280 RoiWrapper(labels, roi) 281 state["labels"] = labels 282 except Exception: 283 msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n" 284 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 285 msg += "But it cannot be used for further training and will throw an error." 286 warnings.warn(msg) 287 state["labels"] = None 288 289 self.__dict__.update(state)
Dataset that provides raw data and labels stored in a container data format for segmentation training.
The dataset loads a patch from the raw and label data and returns a sample for a batch.
Image data and label data must have the same shape, except for potential channels.
The dataset supports all file formats that can be opened with elf.io.open_file, such as hdf5, zarr or n5.
Use raw_path / label_path to specify the file path and raw_key / label_key to specify the internal dataset.
It also supports regular image formats, such as .tif. For these cases set raw_key=None / label_key=None.
Arguments:
- raw_path: The file path to the raw image data. May also be a list of file paths.
- raw_key: The key to the internal dataset containing the raw data.
- label_path: The file path to the label data. May also be a list of file paths.
- label_key: The key to the internal dataset containing the label data
- patch_shape: The patch shape for a training sample.
- raw_transform: Transformation applied to the raw data of a sample.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths). - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- with_padding: Whether to pad samples to
patch_shapeif their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
- pre_label_transform: Transformation applied to the label data of a chosen random sample,
before applying the sample validity via the
sampler.
65 def __init__( 66 self, 67 raw_path: Union[List[Any], str, os.PathLike], 68 raw_key: Optional[str], 69 label_path: Union[List[Any], str, os.PathLike], 70 label_key: Optional[str], 71 patch_shape: Tuple[int, ...], 72 raw_transform: Optional[Callable] = None, 73 label_transform: Optional[Callable] = None, 74 label_transform2: Optional[Callable] = None, 75 transform: Optional[Callable] = None, 76 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 77 dtype: torch.dtype = torch.float32, 78 label_dtype: torch.dtype = torch.float32, 79 n_samples: Optional[int] = None, 80 sampler: Optional[Callable] = None, 81 ndim: Optional[int] = None, 82 with_channels: bool = False, 83 with_label_channels: bool = False, 84 with_padding: bool = True, 85 z_ext: Optional[int] = None, 86 pre_label_transform: Optional[Callable] = None, 87 ): 88 self.raw_path = raw_path 89 self.raw_key = raw_key 90 self.raw = load_data(raw_path, raw_key) 91 92 self.label_path = label_path 93 self.label_key = label_key 94 self.labels = load_data(label_path, label_key) 95 96 self._with_channels = with_channels 97 self._with_label_channels = with_label_channels 98 99 if roi is not None: 100 if isinstance(roi, slice): 101 roi = (roi,) 102 103 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 104 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 105 RoiWrapper(self.labels, roi) 106 107 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 108 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 109 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 110 111 self.shape = shape_raw 112 self.roi = roi 113 114 self._ndim = len(shape_raw) if ndim is None else ndim 115 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 116 117 if patch_shape is not None: 118 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 119 120 self.patch_shape = patch_shape 121 122 self.raw_transform = raw_transform 123 self.label_transform = label_transform 124 self.label_transform2 = label_transform2 125 self.transform = transform 126 self.sampler = sampler 127 self.with_padding = with_padding 128 self.pre_label_transform = pre_label_transform 129 130 self.dtype = dtype 131 self.label_dtype = label_dtype 132 133 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 134 135 self.z_ext = z_ext 136 137 self.sample_shape = patch_shape 138 self.trafo_halo = None 139 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 140 # which is then cut. See code below; but this ne needs to be properly tested 141 142 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 143 # if self.trafo_halo is not None: 144 # if len(self.trafo_halo) == 2 and self._ndim == 3: 145 # self.trafo_halo = (0,) + self.trafo_halo 146 # assert len(self.trafo_halo) == self._ndim 147 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 148 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
The maximal number of sampling attempts, for loading a sample via __getitem__.
This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.