torch_em.data.image_collection_dataset
1import os 2import numpy as np 3from typing import List, Optional, Tuple, Union, Callable 4 5import torch 6 7from ..util import ( 8 ensure_spatial_array, ensure_tensor_with_channels, load_image, supports_memmap, ensure_patch_shape 9) 10 11 12class ImageCollectionDataset(torch.utils.data.Dataset): 13 """Dataset that provides raw data and labels stored in a regular image data format for segmentation training. 14 15 The dataset returns patches loaded from the images and labels as sample for a batch. 16 The raw data and labels are expected to be images of the same shape, except for possible channels. 17 It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files. 18 19 Args: 20 raw_image_paths: The file paths to the raw data. 21 label_image_paths: The file path to the label data. 22 patch_shape: The patch shape for a training sample. 23 raw_transform: Transformation applied to the raw data of a sample. 24 label_transform: Transformation applied to the label data of a sample, 25 before applying augmentations via `transform`. 26 label_transform2: Transformation applied to the label data of a sample, 27 after applying augmentations via `transform`. 28 transform: Transformation applied to both the raw data and label data of a sample. 29 This can be used to implement data augmentations. 30 dtype: The return data type of the raw data. 31 label_dtype: The return data type of the label data. 32 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 33 sampler: Sampler for rejecting samples according to a defined criterion. 34 The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input. 35 full_check: Whether to check that the input data is valid for all image paths. 36 This will ensure that the data is valid, but will take longer for creating the dataset. 37 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 38 """ 39 max_sampling_attempts = 500 40 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 41 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 42 """ 43 max_sampling_attempts_image = 50 44 """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`. 45 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 46 """ 47 48 def _check_inputs(self, raw_images, label_images, full_check): 49 if len(raw_images) != len(label_images): 50 raise ValueError(f"Expect same number of and label images, got {len(raw_images)} and {len(label_images)}") 51 52 if not full_check: 53 return 54 55 is_multichan = None 56 for raw_im, label_im in zip(raw_images, label_images): 57 58 # we only check for compatible shapes if both images support memmap, because 59 # we don't want to load everything into ram 60 if supports_memmap(raw_im) and supports_memmap(label_im): 61 shape = load_image(raw_im).shape 62 assert len(shape) in (2, 3) 63 64 multichan = len(shape) == 3 65 if is_multichan is None: 66 is_multichan = multichan 67 else: 68 assert is_multichan == multichan 69 70 if is_multichan: 71 # use heuristic to decide whether the data is stored in channel last or channel first order: 72 # if the last axis has a length smaller than 16 we assume that it's the channel axis, 73 # otherwise we assume it's a spatial axis and that the first axis is the channel axis. 74 if shape[-1] < 16: 75 shape = shape[:-1] 76 else: 77 shape = shape[1:] 78 79 label_shape = load_image(label_im).shape 80 if shape != label_shape: 81 msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}" 82 raise ValueError(msg) 83 84 def __init__( 85 self, 86 raw_image_paths: List[Union[str, os.PathLike]], 87 label_image_paths: List[Union[str, os.PathLike]], 88 patch_shape: Tuple[int, ...], 89 raw_transform: Optional[Callable] = None, 90 label_transform: Optional[Callable] = None, 91 label_transform2: Optional[Callable] = None, 92 transform: Optional[Callable] = None, 93 dtype: torch.dtype = torch.float32, 94 label_dtype: torch.dtype = torch.float32, 95 n_samples: Optional[int] = None, 96 sampler: Optional[Callable] = None, 97 full_check: bool = False, 98 with_padding: bool = True, 99 ) -> None: 100 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 101 self.raw_images = raw_image_paths 102 self.label_images = label_image_paths 103 self._ndim = 2 104 105 if patch_shape is not None: 106 assert len(patch_shape) == self._ndim 107 self.patch_shape = patch_shape 108 109 self.raw_transform = raw_transform 110 self.label_transform = label_transform 111 self.label_transform2 = label_transform2 112 self.transform = transform 113 self.sampler = sampler 114 self.with_padding = with_padding 115 116 self.dtype = dtype 117 self.label_dtype = label_dtype 118 119 if n_samples is None: 120 self._len = len(self.raw_images) 121 self.sample_random_index = False 122 else: 123 self._len = n_samples 124 self.sample_random_index = True 125 126 def __len__(self): 127 return self._len 128 129 @property 130 def ndim(self): 131 return self._ndim 132 133 def _sample_bounding_box(self, shape): 134 if self.patch_shape is None: 135 patch_shape_for_bb = shape 136 bb_start = [0] * len(shape) 137 else: 138 patch_shape_for_bb = self.patch_shape 139 bb_start = [ 140 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 141 for sh, psh in zip(shape, patch_shape_for_bb) 142 ] 143 144 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 145 146 def _load_data(self, raw_path, label_path): 147 if getattr(self, "have_tensor_data", False): 148 raw, label = raw_path, label_path 149 else: 150 raw = load_image(raw_path, memmap=False) 151 label = load_image(label_path, memmap=False) 152 153 have_raw_channels = getattr(self, "with_channels", raw.ndim == 3) 154 have_label_channels = getattr(self, "with_label_channels", label.ndim == 3) 155 if have_label_channels: 156 raise NotImplementedError("Multi-channel labels are not supported.") 157 158 # We determine if the image has channels as the first or last axis based on the array shape. 159 # This will work only for images with less than 16 channels! 160 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 161 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 162 channel_first = None 163 if have_raw_channels: 164 channel_first = raw.shape[-1] > 16 165 166 if self.patch_shape is not None and self.with_padding: 167 raw, label = ensure_patch_shape( 168 raw=raw, 169 labels=label, 170 patch_shape=self.patch_shape, 171 have_raw_channels=have_raw_channels, 172 have_label_channels=have_label_channels, 173 channel_first=channel_first 174 ) 175 176 shape = raw.shape 177 178 prefix_box = tuple() 179 if have_raw_channels: 180 if channel_first: 181 shape = shape[1:] 182 prefix_box = (slice(None), ) 183 else: 184 shape = shape[:-1] 185 186 return raw, label, shape, prefix_box, have_raw_channels 187 188 def _get_sample(self, index): 189 if self.sample_random_index: 190 index = np.random.randint(0, len(self.raw_images)) 191 192 # The filepath corresponding to this image. 193 raw_path, label_path = self.raw_images[index], self.label_images[index] 194 195 # Load the corresponding data. 196 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 197 198 # Sample random bounding box for this image. 199 bb = self._sample_bounding_box(shape) 200 raw_patch = np.array(raw[prefix_box + bb]) 201 label_patch = np.array(label[bb]) 202 203 if self.sampler is not None: 204 sample_id = 0 205 while not self.sampler(raw_patch, label_patch): 206 bb = self._sample_bounding_box(shape) 207 raw_patch = np.array(raw[prefix_box + bb]) 208 label_patch = np.array(label[bb]) 209 sample_id += 1 210 211 # We need to avoid sampling from the same image over and over again, 212 # otherwise this will fail just because of one or a few empty images. 213 # Hence we update the image from which we sample sometimes. 214 if sample_id % self.max_sampling_attempts_image == 0: 215 index = np.random.randint(0, len(self.raw_images)) 216 raw_path, label_path = self.raw_images[index], self.label_images[index] 217 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 218 219 if sample_id > self.max_sampling_attempts: 220 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 221 222 # to channel first 223 if have_raw_channels and len(prefix_box) == 0: 224 raw_patch = raw_patch.transpose((2, 0, 1)) 225 226 return raw_patch, label_patch 227 228 def __getitem__(self, index): 229 raw, labels = self._get_sample(index) 230 initial_label_dtype = labels.dtype 231 232 if self.raw_transform is not None: 233 raw = self.raw_transform(raw) 234 235 if self.label_transform is not None: 236 labels = self.label_transform(labels) 237 238 if self.transform is not None: 239 raw, labels = self.transform(raw, labels) 240 # if self.trafo_halo is not None: 241 # raw = self.crop(raw) 242 # labels = self.crop(labels) 243 244 # support enlarging bounding box here as well (for affinity transform) ? 245 if self.label_transform2 is not None: 246 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 247 labels = self.label_transform2(labels) 248 249 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 250 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 251 return raw, labels
13class ImageCollectionDataset(torch.utils.data.Dataset): 14 """Dataset that provides raw data and labels stored in a regular image data format for segmentation training. 15 16 The dataset returns patches loaded from the images and labels as sample for a batch. 17 The raw data and labels are expected to be images of the same shape, except for possible channels. 18 It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files. 19 20 Args: 21 raw_image_paths: The file paths to the raw data. 22 label_image_paths: The file path to the label data. 23 patch_shape: The patch shape for a training sample. 24 raw_transform: Transformation applied to the raw data of a sample. 25 label_transform: Transformation applied to the label data of a sample, 26 before applying augmentations via `transform`. 27 label_transform2: Transformation applied to the label data of a sample, 28 after applying augmentations via `transform`. 29 transform: Transformation applied to both the raw data and label data of a sample. 30 This can be used to implement data augmentations. 31 dtype: The return data type of the raw data. 32 label_dtype: The return data type of the label data. 33 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 34 sampler: Sampler for rejecting samples according to a defined criterion. 35 The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input. 36 full_check: Whether to check that the input data is valid for all image paths. 37 This will ensure that the data is valid, but will take longer for creating the dataset. 38 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 39 """ 40 max_sampling_attempts = 500 41 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 42 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 43 """ 44 max_sampling_attempts_image = 50 45 """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`. 46 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 47 """ 48 49 def _check_inputs(self, raw_images, label_images, full_check): 50 if len(raw_images) != len(label_images): 51 raise ValueError(f"Expect same number of and label images, got {len(raw_images)} and {len(label_images)}") 52 53 if not full_check: 54 return 55 56 is_multichan = None 57 for raw_im, label_im in zip(raw_images, label_images): 58 59 # we only check for compatible shapes if both images support memmap, because 60 # we don't want to load everything into ram 61 if supports_memmap(raw_im) and supports_memmap(label_im): 62 shape = load_image(raw_im).shape 63 assert len(shape) in (2, 3) 64 65 multichan = len(shape) == 3 66 if is_multichan is None: 67 is_multichan = multichan 68 else: 69 assert is_multichan == multichan 70 71 if is_multichan: 72 # use heuristic to decide whether the data is stored in channel last or channel first order: 73 # if the last axis has a length smaller than 16 we assume that it's the channel axis, 74 # otherwise we assume it's a spatial axis and that the first axis is the channel axis. 75 if shape[-1] < 16: 76 shape = shape[:-1] 77 else: 78 shape = shape[1:] 79 80 label_shape = load_image(label_im).shape 81 if shape != label_shape: 82 msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}" 83 raise ValueError(msg) 84 85 def __init__( 86 self, 87 raw_image_paths: List[Union[str, os.PathLike]], 88 label_image_paths: List[Union[str, os.PathLike]], 89 patch_shape: Tuple[int, ...], 90 raw_transform: Optional[Callable] = None, 91 label_transform: Optional[Callable] = None, 92 label_transform2: Optional[Callable] = None, 93 transform: Optional[Callable] = None, 94 dtype: torch.dtype = torch.float32, 95 label_dtype: torch.dtype = torch.float32, 96 n_samples: Optional[int] = None, 97 sampler: Optional[Callable] = None, 98 full_check: bool = False, 99 with_padding: bool = True, 100 ) -> None: 101 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 102 self.raw_images = raw_image_paths 103 self.label_images = label_image_paths 104 self._ndim = 2 105 106 if patch_shape is not None: 107 assert len(patch_shape) == self._ndim 108 self.patch_shape = patch_shape 109 110 self.raw_transform = raw_transform 111 self.label_transform = label_transform 112 self.label_transform2 = label_transform2 113 self.transform = transform 114 self.sampler = sampler 115 self.with_padding = with_padding 116 117 self.dtype = dtype 118 self.label_dtype = label_dtype 119 120 if n_samples is None: 121 self._len = len(self.raw_images) 122 self.sample_random_index = False 123 else: 124 self._len = n_samples 125 self.sample_random_index = True 126 127 def __len__(self): 128 return self._len 129 130 @property 131 def ndim(self): 132 return self._ndim 133 134 def _sample_bounding_box(self, shape): 135 if self.patch_shape is None: 136 patch_shape_for_bb = shape 137 bb_start = [0] * len(shape) 138 else: 139 patch_shape_for_bb = self.patch_shape 140 bb_start = [ 141 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 142 for sh, psh in zip(shape, patch_shape_for_bb) 143 ] 144 145 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 146 147 def _load_data(self, raw_path, label_path): 148 if getattr(self, "have_tensor_data", False): 149 raw, label = raw_path, label_path 150 else: 151 raw = load_image(raw_path, memmap=False) 152 label = load_image(label_path, memmap=False) 153 154 have_raw_channels = getattr(self, "with_channels", raw.ndim == 3) 155 have_label_channels = getattr(self, "with_label_channels", label.ndim == 3) 156 if have_label_channels: 157 raise NotImplementedError("Multi-channel labels are not supported.") 158 159 # We determine if the image has channels as the first or last axis based on the array shape. 160 # This will work only for images with less than 16 channels! 161 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 162 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 163 channel_first = None 164 if have_raw_channels: 165 channel_first = raw.shape[-1] > 16 166 167 if self.patch_shape is not None and self.with_padding: 168 raw, label = ensure_patch_shape( 169 raw=raw, 170 labels=label, 171 patch_shape=self.patch_shape, 172 have_raw_channels=have_raw_channels, 173 have_label_channels=have_label_channels, 174 channel_first=channel_first 175 ) 176 177 shape = raw.shape 178 179 prefix_box = tuple() 180 if have_raw_channels: 181 if channel_first: 182 shape = shape[1:] 183 prefix_box = (slice(None), ) 184 else: 185 shape = shape[:-1] 186 187 return raw, label, shape, prefix_box, have_raw_channels 188 189 def _get_sample(self, index): 190 if self.sample_random_index: 191 index = np.random.randint(0, len(self.raw_images)) 192 193 # The filepath corresponding to this image. 194 raw_path, label_path = self.raw_images[index], self.label_images[index] 195 196 # Load the corresponding data. 197 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 198 199 # Sample random bounding box for this image. 200 bb = self._sample_bounding_box(shape) 201 raw_patch = np.array(raw[prefix_box + bb]) 202 label_patch = np.array(label[bb]) 203 204 if self.sampler is not None: 205 sample_id = 0 206 while not self.sampler(raw_patch, label_patch): 207 bb = self._sample_bounding_box(shape) 208 raw_patch = np.array(raw[prefix_box + bb]) 209 label_patch = np.array(label[bb]) 210 sample_id += 1 211 212 # We need to avoid sampling from the same image over and over again, 213 # otherwise this will fail just because of one or a few empty images. 214 # Hence we update the image from which we sample sometimes. 215 if sample_id % self.max_sampling_attempts_image == 0: 216 index = np.random.randint(0, len(self.raw_images)) 217 raw_path, label_path = self.raw_images[index], self.label_images[index] 218 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 219 220 if sample_id > self.max_sampling_attempts: 221 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 222 223 # to channel first 224 if have_raw_channels and len(prefix_box) == 0: 225 raw_patch = raw_patch.transpose((2, 0, 1)) 226 227 return raw_patch, label_patch 228 229 def __getitem__(self, index): 230 raw, labels = self._get_sample(index) 231 initial_label_dtype = labels.dtype 232 233 if self.raw_transform is not None: 234 raw = self.raw_transform(raw) 235 236 if self.label_transform is not None: 237 labels = self.label_transform(labels) 238 239 if self.transform is not None: 240 raw, labels = self.transform(raw, labels) 241 # if self.trafo_halo is not None: 242 # raw = self.crop(raw) 243 # labels = self.crop(labels) 244 245 # support enlarging bounding box here as well (for affinity transform) ? 246 if self.label_transform2 is not None: 247 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 248 labels = self.label_transform2(labels) 249 250 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 251 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 252 return raw, labels
Dataset that provides raw data and labels stored in a regular image data format for segmentation training.
The dataset returns patches loaded from the images and labels as sample for a batch. The raw data and labels are expected to be images of the same shape, except for possible channels. It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
Arguments:
- raw_image_paths: The file paths to the raw data.
- label_image_paths: The file path to the label data.
- patch_shape: The patch shape for a training sample.
- raw_transform: Transformation applied to the raw data of a sample.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths). - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
- full_check: Whether to check that the input data is valid for all image paths. This will ensure that the data is valid, but will take longer for creating the dataset.
- with_padding: Whether to pad samples to
patch_shapeif their shape is smaller.
85 def __init__( 86 self, 87 raw_image_paths: List[Union[str, os.PathLike]], 88 label_image_paths: List[Union[str, os.PathLike]], 89 patch_shape: Tuple[int, ...], 90 raw_transform: Optional[Callable] = None, 91 label_transform: Optional[Callable] = None, 92 label_transform2: Optional[Callable] = None, 93 transform: Optional[Callable] = None, 94 dtype: torch.dtype = torch.float32, 95 label_dtype: torch.dtype = torch.float32, 96 n_samples: Optional[int] = None, 97 sampler: Optional[Callable] = None, 98 full_check: bool = False, 99 with_padding: bool = True, 100 ) -> None: 101 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 102 self.raw_images = raw_image_paths 103 self.label_images = label_image_paths 104 self._ndim = 2 105 106 if patch_shape is not None: 107 assert len(patch_shape) == self._ndim 108 self.patch_shape = patch_shape 109 110 self.raw_transform = raw_transform 111 self.label_transform = label_transform 112 self.label_transform2 = label_transform2 113 self.transform = transform 114 self.sampler = sampler 115 self.with_padding = with_padding 116 117 self.dtype = dtype 118 self.label_dtype = label_dtype 119 120 if n_samples is None: 121 self._len = len(self.raw_images) 122 self.sample_random_index = False 123 else: 124 self._len = n_samples 125 self.sample_random_index = True
The maximal number of sampling attempts, for loading a sample via __getitem__.
This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.
The maximal number of sampling attempts for a single image, for loading a sample via __getitem__.
This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.