torch_em.data.image_collection_dataset
1import os 2import numpy as np 3from typing import List, Optional, Tuple, Union, Callable 4 5import torch 6 7from ..util import ( 8 ensure_spatial_array, ensure_tensor_with_channels, load_image, supports_memmap, ensure_patch_shape 9) 10 11 12class ImageCollectionDataset(torch.utils.data.Dataset): 13 """Dataset that provides raw data and labels stored in a regular image data format for segmentation training. 14 15 The dataset returns patches loaded from the images and labels as sample for a batch. 16 The raw data and labels are expected to be images of the same shape, except for possible channels. 17 It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files. 18 19 Args: 20 raw_image_paths: The file paths to the raw data. 21 label_image_paths: The file path to the label data. 22 patch_shape: The patch shape for a training sample. 23 raw_transform: Transformation applied to the raw data of a sample. 24 label_transform: Transformation applied to the label data of a sample, 25 before applying augmentations via `transform`. 26 label_transform2: Transformation applied to the label data of a sample, 27 after applying augmentations via `transform`. 28 transform: Transformation applied to both the raw data and label data of a sample. 29 This can be used to implement data augmentations. 30 dtype: The return data type of the raw data. 31 label_dtype: The return data type of the label data. 32 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 33 sampler: Sampler for rejecting samples according to a defined criterion. 34 The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input. 35 full_check: Whether to check that the input data is valid for all image paths. 36 This will ensure that the data is valid, but will take longer for creating the dataset. 37 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 38 """ 39 max_sampling_attempts = 500 40 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 41 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 42 """ 43 max_sampling_attempts_image = 50 44 """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`. 45 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 46 """ 47 48 def _check_inputs(self, raw_images, label_images, full_check): 49 if len(raw_images) != len(label_images): 50 raise ValueError(f"Expect same number of and label images, got {len(raw_images)} and {len(label_images)}") 51 52 if not full_check: 53 return 54 55 is_multichan = None 56 for raw_im, label_im in zip(raw_images, label_images): 57 58 # we only check for compatible shapes if both images support memmap, because 59 # we don't want to load everything into ram 60 if supports_memmap(raw_im) and supports_memmap(label_im): 61 shape = load_image(raw_im).shape 62 assert len(shape) in (2, 3) 63 64 multichan = len(shape) == 3 65 if is_multichan is None: 66 is_multichan = multichan 67 else: 68 assert is_multichan == multichan 69 70 if is_multichan: 71 # use heuristic to decide whether the data is stored in channel last or channel first order: 72 # if the last axis has a length smaller than 16 we assume that it's the channel axis, 73 # otherwise we assume it's a spatial axis and that the first axis is the channel axis. 74 if shape[-1] < 16: 75 shape = shape[:-1] 76 else: 77 shape = shape[1:] 78 79 label_shape = load_image(label_im).shape 80 if shape != label_shape: 81 msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}" 82 raise ValueError(msg) 83 84 def __init__( 85 self, 86 raw_image_paths: List[Union[str, os.PathLike]], 87 label_image_paths: List[Union[str, os.PathLike]], 88 patch_shape: Tuple[int, ...], 89 raw_transform: Optional[Callable] = None, 90 label_transform: Optional[Callable] = None, 91 label_transform2: Optional[Callable] = None, 92 transform: Optional[Callable] = None, 93 dtype: torch.dtype = torch.float32, 94 label_dtype: torch.dtype = torch.float32, 95 n_samples: Optional[int] = None, 96 sampler: Optional[Callable] = None, 97 full_check: bool = False, 98 with_padding: bool = True, 99 ): 100 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 101 self.raw_images = raw_image_paths 102 self.label_images = label_image_paths 103 self._ndim = 2 104 105 if patch_shape is not None: 106 assert len(patch_shape) == self._ndim 107 self.patch_shape = patch_shape 108 109 self.raw_transform = raw_transform 110 self.label_transform = label_transform 111 self.label_transform2 = label_transform2 112 self.transform = transform 113 self.sampler = sampler 114 self.with_padding = with_padding 115 116 self.dtype = dtype 117 self.label_dtype = label_dtype 118 119 if n_samples is None: 120 self._len = len(self.raw_images) 121 self.sample_random_index = False 122 else: 123 self._len = n_samples 124 self.sample_random_index = True 125 126 def __len__(self): 127 return self._len 128 129 @property 130 def ndim(self): 131 return self._ndim 132 133 def _sample_bounding_box(self, shape): 134 if self.patch_shape is None: 135 patch_shape_for_bb = shape 136 bb_start = [0] * len(shape) 137 else: 138 patch_shape_for_bb = self.patch_shape 139 bb_start = [ 140 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 141 for sh, psh in zip(shape, patch_shape_for_bb) 142 ] 143 144 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 145 146 def _load_data(self, raw_path, label_path): 147 raw = load_image(raw_path, memmap=False) 148 label = load_image(label_path, memmap=False) 149 150 have_raw_channels = raw.ndim == 3 151 have_label_channels = label.ndim == 3 152 if have_label_channels: 153 raise NotImplementedError("Multi-channel labels are not supported.") 154 155 # We determine if the image has channels as the first or last axis based on the array shape. 156 # This will work only for images with less than 16 channels! 157 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 158 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 159 channel_first = None 160 if have_raw_channels: 161 channel_first = raw.shape[-1] > 16 162 163 if self.patch_shape is not None and self.with_padding: 164 raw, label = ensure_patch_shape( 165 raw=raw, 166 labels=label, 167 patch_shape=self.patch_shape, 168 have_raw_channels=have_raw_channels, 169 have_label_channels=have_label_channels, 170 channel_first=channel_first 171 ) 172 173 shape = raw.shape 174 175 prefix_box = tuple() 176 if have_raw_channels: 177 if channel_first: 178 shape = shape[1:] 179 prefix_box = (slice(None), ) 180 else: 181 shape = shape[:-1] 182 183 return raw, label, shape, prefix_box, have_raw_channels 184 185 def _get_sample(self, index): 186 if self.sample_random_index: 187 index = np.random.randint(0, len(self.raw_images)) 188 189 # The filepath corresponding to this image. 190 raw_path, label_path = self.raw_images[index], self.label_images[index] 191 192 # Load the corresponding data. 193 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 194 195 # Sample random bounding box for this image. 196 bb = self._sample_bounding_box(shape) 197 raw_patch = np.array(raw[prefix_box + bb]) 198 label_patch = np.array(label[bb]) 199 200 if self.sampler is not None: 201 sample_id = 0 202 while not self.sampler(raw_patch, label_patch): 203 bb = self._sample_bounding_box(shape) 204 raw_patch = np.array(raw[prefix_box + bb]) 205 label_patch = np.array(label[bb]) 206 sample_id += 1 207 208 # We need to avoid sampling from the same image over and over again, 209 # otherwise this will fail just because of one or a few empty images. 210 # Hence we update the image from which we sample sometimes. 211 if sample_id % self.max_sampling_attempts_image == 0: 212 index = np.random.randint(0, len(self.raw_images)) 213 raw_path, label_path = self.raw_images[index], self.label_images[index] 214 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 215 216 if sample_id > self.max_sampling_attempts: 217 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 218 219 # to channel first 220 if have_raw_channels and len(prefix_box) == 0: 221 raw_patch = raw_patch.transpose((2, 0, 1)) 222 223 return raw_patch, label_patch 224 225 def __getitem__(self, index): 226 raw, labels = self._get_sample(index) 227 initial_label_dtype = labels.dtype 228 229 if self.raw_transform is not None: 230 raw = self.raw_transform(raw) 231 232 if self.label_transform is not None: 233 labels = self.label_transform(labels) 234 235 if self.transform is not None: 236 raw, labels = self.transform(raw, labels) 237 # if self.trafo_halo is not None: 238 # raw = self.crop(raw) 239 # labels = self.crop(labels) 240 241 # support enlarging bounding box here as well (for affinity transform) ? 242 if self.label_transform2 is not None: 243 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 244 labels = self.label_transform2(labels) 245 246 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 247 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 248 return raw, labels
13class ImageCollectionDataset(torch.utils.data.Dataset): 14 """Dataset that provides raw data and labels stored in a regular image data format for segmentation training. 15 16 The dataset returns patches loaded from the images and labels as sample for a batch. 17 The raw data and labels are expected to be images of the same shape, except for possible channels. 18 It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files. 19 20 Args: 21 raw_image_paths: The file paths to the raw data. 22 label_image_paths: The file path to the label data. 23 patch_shape: The patch shape for a training sample. 24 raw_transform: Transformation applied to the raw data of a sample. 25 label_transform: Transformation applied to the label data of a sample, 26 before applying augmentations via `transform`. 27 label_transform2: Transformation applied to the label data of a sample, 28 after applying augmentations via `transform`. 29 transform: Transformation applied to both the raw data and label data of a sample. 30 This can be used to implement data augmentations. 31 dtype: The return data type of the raw data. 32 label_dtype: The return data type of the label data. 33 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 34 sampler: Sampler for rejecting samples according to a defined criterion. 35 The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input. 36 full_check: Whether to check that the input data is valid for all image paths. 37 This will ensure that the data is valid, but will take longer for creating the dataset. 38 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 39 """ 40 max_sampling_attempts = 500 41 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 42 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 43 """ 44 max_sampling_attempts_image = 50 45 """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`. 46 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 47 """ 48 49 def _check_inputs(self, raw_images, label_images, full_check): 50 if len(raw_images) != len(label_images): 51 raise ValueError(f"Expect same number of and label images, got {len(raw_images)} and {len(label_images)}") 52 53 if not full_check: 54 return 55 56 is_multichan = None 57 for raw_im, label_im in zip(raw_images, label_images): 58 59 # we only check for compatible shapes if both images support memmap, because 60 # we don't want to load everything into ram 61 if supports_memmap(raw_im) and supports_memmap(label_im): 62 shape = load_image(raw_im).shape 63 assert len(shape) in (2, 3) 64 65 multichan = len(shape) == 3 66 if is_multichan is None: 67 is_multichan = multichan 68 else: 69 assert is_multichan == multichan 70 71 if is_multichan: 72 # use heuristic to decide whether the data is stored in channel last or channel first order: 73 # if the last axis has a length smaller than 16 we assume that it's the channel axis, 74 # otherwise we assume it's a spatial axis and that the first axis is the channel axis. 75 if shape[-1] < 16: 76 shape = shape[:-1] 77 else: 78 shape = shape[1:] 79 80 label_shape = load_image(label_im).shape 81 if shape != label_shape: 82 msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}" 83 raise ValueError(msg) 84 85 def __init__( 86 self, 87 raw_image_paths: List[Union[str, os.PathLike]], 88 label_image_paths: List[Union[str, os.PathLike]], 89 patch_shape: Tuple[int, ...], 90 raw_transform: Optional[Callable] = None, 91 label_transform: Optional[Callable] = None, 92 label_transform2: Optional[Callable] = None, 93 transform: Optional[Callable] = None, 94 dtype: torch.dtype = torch.float32, 95 label_dtype: torch.dtype = torch.float32, 96 n_samples: Optional[int] = None, 97 sampler: Optional[Callable] = None, 98 full_check: bool = False, 99 with_padding: bool = True, 100 ): 101 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 102 self.raw_images = raw_image_paths 103 self.label_images = label_image_paths 104 self._ndim = 2 105 106 if patch_shape is not None: 107 assert len(patch_shape) == self._ndim 108 self.patch_shape = patch_shape 109 110 self.raw_transform = raw_transform 111 self.label_transform = label_transform 112 self.label_transform2 = label_transform2 113 self.transform = transform 114 self.sampler = sampler 115 self.with_padding = with_padding 116 117 self.dtype = dtype 118 self.label_dtype = label_dtype 119 120 if n_samples is None: 121 self._len = len(self.raw_images) 122 self.sample_random_index = False 123 else: 124 self._len = n_samples 125 self.sample_random_index = True 126 127 def __len__(self): 128 return self._len 129 130 @property 131 def ndim(self): 132 return self._ndim 133 134 def _sample_bounding_box(self, shape): 135 if self.patch_shape is None: 136 patch_shape_for_bb = shape 137 bb_start = [0] * len(shape) 138 else: 139 patch_shape_for_bb = self.patch_shape 140 bb_start = [ 141 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 142 for sh, psh in zip(shape, patch_shape_for_bb) 143 ] 144 145 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 146 147 def _load_data(self, raw_path, label_path): 148 raw = load_image(raw_path, memmap=False) 149 label = load_image(label_path, memmap=False) 150 151 have_raw_channels = raw.ndim == 3 152 have_label_channels = label.ndim == 3 153 if have_label_channels: 154 raise NotImplementedError("Multi-channel labels are not supported.") 155 156 # We determine if the image has channels as the first or last axis based on the array shape. 157 # This will work only for images with less than 16 channels! 158 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 159 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 160 channel_first = None 161 if have_raw_channels: 162 channel_first = raw.shape[-1] > 16 163 164 if self.patch_shape is not None and self.with_padding: 165 raw, label = ensure_patch_shape( 166 raw=raw, 167 labels=label, 168 patch_shape=self.patch_shape, 169 have_raw_channels=have_raw_channels, 170 have_label_channels=have_label_channels, 171 channel_first=channel_first 172 ) 173 174 shape = raw.shape 175 176 prefix_box = tuple() 177 if have_raw_channels: 178 if channel_first: 179 shape = shape[1:] 180 prefix_box = (slice(None), ) 181 else: 182 shape = shape[:-1] 183 184 return raw, label, shape, prefix_box, have_raw_channels 185 186 def _get_sample(self, index): 187 if self.sample_random_index: 188 index = np.random.randint(0, len(self.raw_images)) 189 190 # The filepath corresponding to this image. 191 raw_path, label_path = self.raw_images[index], self.label_images[index] 192 193 # Load the corresponding data. 194 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 195 196 # Sample random bounding box for this image. 197 bb = self._sample_bounding_box(shape) 198 raw_patch = np.array(raw[prefix_box + bb]) 199 label_patch = np.array(label[bb]) 200 201 if self.sampler is not None: 202 sample_id = 0 203 while not self.sampler(raw_patch, label_patch): 204 bb = self._sample_bounding_box(shape) 205 raw_patch = np.array(raw[prefix_box + bb]) 206 label_patch = np.array(label[bb]) 207 sample_id += 1 208 209 # We need to avoid sampling from the same image over and over again, 210 # otherwise this will fail just because of one or a few empty images. 211 # Hence we update the image from which we sample sometimes. 212 if sample_id % self.max_sampling_attempts_image == 0: 213 index = np.random.randint(0, len(self.raw_images)) 214 raw_path, label_path = self.raw_images[index], self.label_images[index] 215 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 216 217 if sample_id > self.max_sampling_attempts: 218 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 219 220 # to channel first 221 if have_raw_channels and len(prefix_box) == 0: 222 raw_patch = raw_patch.transpose((2, 0, 1)) 223 224 return raw_patch, label_patch 225 226 def __getitem__(self, index): 227 raw, labels = self._get_sample(index) 228 initial_label_dtype = labels.dtype 229 230 if self.raw_transform is not None: 231 raw = self.raw_transform(raw) 232 233 if self.label_transform is not None: 234 labels = self.label_transform(labels) 235 236 if self.transform is not None: 237 raw, labels = self.transform(raw, labels) 238 # if self.trafo_halo is not None: 239 # raw = self.crop(raw) 240 # labels = self.crop(labels) 241 242 # support enlarging bounding box here as well (for affinity transform) ? 243 if self.label_transform2 is not None: 244 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 245 labels = self.label_transform2(labels) 246 247 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 248 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 249 return raw, labels
Dataset that provides raw data and labels stored in a regular image data format for segmentation training.
The dataset returns patches loaded from the images and labels as sample for a batch. The raw data and labels are expected to be images of the same shape, except for possible channels. It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
Arguments:
- raw_image_paths: The file paths to the raw data.
- label_image_paths: The file path to the label data.
- patch_shape: The patch shape for a training sample.
- raw_transform: Transformation applied to the raw data of a sample.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform
. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform
. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths)
. - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
- full_check: Whether to check that the input data is valid for all image paths. This will ensure that the data is valid, but will take longer for creating the dataset.
- with_padding: Whether to pad samples to
patch_shape
if their shape is smaller.
85 def __init__( 86 self, 87 raw_image_paths: List[Union[str, os.PathLike]], 88 label_image_paths: List[Union[str, os.PathLike]], 89 patch_shape: Tuple[int, ...], 90 raw_transform: Optional[Callable] = None, 91 label_transform: Optional[Callable] = None, 92 label_transform2: Optional[Callable] = None, 93 transform: Optional[Callable] = None, 94 dtype: torch.dtype = torch.float32, 95 label_dtype: torch.dtype = torch.float32, 96 n_samples: Optional[int] = None, 97 sampler: Optional[Callable] = None, 98 full_check: bool = False, 99 with_padding: bool = True, 100 ): 101 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 102 self.raw_images = raw_image_paths 103 self.label_images = label_image_paths 104 self._ndim = 2 105 106 if patch_shape is not None: 107 assert len(patch_shape) == self._ndim 108 self.patch_shape = patch_shape 109 110 self.raw_transform = raw_transform 111 self.label_transform = label_transform 112 self.label_transform2 = label_transform2 113 self.transform = transform 114 self.sampler = sampler 115 self.with_padding = with_padding 116 117 self.dtype = dtype 118 self.label_dtype = label_dtype 119 120 if n_samples is None: 121 self._len = len(self.raw_images) 122 self.sample_random_index = False 123 else: 124 self._len = n_samples 125 self.sample_random_index = True
The maximal number of sampling attempts, for loading a sample via __getitem__
.
This is used when sampler
rejects a sample, to avoid an infinite loop if no valid sample can be found.
The maximal number of sampling attempts for a single image, for loading a sample via __getitem__
.
This is used when sampler
rejects a sample, to avoid an infinite loop if no valid sample can be found.