torch_em.data.raw_dataset
1import os 2import warnings 3import numpy as np 4from typing import List, Union, Tuple, Optional, Any, Callable 5 6import torch 7 8from elf.wrapper import RoiWrapper 9 10from ..util import ensure_tensor_with_channels, ensure_patch_shape, load_data 11 12 13class RawDataset(torch.utils.data.Dataset): 14 """Dataset that provides raw data stored in a container data format for unsupervised training. 15 16 The dataset loads a patch from the raw data and returns a sample for a batch. 17 The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 18 Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset. 19 It also supports regular image formats, such as .tif. For these cases set `raw_key=None`. 20 21 The dataset can also be used for contrastive learning that relies on two different views of the same data. 22 You can use the `augmentations` argument for this. 23 24 Args: 25 raw_path: The file path to the raw image data. May also be a list of file paths. 26 raw_key: The key to the internal dataset containing the raw data. 27 patch_shape: The patch shape for a training sample. 28 raw_transform: Transformation applied to the raw data of a sample. 29 transform: Transformation to the raw data. This can be used to implement data augmentations. 30 roi: Region of interest in the raw data. 31 If given, the raw data will only be loaded from the corresponding area. 32 dtype: The return data type of the raw data. 33 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 34 sampler: Sampler for rejecting samples according to a defined criterion. 35 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 36 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 37 with_channels: Whether the raw data has channels. 38 augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. 39 They will be applied to the sampled raw data to return two independent views of the raw data. 40 """ 41 max_sampling_attempts = 500 42 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 43 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 44 """ 45 46 @staticmethod 47 def compute_len(shape, patch_shape): 48 n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 49 return n_samples 50 51 def __init__( 52 self, 53 raw_path: Union[List[Any], str, os.PathLike], 54 raw_key: Optional[str], 55 patch_shape: Tuple[int, ...], 56 raw_transform: Optional[Callable] = None, 57 transform: Optional[Callable] = None, 58 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 59 dtype: torch.dtype = torch.float32, 60 n_samples: Optional[int] = None, 61 sampler: Optional[Callable] = None, 62 ndim: Optional[int] = None, 63 with_channels: bool = False, 64 augmentations: Optional[Tuple[Callable, Callable]] = None, 65 ): 66 self.raw_path = raw_path 67 self.raw_key = raw_key 68 self.raw = load_data(raw_path, raw_key) 69 70 self._with_channels = with_channels 71 72 if roi is not None: 73 if isinstance(roi, slice): 74 roi = (roi,) 75 76 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 77 78 self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 79 self.roi = roi 80 81 self._ndim = len(self.shape) if ndim is None else ndim 82 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 83 84 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 85 self.patch_shape = patch_shape 86 87 self.raw_transform = raw_transform 88 self.transform = transform 89 self.sampler = sampler 90 self.dtype = dtype 91 92 if augmentations is not None: 93 assert len(augmentations) == 2 94 self.augmentations = augmentations 95 96 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 97 98 self.sample_shape = patch_shape 99 self.trafo_halo = None 100 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 101 # which is then cut. See code below; but this ne needs to be properly tested 102 103 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 104 # if self.trafo_halo is not None: 105 # if len(self.trafo_halo) == 2 and self._ndim == 3: 106 # self.trafo_halo = (0,) + self.trafo_halo 107 # assert len(self.trafo_halo) == self._ndim 108 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 109 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 110 111 def __len__(self): 112 return self._len 113 114 @property 115 def ndim(self): 116 return self._ndim 117 118 def _sample_bounding_box(self): 119 bb_start = [ 120 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 121 for sh, psh in zip(self.shape, self.sample_shape) 122 ] 123 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape)) 124 125 def _get_sample(self, index): 126 if self.raw is None: 127 raise RuntimeError("RawDataset has not been properly deserialized.") 128 bb = self._sample_bounding_box() 129 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 130 131 if self.sampler is not None: 132 sample_id = 0 133 while not self.sampler(raw): 134 bb = self._sample_bounding_box() 135 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 136 sample_id += 1 137 if sample_id > self.max_sampling_attempts: 138 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 139 140 if self.patch_shape is not None: 141 raw = ensure_patch_shape( 142 raw=raw, labels=None, patch_shape=self.patch_shape, have_raw_channels=self._with_channels 143 ) 144 145 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 146 if len(self.patch_shape) == self._ndim + 1: 147 raw = raw.squeeze(1 if self._with_channels else 0) 148 149 return raw 150 151 def crop(self, tensor): 152 bb = self.inner_bb 153 if tensor.ndim > len(bb): 154 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 155 return tensor[bb] 156 157 def __getitem__(self, index): 158 raw = self._get_sample(index) 159 160 if self.raw_transform is not None: 161 raw = self.raw_transform(raw) 162 163 if self.transform is not None: 164 raw = self.transform(raw) 165 if isinstance(raw, list): 166 assert len(raw) == 1 167 raw = raw[0] 168 169 if self.trafo_halo is not None: 170 raw = self.crop(raw) 171 172 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 173 if self.augmentations is not None: 174 aug1, aug2 = self.augmentations 175 raw1, raw2 = aug1(raw), aug2(raw) 176 return raw1, raw2 177 178 return raw 179 180 # need to overwrite pickle to support h5py 181 def __getstate__(self): 182 state = self.__dict__.copy() 183 del state["raw"] 184 return state 185 186 def __setstate__(self, state): 187 raw_path, raw_key = state["raw_path"], state["raw_key"] 188 roi = state["roi"] 189 try: 190 raw = load_data(raw_path, raw_key) 191 if roi is not None: 192 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 193 state["raw"] = raw 194 except Exception: 195 msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 196 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 197 msg += "But it cannot be used for further training and wil throw an error." 198 warnings.warn(msg) 199 state["raw"] = None 200 201 self.__dict__.update(state)
14class RawDataset(torch.utils.data.Dataset): 15 """Dataset that provides raw data stored in a container data format for unsupervised training. 16 17 The dataset loads a patch from the raw data and returns a sample for a batch. 18 The dataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 19 Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset. 20 It also supports regular image formats, such as .tif. For these cases set `raw_key=None`. 21 22 The dataset can also be used for contrastive learning that relies on two different views of the same data. 23 You can use the `augmentations` argument for this. 24 25 Args: 26 raw_path: The file path to the raw image data. May also be a list of file paths. 27 raw_key: The key to the internal dataset containing the raw data. 28 patch_shape: The patch shape for a training sample. 29 raw_transform: Transformation applied to the raw data of a sample. 30 transform: Transformation to the raw data. This can be used to implement data augmentations. 31 roi: Region of interest in the raw data. 32 If given, the raw data will only be loaded from the corresponding area. 33 dtype: The return data type of the raw data. 34 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 35 sampler: Sampler for rejecting samples according to a defined criterion. 36 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 37 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 38 with_channels: Whether the raw data has channels. 39 augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. 40 They will be applied to the sampled raw data to return two independent views of the raw data. 41 """ 42 max_sampling_attempts = 500 43 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 44 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 45 """ 46 47 @staticmethod 48 def compute_len(shape, patch_shape): 49 n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 50 return n_samples 51 52 def __init__( 53 self, 54 raw_path: Union[List[Any], str, os.PathLike], 55 raw_key: Optional[str], 56 patch_shape: Tuple[int, ...], 57 raw_transform: Optional[Callable] = None, 58 transform: Optional[Callable] = None, 59 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 60 dtype: torch.dtype = torch.float32, 61 n_samples: Optional[int] = None, 62 sampler: Optional[Callable] = None, 63 ndim: Optional[int] = None, 64 with_channels: bool = False, 65 augmentations: Optional[Tuple[Callable, Callable]] = None, 66 ): 67 self.raw_path = raw_path 68 self.raw_key = raw_key 69 self.raw = load_data(raw_path, raw_key) 70 71 self._with_channels = with_channels 72 73 if roi is not None: 74 if isinstance(roi, slice): 75 roi = (roi,) 76 77 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 78 79 self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 80 self.roi = roi 81 82 self._ndim = len(self.shape) if ndim is None else ndim 83 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 84 85 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 86 self.patch_shape = patch_shape 87 88 self.raw_transform = raw_transform 89 self.transform = transform 90 self.sampler = sampler 91 self.dtype = dtype 92 93 if augmentations is not None: 94 assert len(augmentations) == 2 95 self.augmentations = augmentations 96 97 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 98 99 self.sample_shape = patch_shape 100 self.trafo_halo = None 101 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 102 # which is then cut. See code below; but this ne needs to be properly tested 103 104 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 105 # if self.trafo_halo is not None: 106 # if len(self.trafo_halo) == 2 and self._ndim == 3: 107 # self.trafo_halo = (0,) + self.trafo_halo 108 # assert len(self.trafo_halo) == self._ndim 109 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 110 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 111 112 def __len__(self): 113 return self._len 114 115 @property 116 def ndim(self): 117 return self._ndim 118 119 def _sample_bounding_box(self): 120 bb_start = [ 121 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 122 for sh, psh in zip(self.shape, self.sample_shape) 123 ] 124 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape)) 125 126 def _get_sample(self, index): 127 if self.raw is None: 128 raise RuntimeError("RawDataset has not been properly deserialized.") 129 bb = self._sample_bounding_box() 130 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 131 132 if self.sampler is not None: 133 sample_id = 0 134 while not self.sampler(raw): 135 bb = self._sample_bounding_box() 136 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 137 sample_id += 1 138 if sample_id > self.max_sampling_attempts: 139 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 140 141 if self.patch_shape is not None: 142 raw = ensure_patch_shape( 143 raw=raw, labels=None, patch_shape=self.patch_shape, have_raw_channels=self._with_channels 144 ) 145 146 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 147 if len(self.patch_shape) == self._ndim + 1: 148 raw = raw.squeeze(1 if self._with_channels else 0) 149 150 return raw 151 152 def crop(self, tensor): 153 bb = self.inner_bb 154 if tensor.ndim > len(bb): 155 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 156 return tensor[bb] 157 158 def __getitem__(self, index): 159 raw = self._get_sample(index) 160 161 if self.raw_transform is not None: 162 raw = self.raw_transform(raw) 163 164 if self.transform is not None: 165 raw = self.transform(raw) 166 if isinstance(raw, list): 167 assert len(raw) == 1 168 raw = raw[0] 169 170 if self.trafo_halo is not None: 171 raw = self.crop(raw) 172 173 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 174 if self.augmentations is not None: 175 aug1, aug2 = self.augmentations 176 raw1, raw2 = aug1(raw), aug2(raw) 177 return raw1, raw2 178 179 return raw 180 181 # need to overwrite pickle to support h5py 182 def __getstate__(self): 183 state = self.__dict__.copy() 184 del state["raw"] 185 return state 186 187 def __setstate__(self, state): 188 raw_path, raw_key = state["raw_path"], state["raw_key"] 189 roi = state["roi"] 190 try: 191 raw = load_data(raw_path, raw_key) 192 if roi is not None: 193 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 194 state["raw"] = raw 195 except Exception: 196 msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 197 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 198 msg += "But it cannot be used for further training and wil throw an error." 199 warnings.warn(msg) 200 state["raw"] = None 201 202 self.__dict__.update(state)
Dataset that provides raw data stored in a container data format for unsupervised training.
The dataset loads a patch from the raw data and returns a sample for a batch.
The dataset supports all file formats that can be opened with elf.io.open_file
, such as hdf5, zarr or n5.
Use raw_path
to specify the path to the file and raw_key
to specify the internal dataset.
It also supports regular image formats, such as .tif. For these cases set raw_key=None
.
The dataset can also be used for contrastive learning that relies on two different views of the same data.
You can use the augmentations
argument for this.
Arguments:
- raw_path: The file path to the raw image data. May also be a list of file paths.
- raw_key: The key to the internal dataset containing the raw data.
- patch_shape: The patch shape for a training sample.
- raw_transform: Transformation applied to the raw data of a sample.
- transform: Transformation to the raw data. This can be used to implement data augmentations.
- roi: Region of interest in the raw data. If given, the raw data will only be loaded from the corresponding area.
- dtype: The return data type of the raw data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths)
. - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- with_channels: Whether the raw data has channels.
- augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. They will be applied to the sampled raw data to return two independent views of the raw data.
52 def __init__( 53 self, 54 raw_path: Union[List[Any], str, os.PathLike], 55 raw_key: Optional[str], 56 patch_shape: Tuple[int, ...], 57 raw_transform: Optional[Callable] = None, 58 transform: Optional[Callable] = None, 59 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 60 dtype: torch.dtype = torch.float32, 61 n_samples: Optional[int] = None, 62 sampler: Optional[Callable] = None, 63 ndim: Optional[int] = None, 64 with_channels: bool = False, 65 augmentations: Optional[Tuple[Callable, Callable]] = None, 66 ): 67 self.raw_path = raw_path 68 self.raw_key = raw_key 69 self.raw = load_data(raw_path, raw_key) 70 71 self._with_channels = with_channels 72 73 if roi is not None: 74 if isinstance(roi, slice): 75 roi = (roi,) 76 77 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 78 79 self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 80 self.roi = roi 81 82 self._ndim = len(self.shape) if ndim is None else ndim 83 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 84 85 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 86 self.patch_shape = patch_shape 87 88 self.raw_transform = raw_transform 89 self.transform = transform 90 self.sampler = sampler 91 self.dtype = dtype 92 93 if augmentations is not None: 94 assert len(augmentations) == 2 95 self.augmentations = augmentations 96 97 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 98 99 self.sample_shape = patch_shape 100 self.trafo_halo = None 101 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 102 # which is then cut. See code below; but this ne needs to be properly tested 103 104 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 105 # if self.trafo_halo is not None: 106 # if len(self.trafo_halo) == 2 and self._ndim == 3: 107 # self.trafo_halo = (0,) + self.trafo_halo 108 # assert len(self.trafo_halo) == self._ndim 109 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 110 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
The maximal number of sampling attempts, for loading a sample via __getitem__
.
This is used when sampler
rejects a sample, to avoid an infinite loop if no valid sample can be found.