torch_em.data.segmentation_dataset
1import os 2import warnings 3import numpy as np 4from typing import List, Union, Tuple, Optional, Any, Callable 5 6import torch 7 8from elf.wrapper import RoiWrapper 9 10from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data, ensure_patch_shape 11 12 13class SegmentationDataset(torch.utils.data.Dataset): 14 """Dataset that provides raw data and labels stored in a container data format for segmentation training. 15 16 The dataset loads a patch from the raw and label data and returns a sample for a batch. 17 Image data and label data must have the same shape, except for potential channels. 18 The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 19 Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset. 20 It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`. 21 22 Args: 23 raw_path: The file path to the raw image data. May also be a list of file paths. 24 raw_key: The key to the internal dataset containing the raw data. 25 label_path: The file path to the label data. May also be a list of file paths. 26 label_key: The key to the internal dataset containing the label data 27 patch_shape: The patch shape for a training sample. 28 raw_transform: Transformation applied to the raw data of a sample. 29 label_transform: Transformation applied to the label data of a sample, 30 before applying augmentations via `transform`. 31 label_transform2: Transformation applied to the label data of a sample, 32 after applying augmentations via `transform`. 33 transform: Transformation applied to both the raw data and label data of a sample. 34 This can be used to implement data augmentations. 35 roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area. 36 dtype: The return data type of the raw data. 37 label_dtype: The return data type of the label data. 38 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 39 sampler: Sampler for rejecting samples according to a defined criterion. 40 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 41 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 42 with_channels: Whether the raw data has channels. 43 with_label_channels: Whether the label data has channels. 44 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 45 z_ext: Extra bounding box for loading the data across z. 46 """ 47 max_sampling_attempts = 500 48 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 49 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 50 """ 51 52 @staticmethod 53 def compute_len(shape, patch_shape): 54 if patch_shape is None: 55 return 1 56 else: 57 n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 58 return n_samples 59 60 def __init__( 61 self, 62 raw_path: Union[List[Any], str, os.PathLike], 63 raw_key: Optional[str], 64 label_path: Union[List[Any], str, os.PathLike], 65 label_key: Optional[str], 66 patch_shape: Tuple[int, ...], 67 raw_transform: Optional[Callable] = None, 68 label_transform: Optional[Callable] = None, 69 label_transform2: Optional[Callable] = None, 70 transform: Optional[Callable] = None, 71 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 72 dtype: torch.dtype = torch.float32, 73 label_dtype: torch.dtype = torch.float32, 74 n_samples: Optional[int] = None, 75 sampler: Optional[Callable] = None, 76 ndim: Optional[int] = None, 77 with_channels: bool = False, 78 with_label_channels: bool = False, 79 with_padding: bool = True, 80 z_ext: Optional[int] = None, 81 ): 82 self.raw_path = raw_path 83 self.raw_key = raw_key 84 self.raw = load_data(raw_path, raw_key) 85 86 self.label_path = label_path 87 self.label_key = label_key 88 self.labels = load_data(label_path, label_key) 89 90 self._with_channels = with_channels 91 self._with_label_channels = with_label_channels 92 93 if roi is not None: 94 if isinstance(roi, slice): 95 roi = (roi,) 96 97 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 98 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 99 RoiWrapper(self.labels, roi) 100 101 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 102 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 103 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 104 105 self.shape = shape_raw 106 self.roi = roi 107 108 self._ndim = len(shape_raw) if ndim is None else ndim 109 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 110 111 if patch_shape is not None: 112 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 113 114 self.patch_shape = patch_shape 115 116 self.raw_transform = raw_transform 117 self.label_transform = label_transform 118 self.label_transform2 = label_transform2 119 self.transform = transform 120 self.sampler = sampler 121 self.with_padding = with_padding 122 123 self.dtype = dtype 124 self.label_dtype = label_dtype 125 126 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 127 128 self.z_ext = z_ext 129 130 self.sample_shape = patch_shape 131 self.trafo_halo = None 132 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 133 # which is then cut. See code below; but this ne needs to be properly tested 134 135 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 136 # if self.trafo_halo is not None: 137 # if len(self.trafo_halo) == 2 and self._ndim == 3: 138 # self.trafo_halo = (0,) + self.trafo_halo 139 # assert len(self.trafo_halo) == self._ndim 140 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 141 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 142 143 def __len__(self): 144 return self._len 145 146 @property 147 def ndim(self): 148 return self._ndim 149 150 def _sample_bounding_box(self): 151 if self.sample_shape is None: 152 if self.z_ext is None: 153 bb_start = [0] * len(self.shape) 154 patch_shape_for_bb = self.shape 155 else: 156 z_diff = self.shape[0] - self.z_ext 157 bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:]) 158 patch_shape_for_bb = (self.z_ext, *self.shape[1:]) 159 160 else: 161 bb_start = [ 162 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape) 163 ] 164 patch_shape_for_bb = self.sample_shape 165 166 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 167 168 def _get_desired_raw_and_labels(self): 169 bb = self._sample_bounding_box() 170 bb_raw = (slice(None),) + bb if self._with_channels else bb 171 bb_labels = (slice(None),) + bb if self._with_label_channels else bb 172 raw, labels = self.raw[bb_raw], self.labels[bb_labels] 173 return raw, labels 174 175 def _get_sample(self, index): 176 if self.raw is None or self.labels is None: 177 raise RuntimeError("SegmentationDataset has not been properly deserialized.") 178 179 raw, labels = self._get_desired_raw_and_labels() 180 181 if self.sampler is not None: 182 sample_id = 0 183 while not self.sampler(raw, labels): 184 raw, labels = self._get_desired_raw_and_labels() 185 sample_id += 1 186 if sample_id > self.max_sampling_attempts: 187 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 188 189 # Padding the patch to match the expected input shape. 190 if self.patch_shape is not None and self.with_padding: 191 raw, labels = ensure_patch_shape( 192 raw=raw, 193 labels=labels, 194 patch_shape=self.patch_shape, 195 have_raw_channels=self._with_channels, 196 have_label_channels=self._with_label_channels, 197 ) 198 199 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 200 if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1: 201 raw = raw.squeeze(1 if self._with_channels else 0) 202 labels = labels.squeeze(1 if self._with_label_channels else 0) 203 204 return raw, labels 205 206 def crop(self, tensor): 207 """@private 208 """ 209 bb = self.inner_bb 210 if tensor.ndim > len(bb): 211 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 212 return tensor[bb] 213 214 def __getitem__(self, index): 215 raw, labels = self._get_sample(index) 216 initial_label_dtype = labels.dtype 217 218 if self.raw_transform is not None: 219 raw = self.raw_transform(raw) 220 221 if self.label_transform is not None: 222 labels = self.label_transform(labels) 223 224 if self.transform is not None: 225 raw, labels = self.transform(raw, labels) 226 if self.trafo_halo is not None: 227 raw = self.crop(raw) 228 labels = self.crop(labels) 229 230 # support enlarging bounding box here as well (for affinity transform) ? 231 if self.label_transform2 is not None: 232 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 233 labels = self.label_transform2(labels) 234 235 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 236 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 237 return raw, labels 238 239 # need to overwrite pickle to support h5py 240 def __getstate__(self): 241 state = self.__dict__.copy() 242 del state["raw"] 243 del state["labels"] 244 return state 245 246 def __setstate__(self, state): 247 raw_path, raw_key = state["raw_path"], state["raw_key"] 248 label_path, label_key = state["label_path"], state["label_key"] 249 roi = state["roi"] 250 try: 251 raw = load_data(raw_path, raw_key) 252 if roi is not None: 253 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 254 state["raw"] = raw 255 except Exception: 256 msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 257 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 258 msg += "But it cannot be used for further training and will throw an error." 259 warnings.warn(msg) 260 state["raw"] = None 261 262 try: 263 labels = load_data(label_path, label_key) 264 if roi is not None: 265 labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\ 266 RoiWrapper(labels, roi) 267 state["labels"] = labels 268 except Exception: 269 msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n" 270 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 271 msg += "But it cannot be used for further training and will throw an error." 272 warnings.warn(msg) 273 state["labels"] = None 274 275 self.__dict__.update(state)
14class SegmentationDataset(torch.utils.data.Dataset): 15 """Dataset that provides raw data and labels stored in a container data format for segmentation training. 16 17 The dataset loads a patch from the raw and label data and returns a sample for a batch. 18 Image data and label data must have the same shape, except for potential channels. 19 The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 20 Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset. 21 It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`. 22 23 Args: 24 raw_path: The file path to the raw image data. May also be a list of file paths. 25 raw_key: The key to the internal dataset containing the raw data. 26 label_path: The file path to the label data. May also be a list of file paths. 27 label_key: The key to the internal dataset containing the label data 28 patch_shape: The patch shape for a training sample. 29 raw_transform: Transformation applied to the raw data of a sample. 30 label_transform: Transformation applied to the label data of a sample, 31 before applying augmentations via `transform`. 32 label_transform2: Transformation applied to the label data of a sample, 33 after applying augmentations via `transform`. 34 transform: Transformation applied to both the raw data and label data of a sample. 35 This can be used to implement data augmentations. 36 roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area. 37 dtype: The return data type of the raw data. 38 label_dtype: The return data type of the label data. 39 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 40 sampler: Sampler for rejecting samples according to a defined criterion. 41 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 42 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 43 with_channels: Whether the raw data has channels. 44 with_label_channels: Whether the label data has channels. 45 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 46 z_ext: Extra bounding box for loading the data across z. 47 """ 48 max_sampling_attempts = 500 49 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 50 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 51 """ 52 53 @staticmethod 54 def compute_len(shape, patch_shape): 55 if patch_shape is None: 56 return 1 57 else: 58 n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 59 return n_samples 60 61 def __init__( 62 self, 63 raw_path: Union[List[Any], str, os.PathLike], 64 raw_key: Optional[str], 65 label_path: Union[List[Any], str, os.PathLike], 66 label_key: Optional[str], 67 patch_shape: Tuple[int, ...], 68 raw_transform: Optional[Callable] = None, 69 label_transform: Optional[Callable] = None, 70 label_transform2: Optional[Callable] = None, 71 transform: Optional[Callable] = None, 72 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 73 dtype: torch.dtype = torch.float32, 74 label_dtype: torch.dtype = torch.float32, 75 n_samples: Optional[int] = None, 76 sampler: Optional[Callable] = None, 77 ndim: Optional[int] = None, 78 with_channels: bool = False, 79 with_label_channels: bool = False, 80 with_padding: bool = True, 81 z_ext: Optional[int] = None, 82 ): 83 self.raw_path = raw_path 84 self.raw_key = raw_key 85 self.raw = load_data(raw_path, raw_key) 86 87 self.label_path = label_path 88 self.label_key = label_key 89 self.labels = load_data(label_path, label_key) 90 91 self._with_channels = with_channels 92 self._with_label_channels = with_label_channels 93 94 if roi is not None: 95 if isinstance(roi, slice): 96 roi = (roi,) 97 98 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 99 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 100 RoiWrapper(self.labels, roi) 101 102 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 103 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 104 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 105 106 self.shape = shape_raw 107 self.roi = roi 108 109 self._ndim = len(shape_raw) if ndim is None else ndim 110 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 111 112 if patch_shape is not None: 113 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 114 115 self.patch_shape = patch_shape 116 117 self.raw_transform = raw_transform 118 self.label_transform = label_transform 119 self.label_transform2 = label_transform2 120 self.transform = transform 121 self.sampler = sampler 122 self.with_padding = with_padding 123 124 self.dtype = dtype 125 self.label_dtype = label_dtype 126 127 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 128 129 self.z_ext = z_ext 130 131 self.sample_shape = patch_shape 132 self.trafo_halo = None 133 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 134 # which is then cut. See code below; but this ne needs to be properly tested 135 136 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 137 # if self.trafo_halo is not None: 138 # if len(self.trafo_halo) == 2 and self._ndim == 3: 139 # self.trafo_halo = (0,) + self.trafo_halo 140 # assert len(self.trafo_halo) == self._ndim 141 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 142 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 143 144 def __len__(self): 145 return self._len 146 147 @property 148 def ndim(self): 149 return self._ndim 150 151 def _sample_bounding_box(self): 152 if self.sample_shape is None: 153 if self.z_ext is None: 154 bb_start = [0] * len(self.shape) 155 patch_shape_for_bb = self.shape 156 else: 157 z_diff = self.shape[0] - self.z_ext 158 bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:]) 159 patch_shape_for_bb = (self.z_ext, *self.shape[1:]) 160 161 else: 162 bb_start = [ 163 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape) 164 ] 165 patch_shape_for_bb = self.sample_shape 166 167 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 168 169 def _get_desired_raw_and_labels(self): 170 bb = self._sample_bounding_box() 171 bb_raw = (slice(None),) + bb if self._with_channels else bb 172 bb_labels = (slice(None),) + bb if self._with_label_channels else bb 173 raw, labels = self.raw[bb_raw], self.labels[bb_labels] 174 return raw, labels 175 176 def _get_sample(self, index): 177 if self.raw is None or self.labels is None: 178 raise RuntimeError("SegmentationDataset has not been properly deserialized.") 179 180 raw, labels = self._get_desired_raw_and_labels() 181 182 if self.sampler is not None: 183 sample_id = 0 184 while not self.sampler(raw, labels): 185 raw, labels = self._get_desired_raw_and_labels() 186 sample_id += 1 187 if sample_id > self.max_sampling_attempts: 188 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 189 190 # Padding the patch to match the expected input shape. 191 if self.patch_shape is not None and self.with_padding: 192 raw, labels = ensure_patch_shape( 193 raw=raw, 194 labels=labels, 195 patch_shape=self.patch_shape, 196 have_raw_channels=self._with_channels, 197 have_label_channels=self._with_label_channels, 198 ) 199 200 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 201 if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1: 202 raw = raw.squeeze(1 if self._with_channels else 0) 203 labels = labels.squeeze(1 if self._with_label_channels else 0) 204 205 return raw, labels 206 207 def crop(self, tensor): 208 """@private 209 """ 210 bb = self.inner_bb 211 if tensor.ndim > len(bb): 212 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 213 return tensor[bb] 214 215 def __getitem__(self, index): 216 raw, labels = self._get_sample(index) 217 initial_label_dtype = labels.dtype 218 219 if self.raw_transform is not None: 220 raw = self.raw_transform(raw) 221 222 if self.label_transform is not None: 223 labels = self.label_transform(labels) 224 225 if self.transform is not None: 226 raw, labels = self.transform(raw, labels) 227 if self.trafo_halo is not None: 228 raw = self.crop(raw) 229 labels = self.crop(labels) 230 231 # support enlarging bounding box here as well (for affinity transform) ? 232 if self.label_transform2 is not None: 233 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 234 labels = self.label_transform2(labels) 235 236 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 237 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 238 return raw, labels 239 240 # need to overwrite pickle to support h5py 241 def __getstate__(self): 242 state = self.__dict__.copy() 243 del state["raw"] 244 del state["labels"] 245 return state 246 247 def __setstate__(self, state): 248 raw_path, raw_key = state["raw_path"], state["raw_key"] 249 label_path, label_key = state["label_path"], state["label_key"] 250 roi = state["roi"] 251 try: 252 raw = load_data(raw_path, raw_key) 253 if roi is not None: 254 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 255 state["raw"] = raw 256 except Exception: 257 msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 258 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 259 msg += "But it cannot be used for further training and will throw an error." 260 warnings.warn(msg) 261 state["raw"] = None 262 263 try: 264 labels = load_data(label_path, label_key) 265 if roi is not None: 266 labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\ 267 RoiWrapper(labels, roi) 268 state["labels"] = labels 269 except Exception: 270 msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n" 271 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 272 msg += "But it cannot be used for further training and will throw an error." 273 warnings.warn(msg) 274 state["labels"] = None 275 276 self.__dict__.update(state)
Dataset that provides raw data and labels stored in a container data format for segmentation training.
The dataset loads a patch from the raw and label data and returns a sample for a batch.
Image data and label data must have the same shape, except for potential channels.
The dataset supports all file formats that can be opened with elf.io.open_file
, such as hdf5, zarr or n5.
Use raw_path
/ label_path
to specify the file path and raw_key
/ label_key
to specify the internal dataset.
It also supports regular image formats, such as .tif. For these cases set raw_key=None
/ label_key=None
.
Arguments:
- raw_path: The file path to the raw image data. May also be a list of file paths.
- raw_key: The key to the internal dataset containing the raw data.
- label_path: The file path to the label data. May also be a list of file paths.
- label_key: The key to the internal dataset containing the label data
- patch_shape: The patch shape for a training sample.
- raw_transform: Transformation applied to the raw data of a sample.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform
. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform
. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths)
. - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- with_padding: Whether to pad samples to
patch_shape
if their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
61 def __init__( 62 self, 63 raw_path: Union[List[Any], str, os.PathLike], 64 raw_key: Optional[str], 65 label_path: Union[List[Any], str, os.PathLike], 66 label_key: Optional[str], 67 patch_shape: Tuple[int, ...], 68 raw_transform: Optional[Callable] = None, 69 label_transform: Optional[Callable] = None, 70 label_transform2: Optional[Callable] = None, 71 transform: Optional[Callable] = None, 72 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 73 dtype: torch.dtype = torch.float32, 74 label_dtype: torch.dtype = torch.float32, 75 n_samples: Optional[int] = None, 76 sampler: Optional[Callable] = None, 77 ndim: Optional[int] = None, 78 with_channels: bool = False, 79 with_label_channels: bool = False, 80 with_padding: bool = True, 81 z_ext: Optional[int] = None, 82 ): 83 self.raw_path = raw_path 84 self.raw_key = raw_key 85 self.raw = load_data(raw_path, raw_key) 86 87 self.label_path = label_path 88 self.label_key = label_key 89 self.labels = load_data(label_path, label_key) 90 91 self._with_channels = with_channels 92 self._with_label_channels = with_label_channels 93 94 if roi is not None: 95 if isinstance(roi, slice): 96 roi = (roi,) 97 98 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 99 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 100 RoiWrapper(self.labels, roi) 101 102 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 103 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 104 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 105 106 self.shape = shape_raw 107 self.roi = roi 108 109 self._ndim = len(shape_raw) if ndim is None else ndim 110 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 111 112 if patch_shape is not None: 113 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 114 115 self.patch_shape = patch_shape 116 117 self.raw_transform = raw_transform 118 self.label_transform = label_transform 119 self.label_transform2 = label_transform2 120 self.transform = transform 121 self.sampler = sampler 122 self.with_padding = with_padding 123 124 self.dtype = dtype 125 self.label_dtype = label_dtype 126 127 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 128 129 self.z_ext = z_ext 130 131 self.sample_shape = patch_shape 132 self.trafo_halo = None 133 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 134 # which is then cut. See code below; but this ne needs to be properly tested 135 136 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 137 # if self.trafo_halo is not None: 138 # if len(self.trafo_halo) == 2 and self._ndim == 3: 139 # self.trafo_halo = (0,) + self.trafo_halo 140 # assert len(self.trafo_halo) == self._ndim 141 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 142 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
The maximal number of sampling attempts, for loading a sample via __getitem__
.
This is used when sampler
rejects a sample, to avoid an infinite loop if no valid sample can be found.