torch_em.data.segmentation_dataset
1import os 2import warnings 3from typing import List, Union, Tuple, Optional, Any, Callable 4 5import numpy as np 6from math import ceil 7 8import torch 9 10from elf.wrapper import RoiWrapper 11 12from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data, ensure_patch_shape 13 14 15class SegmentationDataset(torch.utils.data.Dataset): 16 """Dataset that provides raw data and labels stored in a container data format for segmentation training. 17 18 The dataset loads a patch from the raw and label data and returns a sample for a batch. 19 Image data and label data must have the same shape, except for potential channels. 20 The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 21 Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset. 22 It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`. 23 24 Args: 25 raw_path: The file path to the raw image data. May also be a list of file paths. 26 raw_key: The key to the internal dataset containing the raw data. 27 label_path: The file path to the label data. May also be a list of file paths. 28 label_key: The key to the internal dataset containing the label data 29 patch_shape: The patch shape for a training sample. 30 raw_transform: Transformation applied to the raw data of a sample. 31 label_transform: Transformation applied to the label data of a sample, 32 before applying augmentations via `transform`. 33 label_transform2: Transformation applied to the label data of a sample, 34 after applying augmentations via `transform`. 35 transform: Transformation applied to both the raw data and label data of a sample. 36 This can be used to implement data augmentations. 37 roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area. 38 dtype: The return data type of the raw data. 39 label_dtype: The return data type of the label data. 40 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 41 sampler: Sampler for rejecting samples according to a defined criterion. 42 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 43 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 44 with_channels: Whether the raw data has channels. 45 with_label_channels: Whether the label data has channels. 46 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 47 z_ext: Extra bounding box for loading the data across z. 48 """ 49 max_sampling_attempts = 500 50 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 51 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 52 """ 53 54 @staticmethod 55 def compute_len(shape, patch_shape): 56 if patch_shape is None: 57 return 1 58 else: 59 n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 60 return n_samples 61 62 def __init__( 63 self, 64 raw_path: Union[List[Any], str, os.PathLike], 65 raw_key: Optional[str], 66 label_path: Union[List[Any], str, os.PathLike], 67 label_key: Optional[str], 68 patch_shape: Tuple[int, ...], 69 raw_transform: Optional[Callable] = None, 70 label_transform: Optional[Callable] = None, 71 label_transform2: Optional[Callable] = None, 72 transform: Optional[Callable] = None, 73 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 74 dtype: torch.dtype = torch.float32, 75 label_dtype: torch.dtype = torch.float32, 76 n_samples: Optional[int] = None, 77 sampler: Optional[Callable] = None, 78 ndim: Optional[int] = None, 79 with_channels: bool = False, 80 with_label_channels: bool = False, 81 with_padding: bool = True, 82 z_ext: Optional[int] = None, 83 ): 84 self.raw_path = raw_path 85 self.raw_key = raw_key 86 self.raw = load_data(raw_path, raw_key) 87 88 self.label_path = label_path 89 self.label_key = label_key 90 self.labels = load_data(label_path, label_key) 91 92 self._with_channels = with_channels 93 self._with_label_channels = with_label_channels 94 95 if roi is not None: 96 if isinstance(roi, slice): 97 roi = (roi,) 98 99 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 100 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 101 RoiWrapper(self.labels, roi) 102 103 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 104 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 105 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 106 107 self.shape = shape_raw 108 self.roi = roi 109 110 self._ndim = len(shape_raw) if ndim is None else ndim 111 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 112 113 if patch_shape is not None: 114 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 115 116 self.patch_shape = patch_shape 117 118 self.raw_transform = raw_transform 119 self.label_transform = label_transform 120 self.label_transform2 = label_transform2 121 self.transform = transform 122 self.sampler = sampler 123 self.with_padding = with_padding 124 125 self.dtype = dtype 126 self.label_dtype = label_dtype 127 128 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 129 130 self.z_ext = z_ext 131 132 self.sample_shape = patch_shape 133 self.trafo_halo = None 134 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 135 # which is then cut. See code below; but this ne needs to be properly tested 136 137 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 138 # if self.trafo_halo is not None: 139 # if len(self.trafo_halo) == 2 and self._ndim == 3: 140 # self.trafo_halo = (0,) + self.trafo_halo 141 # assert len(self.trafo_halo) == self._ndim 142 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 143 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 144 145 def __len__(self): 146 return self._len 147 148 @property 149 def ndim(self): 150 return self._ndim 151 152 def _sample_bounding_box(self): 153 if self.sample_shape is None: 154 if self.z_ext is None: 155 bb_start = [0] * len(self.shape) 156 patch_shape_for_bb = self.shape 157 else: 158 z_diff = self.shape[0] - self.z_ext 159 bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:]) 160 patch_shape_for_bb = (self.z_ext, *self.shape[1:]) 161 162 else: 163 bb_start = [ 164 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape) 165 ] 166 patch_shape_for_bb = self.sample_shape 167 168 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 169 170 def _get_desired_raw_and_labels(self): 171 bb = self._sample_bounding_box() 172 bb_raw = (slice(None),) + bb if self._with_channels else bb 173 bb_labels = (slice(None),) + bb if self._with_label_channels else bb 174 raw, labels = self.raw[bb_raw], self.labels[bb_labels] 175 return raw, labels 176 177 def _get_sample(self, index): 178 if self.raw is None or self.labels is None: 179 raise RuntimeError("SegmentationDataset has not been properly deserialized.") 180 181 raw, labels = self._get_desired_raw_and_labels() 182 183 if self.sampler is not None: 184 sample_id = 0 185 while not self.sampler(raw, labels): 186 raw, labels = self._get_desired_raw_and_labels() 187 sample_id += 1 188 if sample_id > self.max_sampling_attempts: 189 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 190 191 # Padding the patch to match the expected input shape. 192 if self.patch_shape is not None and self.with_padding: 193 raw, labels = ensure_patch_shape( 194 raw=raw, 195 labels=labels, 196 patch_shape=self.patch_shape, 197 have_raw_channels=self._with_channels, 198 have_label_channels=self._with_label_channels, 199 ) 200 201 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 202 if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1: 203 raw = raw.squeeze(1 if self._with_channels else 0) 204 labels = labels.squeeze(1 if self._with_label_channels else 0) 205 206 return raw, labels 207 208 def crop(self, tensor): 209 """@private 210 """ 211 bb = self.inner_bb 212 if tensor.ndim > len(bb): 213 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 214 return tensor[bb] 215 216 def __getitem__(self, index): 217 raw, labels = self._get_sample(index) 218 initial_label_dtype = labels.dtype 219 220 if self.raw_transform is not None: 221 raw = self.raw_transform(raw) 222 223 if self.label_transform is not None: 224 labels = self.label_transform(labels) 225 226 if self.transform is not None: 227 raw, labels = self.transform(raw, labels) 228 if self.trafo_halo is not None: 229 raw = self.crop(raw) 230 labels = self.crop(labels) 231 232 # support enlarging bounding box here as well (for affinity transform) ? 233 if self.label_transform2 is not None: 234 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 235 labels = self.label_transform2(labels) 236 237 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 238 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 239 return raw, labels 240 241 # need to overwrite pickle to support h5py 242 def __getstate__(self): 243 state = self.__dict__.copy() 244 del state["raw"] 245 del state["labels"] 246 return state 247 248 def __setstate__(self, state): 249 raw_path, raw_key = state["raw_path"], state["raw_key"] 250 label_path, label_key = state["label_path"], state["label_key"] 251 roi = state["roi"] 252 try: 253 raw = load_data(raw_path, raw_key) 254 if roi is not None: 255 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 256 state["raw"] = raw 257 except Exception: 258 msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 259 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 260 msg += "But it cannot be used for further training and will throw an error." 261 warnings.warn(msg) 262 state["raw"] = None 263 264 try: 265 labels = load_data(label_path, label_key) 266 if roi is not None: 267 labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\ 268 RoiWrapper(labels, roi) 269 state["labels"] = labels 270 except Exception: 271 msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n" 272 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 273 msg += "But it cannot be used for further training and will throw an error." 274 warnings.warn(msg) 275 state["labels"] = None 276 277 self.__dict__.update(state)
16class SegmentationDataset(torch.utils.data.Dataset): 17 """Dataset that provides raw data and labels stored in a container data format for segmentation training. 18 19 The dataset loads a patch from the raw and label data and returns a sample for a batch. 20 Image data and label data must have the same shape, except for potential channels. 21 The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 22 Use `raw_path` / `label_path` to specify the file path and `raw_key` / `label_key` to specify the internal dataset. 23 It also supports regular image formats, such as .tif. For these cases set `raw_key=None` / `label_key=None`. 24 25 Args: 26 raw_path: The file path to the raw image data. May also be a list of file paths. 27 raw_key: The key to the internal dataset containing the raw data. 28 label_path: The file path to the label data. May also be a list of file paths. 29 label_key: The key to the internal dataset containing the label data 30 patch_shape: The patch shape for a training sample. 31 raw_transform: Transformation applied to the raw data of a sample. 32 label_transform: Transformation applied to the label data of a sample, 33 before applying augmentations via `transform`. 34 label_transform2: Transformation applied to the label data of a sample, 35 after applying augmentations via `transform`. 36 transform: Transformation applied to both the raw data and label data of a sample. 37 This can be used to implement data augmentations. 38 roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area. 39 dtype: The return data type of the raw data. 40 label_dtype: The return data type of the label data. 41 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 42 sampler: Sampler for rejecting samples according to a defined criterion. 43 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 44 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 45 with_channels: Whether the raw data has channels. 46 with_label_channels: Whether the label data has channels. 47 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 48 z_ext: Extra bounding box for loading the data across z. 49 """ 50 max_sampling_attempts = 500 51 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 52 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 53 """ 54 55 @staticmethod 56 def compute_len(shape, patch_shape): 57 if patch_shape is None: 58 return 1 59 else: 60 n_samples = ceil(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 61 return n_samples 62 63 def __init__( 64 self, 65 raw_path: Union[List[Any], str, os.PathLike], 66 raw_key: Optional[str], 67 label_path: Union[List[Any], str, os.PathLike], 68 label_key: Optional[str], 69 patch_shape: Tuple[int, ...], 70 raw_transform: Optional[Callable] = None, 71 label_transform: Optional[Callable] = None, 72 label_transform2: Optional[Callable] = None, 73 transform: Optional[Callable] = None, 74 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 75 dtype: torch.dtype = torch.float32, 76 label_dtype: torch.dtype = torch.float32, 77 n_samples: Optional[int] = None, 78 sampler: Optional[Callable] = None, 79 ndim: Optional[int] = None, 80 with_channels: bool = False, 81 with_label_channels: bool = False, 82 with_padding: bool = True, 83 z_ext: Optional[int] = None, 84 ): 85 self.raw_path = raw_path 86 self.raw_key = raw_key 87 self.raw = load_data(raw_path, raw_key) 88 89 self.label_path = label_path 90 self.label_key = label_key 91 self.labels = load_data(label_path, label_key) 92 93 self._with_channels = with_channels 94 self._with_label_channels = with_label_channels 95 96 if roi is not None: 97 if isinstance(roi, slice): 98 roi = (roi,) 99 100 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 101 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 102 RoiWrapper(self.labels, roi) 103 104 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 105 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 106 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 107 108 self.shape = shape_raw 109 self.roi = roi 110 111 self._ndim = len(shape_raw) if ndim is None else ndim 112 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 113 114 if patch_shape is not None: 115 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 116 117 self.patch_shape = patch_shape 118 119 self.raw_transform = raw_transform 120 self.label_transform = label_transform 121 self.label_transform2 = label_transform2 122 self.transform = transform 123 self.sampler = sampler 124 self.with_padding = with_padding 125 126 self.dtype = dtype 127 self.label_dtype = label_dtype 128 129 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 130 131 self.z_ext = z_ext 132 133 self.sample_shape = patch_shape 134 self.trafo_halo = None 135 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 136 # which is then cut. See code below; but this ne needs to be properly tested 137 138 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 139 # if self.trafo_halo is not None: 140 # if len(self.trafo_halo) == 2 and self._ndim == 3: 141 # self.trafo_halo = (0,) + self.trafo_halo 142 # assert len(self.trafo_halo) == self._ndim 143 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 144 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 145 146 def __len__(self): 147 return self._len 148 149 @property 150 def ndim(self): 151 return self._ndim 152 153 def _sample_bounding_box(self): 154 if self.sample_shape is None: 155 if self.z_ext is None: 156 bb_start = [0] * len(self.shape) 157 patch_shape_for_bb = self.shape 158 else: 159 z_diff = self.shape[0] - self.z_ext 160 bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:]) 161 patch_shape_for_bb = (self.z_ext, *self.shape[1:]) 162 163 else: 164 bb_start = [ 165 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(self.shape, self.sample_shape) 166 ] 167 patch_shape_for_bb = self.sample_shape 168 169 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 170 171 def _get_desired_raw_and_labels(self): 172 bb = self._sample_bounding_box() 173 bb_raw = (slice(None),) + bb if self._with_channels else bb 174 bb_labels = (slice(None),) + bb if self._with_label_channels else bb 175 raw, labels = self.raw[bb_raw], self.labels[bb_labels] 176 return raw, labels 177 178 def _get_sample(self, index): 179 if self.raw is None or self.labels is None: 180 raise RuntimeError("SegmentationDataset has not been properly deserialized.") 181 182 raw, labels = self._get_desired_raw_and_labels() 183 184 if self.sampler is not None: 185 sample_id = 0 186 while not self.sampler(raw, labels): 187 raw, labels = self._get_desired_raw_and_labels() 188 sample_id += 1 189 if sample_id > self.max_sampling_attempts: 190 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 191 192 # Padding the patch to match the expected input shape. 193 if self.patch_shape is not None and self.with_padding: 194 raw, labels = ensure_patch_shape( 195 raw=raw, 196 labels=labels, 197 patch_shape=self.patch_shape, 198 have_raw_channels=self._with_channels, 199 have_label_channels=self._with_label_channels, 200 ) 201 202 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 203 if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1: 204 raw = raw.squeeze(1 if self._with_channels else 0) 205 labels = labels.squeeze(1 if self._with_label_channels else 0) 206 207 return raw, labels 208 209 def crop(self, tensor): 210 """@private 211 """ 212 bb = self.inner_bb 213 if tensor.ndim > len(bb): 214 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 215 return tensor[bb] 216 217 def __getitem__(self, index): 218 raw, labels = self._get_sample(index) 219 initial_label_dtype = labels.dtype 220 221 if self.raw_transform is not None: 222 raw = self.raw_transform(raw) 223 224 if self.label_transform is not None: 225 labels = self.label_transform(labels) 226 227 if self.transform is not None: 228 raw, labels = self.transform(raw, labels) 229 if self.trafo_halo is not None: 230 raw = self.crop(raw) 231 labels = self.crop(labels) 232 233 # support enlarging bounding box here as well (for affinity transform) ? 234 if self.label_transform2 is not None: 235 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 236 labels = self.label_transform2(labels) 237 238 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 239 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 240 return raw, labels 241 242 # need to overwrite pickle to support h5py 243 def __getstate__(self): 244 state = self.__dict__.copy() 245 del state["raw"] 246 del state["labels"] 247 return state 248 249 def __setstate__(self, state): 250 raw_path, raw_key = state["raw_path"], state["raw_key"] 251 label_path, label_key = state["label_path"], state["label_key"] 252 roi = state["roi"] 253 try: 254 raw = load_data(raw_path, raw_key) 255 if roi is not None: 256 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 257 state["raw"] = raw 258 except Exception: 259 msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 260 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 261 msg += "But it cannot be used for further training and will throw an error." 262 warnings.warn(msg) 263 state["raw"] = None 264 265 try: 266 labels = load_data(label_path, label_key) 267 if roi is not None: 268 labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\ 269 RoiWrapper(labels, roi) 270 state["labels"] = labels 271 except Exception: 272 msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n" 273 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 274 msg += "But it cannot be used for further training and will throw an error." 275 warnings.warn(msg) 276 state["labels"] = None 277 278 self.__dict__.update(state)
Dataset that provides raw data and labels stored in a container data format for segmentation training.
The dataset loads a patch from the raw and label data and returns a sample for a batch.
Image data and label data must have the same shape, except for potential channels.
The dataset supports all file formats that can be opened with elf.io.open_file
, such as hdf5, zarr or n5.
Use raw_path
/ label_path
to specify the file path and raw_key
/ label_key
to specify the internal dataset.
It also supports regular image formats, such as .tif. For these cases set raw_key=None
/ label_key=None
.
Arguments:
- raw_path: The file path to the raw image data. May also be a list of file paths.
- raw_key: The key to the internal dataset containing the raw data.
- label_path: The file path to the label data. May also be a list of file paths.
- label_key: The key to the internal dataset containing the label data
- patch_shape: The patch shape for a training sample.
- raw_transform: Transformation applied to the raw data of a sample.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform
. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform
. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- roi: Region of interest in the data. If given, the data will only be loaded from the corresponding area.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths)
. - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- with_padding: Whether to pad samples to
patch_shape
if their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
63 def __init__( 64 self, 65 raw_path: Union[List[Any], str, os.PathLike], 66 raw_key: Optional[str], 67 label_path: Union[List[Any], str, os.PathLike], 68 label_key: Optional[str], 69 patch_shape: Tuple[int, ...], 70 raw_transform: Optional[Callable] = None, 71 label_transform: Optional[Callable] = None, 72 label_transform2: Optional[Callable] = None, 73 transform: Optional[Callable] = None, 74 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 75 dtype: torch.dtype = torch.float32, 76 label_dtype: torch.dtype = torch.float32, 77 n_samples: Optional[int] = None, 78 sampler: Optional[Callable] = None, 79 ndim: Optional[int] = None, 80 with_channels: bool = False, 81 with_label_channels: bool = False, 82 with_padding: bool = True, 83 z_ext: Optional[int] = None, 84 ): 85 self.raw_path = raw_path 86 self.raw_key = raw_key 87 self.raw = load_data(raw_path, raw_key) 88 89 self.label_path = label_path 90 self.label_key = label_key 91 self.labels = load_data(label_path, label_key) 92 93 self._with_channels = with_channels 94 self._with_label_channels = with_label_channels 95 96 if roi is not None: 97 if isinstance(roi, slice): 98 roi = (roi,) 99 100 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 101 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 102 RoiWrapper(self.labels, roi) 103 104 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 105 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 106 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 107 108 self.shape = shape_raw 109 self.roi = roi 110 111 self._ndim = len(shape_raw) if ndim is None else ndim 112 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 113 114 if patch_shape is not None: 115 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 116 117 self.patch_shape = patch_shape 118 119 self.raw_transform = raw_transform 120 self.label_transform = label_transform 121 self.label_transform2 = label_transform2 122 self.transform = transform 123 self.sampler = sampler 124 self.with_padding = with_padding 125 126 self.dtype = dtype 127 self.label_dtype = label_dtype 128 129 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 130 131 self.z_ext = z_ext 132 133 self.sample_shape = patch_shape 134 self.trafo_halo = None 135 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 136 # which is then cut. See code below; but this ne needs to be properly tested 137 138 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 139 # if self.trafo_halo is not None: 140 # if len(self.trafo_halo) == 2 and self._ndim == 3: 141 # self.trafo_halo = (0,) + self.trafo_halo 142 # assert len(self.trafo_halo) == self._ndim 143 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 144 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
The maximal number of sampling attempts, for loading a sample via __getitem__
.
This is used when sampler
rejects a sample, to avoid an infinite loop if no valid sample can be found.