torch_em.data.image_collection_dataset
1import os 2import numpy as np 3from typing import List, Optional, Tuple, Union, Callable 4 5import torch 6 7from ..util import ( 8 ensure_spatial_array, ensure_tensor_with_channels, load_image, supports_memmap, ensure_patch_shape 9) 10 11 12class ImageCollectionDataset(torch.utils.data.Dataset): 13 """Dataset that provides raw data and labels stored in a regular image data format for segmentation training. 14 15 The dataset returns patches loaded from the images and labels as sample for a batch. 16 The raw data and labels are expected to be images of the same shape, except for possible channels. 17 It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files. 18 19 Args: 20 raw_image_paths: The file paths to the raw data. 21 label_image_paths: The file path to the label data. 22 patch_shape: The patch shape for a training sample. 23 raw_transform: Transformation applied to the raw data of a sample. 24 label_transform: Transformation applied to the label data of a sample, 25 before applying augmentations via `transform`. 26 label_transform2: Transformation applied to the label data of a sample, 27 after applying augmentations via `transform`. 28 transform: Transformation applied to both the raw data and label data of a sample. 29 This can be used to implement data augmentations. 30 dtype: The return data type of the raw data. 31 label_dtype: The return data type of the label data. 32 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 33 sampler: Sampler for rejecting samples according to a defined criterion. 34 The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input. 35 full_check: Whether to check that the input data is valid for all image paths. 36 This will ensure that the data is valid, but will take longer for creating the dataset. 37 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 38 pre_label_transform: Transformation applied to the label data of a chosen random sample, 39 before applying the sample validity via the `sampler`. 40 """ 41 max_sampling_attempts = 500 42 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 43 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 44 """ 45 max_sampling_attempts_image = 50 46 """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`. 47 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 48 """ 49 50 def _check_inputs(self, raw_images, label_images, full_check): 51 if len(raw_images) != len(label_images): 52 raise ValueError(f"Expect same number of and label images, got {len(raw_images)} and {len(label_images)}") 53 54 if not full_check: 55 return 56 57 is_multichan = None 58 for raw_im, label_im in zip(raw_images, label_images): 59 60 # we only check for compatible shapes if both images support memmap, because 61 # we don't want to load everything into ram 62 if supports_memmap(raw_im) and supports_memmap(label_im): 63 shape = load_image(raw_im).shape 64 assert len(shape) in (2, 3) 65 66 multichan = len(shape) == 3 67 if is_multichan is None: 68 is_multichan = multichan 69 else: 70 assert is_multichan == multichan 71 72 if is_multichan: 73 # use heuristic to decide whether the data is stored in channel last or channel first order: 74 # if the last axis has a length smaller than 16 we assume that it's the channel axis, 75 # otherwise we assume it's a spatial axis and that the first axis is the channel axis. 76 if shape[-1] < 16: 77 shape = shape[:-1] 78 else: 79 shape = shape[1:] 80 81 label_shape = load_image(label_im).shape 82 if shape != label_shape: 83 msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}" 84 raise ValueError(msg) 85 86 def __init__( 87 self, 88 raw_image_paths: List[Union[str, os.PathLike]], 89 label_image_paths: List[Union[str, os.PathLike]], 90 patch_shape: Tuple[int, ...], 91 raw_transform: Optional[Callable] = None, 92 label_transform: Optional[Callable] = None, 93 label_transform2: Optional[Callable] = None, 94 transform: Optional[Callable] = None, 95 dtype: torch.dtype = torch.float32, 96 label_dtype: torch.dtype = torch.float32, 97 n_samples: Optional[int] = None, 98 sampler: Optional[Callable] = None, 99 full_check: bool = False, 100 with_padding: bool = True, 101 pre_label_transform: Optional[Callable] = None, 102 ) -> None: 103 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 104 self.raw_images = raw_image_paths 105 self.label_images = label_image_paths 106 self._ndim = 2 107 108 if patch_shape is not None: 109 assert len(patch_shape) == self._ndim 110 self.patch_shape = patch_shape 111 112 self.raw_transform = raw_transform 113 self.label_transform = label_transform 114 self.label_transform2 = label_transform2 115 self.transform = transform 116 self.sampler = sampler 117 self.with_padding = with_padding 118 self.pre_label_transform = pre_label_transform 119 120 self.dtype = dtype 121 self.label_dtype = label_dtype 122 123 if n_samples is None: 124 self._len = len(self.raw_images) 125 self.sample_random_index = False 126 else: 127 self._len = n_samples 128 self.sample_random_index = True 129 130 def __len__(self): 131 return self._len 132 133 @property 134 def ndim(self): 135 return self._ndim 136 137 def _sample_bounding_box(self, shape): 138 if self.patch_shape is None: 139 patch_shape_for_bb = shape 140 bb_start = [0] * len(shape) 141 else: 142 patch_shape_for_bb = self.patch_shape 143 bb_start = [ 144 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 145 for sh, psh in zip(shape, patch_shape_for_bb) 146 ] 147 148 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 149 150 def _load_data(self, raw_path, label_path): 151 if getattr(self, "have_tensor_data", False): 152 raw, label = raw_path, label_path 153 else: 154 raw = load_image(raw_path, memmap=False) 155 label = load_image(label_path, memmap=False) 156 157 have_raw_channels = getattr(self, "with_channels", raw.ndim == 3) 158 have_label_channels = getattr(self, "with_label_channels", label.ndim == 3) 159 if have_label_channels: 160 raise NotImplementedError("Multi-channel labels are not supported.") 161 162 # We determine if the image has channels as the first or last axis based on the array shape. 163 # This will work only for images with less than 16 channels! 164 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 165 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 166 channel_first = None 167 if have_raw_channels: 168 channel_first = raw.shape[-1] > 16 169 170 if self.patch_shape is not None and self.with_padding: 171 raw, label = ensure_patch_shape( 172 raw=raw, 173 labels=label, 174 patch_shape=self.patch_shape, 175 have_raw_channels=have_raw_channels, 176 have_label_channels=have_label_channels, 177 channel_first=channel_first 178 ) 179 180 shape = raw.shape 181 182 prefix_box = tuple() 183 if have_raw_channels: 184 if channel_first: 185 shape = shape[1:] 186 prefix_box = (slice(None), ) 187 else: 188 shape = shape[:-1] 189 190 return raw, label, shape, prefix_box, have_raw_channels 191 192 def _get_desired_raw_and_labels(self, raw, label, shape, prefix_box): 193 bb = self._sample_bounding_box(shape) 194 raw_patch = np.array(raw[prefix_box + bb]) 195 label_patch = np.array(label[bb]) 196 197 # Additional label transform on top to make sampler consider expected labels 198 # (eg. run connected components on disconnected semantic labels) 199 pre_label_transform = getattr(self, "pre_label_transform", None) 200 if pre_label_transform is not None: 201 label_patch = pre_label_transform(label_patch) 202 203 return raw_patch, label_patch 204 205 def _get_sample(self, index): 206 if self.sample_random_index: 207 index = np.random.randint(0, len(self.raw_images)) 208 209 # The filepath corresponding to this image. 210 raw_path, label_path = self.raw_images[index], self.label_images[index] 211 212 # Load the corresponding data. 213 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 214 215 # Sample random bounding box for this image. 216 raw_patch, label_patch = self._get_desired_raw_and_labels(raw, label, shape, prefix_box) 217 218 if self.sampler is not None: 219 sample_id = 0 220 while not self.sampler(raw_patch, label_patch): 221 raw_patch, label_patch = self._get_desired_raw_and_labels(raw, label, shape, prefix_box) 222 sample_id += 1 223 224 # We need to avoid sampling from the same image over and over again, 225 # otherwise this will fail just because of one or a few empty images. 226 # Hence we update the image from which we sample sometimes. 227 if sample_id % self.max_sampling_attempts_image == 0: 228 index = np.random.randint(0, len(self.raw_images)) 229 raw_path, label_path = self.raw_images[index], self.label_images[index] 230 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 231 232 if sample_id > self.max_sampling_attempts: 233 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 234 235 # to channel first 236 if have_raw_channels and len(prefix_box) == 0: 237 raw_patch = raw_patch.transpose((2, 0, 1)) 238 239 return raw_patch, label_patch 240 241 def __getitem__(self, index): 242 raw, labels = self._get_sample(index) 243 initial_label_dtype = labels.dtype 244 245 if self.raw_transform is not None: 246 raw = self.raw_transform(raw) 247 248 if self.label_transform is not None: 249 labels = self.label_transform(labels) 250 251 if self.transform is not None: 252 raw, labels = self.transform(raw, labels) 253 # if self.trafo_halo is not None: 254 # raw = self.crop(raw) 255 # labels = self.crop(labels) 256 257 # support enlarging bounding box here as well (for affinity transform) ? 258 if self.label_transform2 is not None: 259 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 260 labels = self.label_transform2(labels) 261 262 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 263 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 264 return raw, labels
13class ImageCollectionDataset(torch.utils.data.Dataset): 14 """Dataset that provides raw data and labels stored in a regular image data format for segmentation training. 15 16 The dataset returns patches loaded from the images and labels as sample for a batch. 17 The raw data and labels are expected to be images of the same shape, except for possible channels. 18 It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files. 19 20 Args: 21 raw_image_paths: The file paths to the raw data. 22 label_image_paths: The file path to the label data. 23 patch_shape: The patch shape for a training sample. 24 raw_transform: Transformation applied to the raw data of a sample. 25 label_transform: Transformation applied to the label data of a sample, 26 before applying augmentations via `transform`. 27 label_transform2: Transformation applied to the label data of a sample, 28 after applying augmentations via `transform`. 29 transform: Transformation applied to both the raw data and label data of a sample. 30 This can be used to implement data augmentations. 31 dtype: The return data type of the raw data. 32 label_dtype: The return data type of the label data. 33 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 34 sampler: Sampler for rejecting samples according to a defined criterion. 35 The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input. 36 full_check: Whether to check that the input data is valid for all image paths. 37 This will ensure that the data is valid, but will take longer for creating the dataset. 38 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 39 pre_label_transform: Transformation applied to the label data of a chosen random sample, 40 before applying the sample validity via the `sampler`. 41 """ 42 max_sampling_attempts = 500 43 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 44 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 45 """ 46 max_sampling_attempts_image = 50 47 """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`. 48 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 49 """ 50 51 def _check_inputs(self, raw_images, label_images, full_check): 52 if len(raw_images) != len(label_images): 53 raise ValueError(f"Expect same number of and label images, got {len(raw_images)} and {len(label_images)}") 54 55 if not full_check: 56 return 57 58 is_multichan = None 59 for raw_im, label_im in zip(raw_images, label_images): 60 61 # we only check for compatible shapes if both images support memmap, because 62 # we don't want to load everything into ram 63 if supports_memmap(raw_im) and supports_memmap(label_im): 64 shape = load_image(raw_im).shape 65 assert len(shape) in (2, 3) 66 67 multichan = len(shape) == 3 68 if is_multichan is None: 69 is_multichan = multichan 70 else: 71 assert is_multichan == multichan 72 73 if is_multichan: 74 # use heuristic to decide whether the data is stored in channel last or channel first order: 75 # if the last axis has a length smaller than 16 we assume that it's the channel axis, 76 # otherwise we assume it's a spatial axis and that the first axis is the channel axis. 77 if shape[-1] < 16: 78 shape = shape[:-1] 79 else: 80 shape = shape[1:] 81 82 label_shape = load_image(label_im).shape 83 if shape != label_shape: 84 msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}" 85 raise ValueError(msg) 86 87 def __init__( 88 self, 89 raw_image_paths: List[Union[str, os.PathLike]], 90 label_image_paths: List[Union[str, os.PathLike]], 91 patch_shape: Tuple[int, ...], 92 raw_transform: Optional[Callable] = None, 93 label_transform: Optional[Callable] = None, 94 label_transform2: Optional[Callable] = None, 95 transform: Optional[Callable] = None, 96 dtype: torch.dtype = torch.float32, 97 label_dtype: torch.dtype = torch.float32, 98 n_samples: Optional[int] = None, 99 sampler: Optional[Callable] = None, 100 full_check: bool = False, 101 with_padding: bool = True, 102 pre_label_transform: Optional[Callable] = None, 103 ) -> None: 104 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 105 self.raw_images = raw_image_paths 106 self.label_images = label_image_paths 107 self._ndim = 2 108 109 if patch_shape is not None: 110 assert len(patch_shape) == self._ndim 111 self.patch_shape = patch_shape 112 113 self.raw_transform = raw_transform 114 self.label_transform = label_transform 115 self.label_transform2 = label_transform2 116 self.transform = transform 117 self.sampler = sampler 118 self.with_padding = with_padding 119 self.pre_label_transform = pre_label_transform 120 121 self.dtype = dtype 122 self.label_dtype = label_dtype 123 124 if n_samples is None: 125 self._len = len(self.raw_images) 126 self.sample_random_index = False 127 else: 128 self._len = n_samples 129 self.sample_random_index = True 130 131 def __len__(self): 132 return self._len 133 134 @property 135 def ndim(self): 136 return self._ndim 137 138 def _sample_bounding_box(self, shape): 139 if self.patch_shape is None: 140 patch_shape_for_bb = shape 141 bb_start = [0] * len(shape) 142 else: 143 patch_shape_for_bb = self.patch_shape 144 bb_start = [ 145 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 146 for sh, psh in zip(shape, patch_shape_for_bb) 147 ] 148 149 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 150 151 def _load_data(self, raw_path, label_path): 152 if getattr(self, "have_tensor_data", False): 153 raw, label = raw_path, label_path 154 else: 155 raw = load_image(raw_path, memmap=False) 156 label = load_image(label_path, memmap=False) 157 158 have_raw_channels = getattr(self, "with_channels", raw.ndim == 3) 159 have_label_channels = getattr(self, "with_label_channels", label.ndim == 3) 160 if have_label_channels: 161 raise NotImplementedError("Multi-channel labels are not supported.") 162 163 # We determine if the image has channels as the first or last axis based on the array shape. 164 # This will work only for images with less than 16 channels! 165 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 166 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 167 channel_first = None 168 if have_raw_channels: 169 channel_first = raw.shape[-1] > 16 170 171 if self.patch_shape is not None and self.with_padding: 172 raw, label = ensure_patch_shape( 173 raw=raw, 174 labels=label, 175 patch_shape=self.patch_shape, 176 have_raw_channels=have_raw_channels, 177 have_label_channels=have_label_channels, 178 channel_first=channel_first 179 ) 180 181 shape = raw.shape 182 183 prefix_box = tuple() 184 if have_raw_channels: 185 if channel_first: 186 shape = shape[1:] 187 prefix_box = (slice(None), ) 188 else: 189 shape = shape[:-1] 190 191 return raw, label, shape, prefix_box, have_raw_channels 192 193 def _get_desired_raw_and_labels(self, raw, label, shape, prefix_box): 194 bb = self._sample_bounding_box(shape) 195 raw_patch = np.array(raw[prefix_box + bb]) 196 label_patch = np.array(label[bb]) 197 198 # Additional label transform on top to make sampler consider expected labels 199 # (eg. run connected components on disconnected semantic labels) 200 pre_label_transform = getattr(self, "pre_label_transform", None) 201 if pre_label_transform is not None: 202 label_patch = pre_label_transform(label_patch) 203 204 return raw_patch, label_patch 205 206 def _get_sample(self, index): 207 if self.sample_random_index: 208 index = np.random.randint(0, len(self.raw_images)) 209 210 # The filepath corresponding to this image. 211 raw_path, label_path = self.raw_images[index], self.label_images[index] 212 213 # Load the corresponding data. 214 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 215 216 # Sample random bounding box for this image. 217 raw_patch, label_patch = self._get_desired_raw_and_labels(raw, label, shape, prefix_box) 218 219 if self.sampler is not None: 220 sample_id = 0 221 while not self.sampler(raw_patch, label_patch): 222 raw_patch, label_patch = self._get_desired_raw_and_labels(raw, label, shape, prefix_box) 223 sample_id += 1 224 225 # We need to avoid sampling from the same image over and over again, 226 # otherwise this will fail just because of one or a few empty images. 227 # Hence we update the image from which we sample sometimes. 228 if sample_id % self.max_sampling_attempts_image == 0: 229 index = np.random.randint(0, len(self.raw_images)) 230 raw_path, label_path = self.raw_images[index], self.label_images[index] 231 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 232 233 if sample_id > self.max_sampling_attempts: 234 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 235 236 # to channel first 237 if have_raw_channels and len(prefix_box) == 0: 238 raw_patch = raw_patch.transpose((2, 0, 1)) 239 240 return raw_patch, label_patch 241 242 def __getitem__(self, index): 243 raw, labels = self._get_sample(index) 244 initial_label_dtype = labels.dtype 245 246 if self.raw_transform is not None: 247 raw = self.raw_transform(raw) 248 249 if self.label_transform is not None: 250 labels = self.label_transform(labels) 251 252 if self.transform is not None: 253 raw, labels = self.transform(raw, labels) 254 # if self.trafo_halo is not None: 255 # raw = self.crop(raw) 256 # labels = self.crop(labels) 257 258 # support enlarging bounding box here as well (for affinity transform) ? 259 if self.label_transform2 is not None: 260 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 261 labels = self.label_transform2(labels) 262 263 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 264 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 265 return raw, labels
Dataset that provides raw data and labels stored in a regular image data format for segmentation training.
The dataset returns patches loaded from the images and labels as sample for a batch. The raw data and labels are expected to be images of the same shape, except for possible channels. It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
Arguments:
- raw_image_paths: The file paths to the raw data.
- label_image_paths: The file path to the label data.
- patch_shape: The patch shape for a training sample.
- raw_transform: Transformation applied to the raw data of a sample.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths). - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
- full_check: Whether to check that the input data is valid for all image paths. This will ensure that the data is valid, but will take longer for creating the dataset.
- with_padding: Whether to pad samples to
patch_shapeif their shape is smaller. - pre_label_transform: Transformation applied to the label data of a chosen random sample,
before applying the sample validity via the
sampler.
87 def __init__( 88 self, 89 raw_image_paths: List[Union[str, os.PathLike]], 90 label_image_paths: List[Union[str, os.PathLike]], 91 patch_shape: Tuple[int, ...], 92 raw_transform: Optional[Callable] = None, 93 label_transform: Optional[Callable] = None, 94 label_transform2: Optional[Callable] = None, 95 transform: Optional[Callable] = None, 96 dtype: torch.dtype = torch.float32, 97 label_dtype: torch.dtype = torch.float32, 98 n_samples: Optional[int] = None, 99 sampler: Optional[Callable] = None, 100 full_check: bool = False, 101 with_padding: bool = True, 102 pre_label_transform: Optional[Callable] = None, 103 ) -> None: 104 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 105 self.raw_images = raw_image_paths 106 self.label_images = label_image_paths 107 self._ndim = 2 108 109 if patch_shape is not None: 110 assert len(patch_shape) == self._ndim 111 self.patch_shape = patch_shape 112 113 self.raw_transform = raw_transform 114 self.label_transform = label_transform 115 self.label_transform2 = label_transform2 116 self.transform = transform 117 self.sampler = sampler 118 self.with_padding = with_padding 119 self.pre_label_transform = pre_label_transform 120 121 self.dtype = dtype 122 self.label_dtype = label_dtype 123 124 if n_samples is None: 125 self._len = len(self.raw_images) 126 self.sample_random_index = False 127 else: 128 self._len = n_samples 129 self.sample_random_index = True
The maximal number of sampling attempts, for loading a sample via __getitem__.
This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.
The maximal number of sampling attempts for a single image, for loading a sample via __getitem__.
This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.