torch_em.data.image_collection_dataset
1import os 2import numpy as np 3from typing import List, Optional, Tuple, Union 4 5import torch 6 7from ..util import (ensure_spatial_array, ensure_tensor_with_channels, 8 load_image, supports_memmap) 9 10 11class ImageCollectionDataset(torch.utils.data.Dataset): 12 max_sampling_attempts = 500 13 max_sampling_attempts_image = 50 14 15 def _check_inputs(self, raw_images, label_images, full_check): 16 if len(raw_images) != len(label_images): 17 raise ValueError(f"Expect same number of and label images, got {len(raw_images)} and {len(label_images)}") 18 19 if not full_check: 20 return 21 22 is_multichan = None 23 for raw_im, label_im in zip(raw_images, label_images): 24 25 # we only check for compatible shapes if both images support memmap, because 26 # we don't want to load everything into ram 27 if supports_memmap(raw_im) and supports_memmap(label_im): 28 shape = load_image(raw_im).shape 29 assert len(shape) in (2, 3) 30 31 multichan = len(shape) == 3 32 if is_multichan is None: 33 is_multichan = multichan 34 else: 35 assert is_multichan == multichan 36 37 if is_multichan: 38 # use heuristic to decide whether the data is stored in channel last or channel first order: 39 # if the last axis has a length smaller than 16 we assume that it's the channel axis, 40 # otherwise we assume it's a spatial axis and that the first axis is the channel axis. 41 if shape[-1] < 16: 42 shape = shape[:-1] 43 else: 44 shape = shape[1:] 45 46 label_shape = load_image(label_im).shape 47 if shape != label_shape: 48 msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}" 49 raise ValueError(msg) 50 51 def __init__( 52 self, 53 raw_image_paths: List[Union[str, os.PathLike]], 54 label_image_paths: List[Union[str, os.PathLike]], 55 patch_shape: Tuple[int, ...], 56 raw_transform=None, 57 label_transform=None, 58 label_transform2=None, 59 transform=None, 60 dtype: torch.dtype = torch.float32, 61 label_dtype: torch.dtype = torch.float32, 62 n_samples: Optional[int] = None, 63 sampler=None, 64 full_check: bool = False, 65 ): 66 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 67 self.raw_images = raw_image_paths 68 self.label_images = label_image_paths 69 self._ndim = 2 70 71 if patch_shape is not None: 72 assert len(patch_shape) == self._ndim 73 self.patch_shape = patch_shape 74 75 self.raw_transform = raw_transform 76 self.label_transform = label_transform 77 self.label_transform2 = label_transform2 78 self.transform = transform 79 self.sampler = sampler 80 81 self.dtype = dtype 82 self.label_dtype = label_dtype 83 84 if n_samples is None: 85 self._len = len(self.raw_images) 86 self.sample_random_index = False 87 else: 88 self._len = n_samples 89 self.sample_random_index = True 90 91 def __len__(self): 92 return self._len 93 94 @property 95 def ndim(self): 96 return self._ndim 97 98 def _sample_bounding_box(self, shape): 99 if self.patch_shape is None: 100 patch_shape_for_bb = shape 101 bb_start = [0] * len(shape) 102 else: 103 patch_shape_for_bb = self.patch_shape 104 bb_start = [ 105 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 106 for sh, psh in zip(shape, patch_shape_for_bb) 107 ] 108 109 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 110 111 def _ensure_patch_shape(self, raw, labels, have_raw_channels, have_label_channels, channel_first): 112 shape = raw.shape 113 if have_raw_channels and channel_first: 114 shape = shape[1:] 115 if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): 116 pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)] 117 118 if have_raw_channels and channel_first: 119 pw_raw = [(0, 0), *pw] 120 elif have_raw_channels and not channel_first: 121 pw_raw = [*pw, (0, 0)] 122 else: 123 pw_raw = pw 124 125 # TODO: ensure padding for labels with channels, when supported (see `_get_sample` below) 126 127 raw, labels = np.pad(raw, pw_raw), np.pad(labels, pw) 128 return raw, labels 129 130 def _load_data(self, raw_path, label_path): 131 raw = load_image(raw_path, memmap=False) 132 label = load_image(label_path, memmap=False) 133 134 have_raw_channels = raw.ndim == 3 135 have_label_channels = label.ndim == 3 136 if have_label_channels: 137 raise NotImplementedError("Multi-channel labels are not supported.") 138 139 # We determine if the image has channels as the first or last axis based on the array shape. 140 # This will work only for images with less than 16 channels! 141 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 142 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 143 channel_first = None 144 if have_raw_channels: 145 channel_first = raw.shape[-1] > 16 146 147 if self.patch_shape is not None: 148 raw, label = self._ensure_patch_shape(raw, label, have_raw_channels, have_label_channels, channel_first) 149 shape = raw.shape 150 151 prefix_box = tuple() 152 if have_raw_channels: 153 if channel_first: 154 shape = shape[1:] 155 prefix_box = (slice(None), ) 156 else: 157 shape = shape[:-1] 158 159 return raw, label, shape, prefix_box, have_raw_channels 160 161 def _get_sample(self, index): 162 if self.sample_random_index: 163 index = np.random.randint(0, len(self.raw_images)) 164 165 # The filepath corresponding to this image. 166 raw_path, label_path = self.raw_images[index], self.label_images[index] 167 168 # Load the corresponding data. 169 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 170 171 # Sample random bounding box for this image. 172 bb = self._sample_bounding_box(shape) 173 raw_patch = np.array(raw[prefix_box + bb]) 174 label_patch = np.array(label[bb]) 175 176 if self.sampler is not None: 177 sample_id = 0 178 while not self.sampler(raw_patch, label_patch): 179 bb = self._sample_bounding_box(shape) 180 raw_patch = np.array(raw[prefix_box + bb]) 181 label_patch = np.array(label[bb]) 182 sample_id += 1 183 184 # We need to avoid sampling from the same image over and over agagin, 185 # otherwise this will fail just because of one or a few empty images. 186 # Hence we update the image from which we sample sometimes. 187 if sample_id % self.max_sampling_attempts_image == 0: 188 index = np.random.randint(0, len(self.raw_images)) 189 raw_path, label_path = self.raw_images[index], self.label_images[index] 190 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 191 192 if sample_id > self.max_sampling_attempts: 193 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 194 195 # to channel first 196 if have_raw_channels and len(prefix_box) == 0: 197 raw_patch = raw_patch.transpose((2, 0, 1)) 198 199 return raw_patch, label_patch 200 201 def __getitem__(self, index): 202 raw, labels = self._get_sample(index) 203 initial_label_dtype = labels.dtype 204 205 if self.raw_transform is not None: 206 raw = self.raw_transform(raw) 207 208 if self.label_transform is not None: 209 labels = self.label_transform(labels) 210 211 if self.transform is not None: 212 raw, labels = self.transform(raw, labels) 213 # if self.trafo_halo is not None: 214 # raw = self.crop(raw) 215 # labels = self.crop(labels) 216 217 # support enlarging bounding box here as well (for affinity transform) ? 218 if self.label_transform2 is not None: 219 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 220 labels = self.label_transform2(labels) 221 222 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 223 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 224 return raw, labels
12class ImageCollectionDataset(torch.utils.data.Dataset): 13 max_sampling_attempts = 500 14 max_sampling_attempts_image = 50 15 16 def _check_inputs(self, raw_images, label_images, full_check): 17 if len(raw_images) != len(label_images): 18 raise ValueError(f"Expect same number of and label images, got {len(raw_images)} and {len(label_images)}") 19 20 if not full_check: 21 return 22 23 is_multichan = None 24 for raw_im, label_im in zip(raw_images, label_images): 25 26 # we only check for compatible shapes if both images support memmap, because 27 # we don't want to load everything into ram 28 if supports_memmap(raw_im) and supports_memmap(label_im): 29 shape = load_image(raw_im).shape 30 assert len(shape) in (2, 3) 31 32 multichan = len(shape) == 3 33 if is_multichan is None: 34 is_multichan = multichan 35 else: 36 assert is_multichan == multichan 37 38 if is_multichan: 39 # use heuristic to decide whether the data is stored in channel last or channel first order: 40 # if the last axis has a length smaller than 16 we assume that it's the channel axis, 41 # otherwise we assume it's a spatial axis and that the first axis is the channel axis. 42 if shape[-1] < 16: 43 shape = shape[:-1] 44 else: 45 shape = shape[1:] 46 47 label_shape = load_image(label_im).shape 48 if shape != label_shape: 49 msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}" 50 raise ValueError(msg) 51 52 def __init__( 53 self, 54 raw_image_paths: List[Union[str, os.PathLike]], 55 label_image_paths: List[Union[str, os.PathLike]], 56 patch_shape: Tuple[int, ...], 57 raw_transform=None, 58 label_transform=None, 59 label_transform2=None, 60 transform=None, 61 dtype: torch.dtype = torch.float32, 62 label_dtype: torch.dtype = torch.float32, 63 n_samples: Optional[int] = None, 64 sampler=None, 65 full_check: bool = False, 66 ): 67 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 68 self.raw_images = raw_image_paths 69 self.label_images = label_image_paths 70 self._ndim = 2 71 72 if patch_shape is not None: 73 assert len(patch_shape) == self._ndim 74 self.patch_shape = patch_shape 75 76 self.raw_transform = raw_transform 77 self.label_transform = label_transform 78 self.label_transform2 = label_transform2 79 self.transform = transform 80 self.sampler = sampler 81 82 self.dtype = dtype 83 self.label_dtype = label_dtype 84 85 if n_samples is None: 86 self._len = len(self.raw_images) 87 self.sample_random_index = False 88 else: 89 self._len = n_samples 90 self.sample_random_index = True 91 92 def __len__(self): 93 return self._len 94 95 @property 96 def ndim(self): 97 return self._ndim 98 99 def _sample_bounding_box(self, shape): 100 if self.patch_shape is None: 101 patch_shape_for_bb = shape 102 bb_start = [0] * len(shape) 103 else: 104 patch_shape_for_bb = self.patch_shape 105 bb_start = [ 106 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 107 for sh, psh in zip(shape, patch_shape_for_bb) 108 ] 109 110 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 111 112 def _ensure_patch_shape(self, raw, labels, have_raw_channels, have_label_channels, channel_first): 113 shape = raw.shape 114 if have_raw_channels and channel_first: 115 shape = shape[1:] 116 if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): 117 pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)] 118 119 if have_raw_channels and channel_first: 120 pw_raw = [(0, 0), *pw] 121 elif have_raw_channels and not channel_first: 122 pw_raw = [*pw, (0, 0)] 123 else: 124 pw_raw = pw 125 126 # TODO: ensure padding for labels with channels, when supported (see `_get_sample` below) 127 128 raw, labels = np.pad(raw, pw_raw), np.pad(labels, pw) 129 return raw, labels 130 131 def _load_data(self, raw_path, label_path): 132 raw = load_image(raw_path, memmap=False) 133 label = load_image(label_path, memmap=False) 134 135 have_raw_channels = raw.ndim == 3 136 have_label_channels = label.ndim == 3 137 if have_label_channels: 138 raise NotImplementedError("Multi-channel labels are not supported.") 139 140 # We determine if the image has channels as the first or last axis based on the array shape. 141 # This will work only for images with less than 16 channels! 142 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 143 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 144 channel_first = None 145 if have_raw_channels: 146 channel_first = raw.shape[-1] > 16 147 148 if self.patch_shape is not None: 149 raw, label = self._ensure_patch_shape(raw, label, have_raw_channels, have_label_channels, channel_first) 150 shape = raw.shape 151 152 prefix_box = tuple() 153 if have_raw_channels: 154 if channel_first: 155 shape = shape[1:] 156 prefix_box = (slice(None), ) 157 else: 158 shape = shape[:-1] 159 160 return raw, label, shape, prefix_box, have_raw_channels 161 162 def _get_sample(self, index): 163 if self.sample_random_index: 164 index = np.random.randint(0, len(self.raw_images)) 165 166 # The filepath corresponding to this image. 167 raw_path, label_path = self.raw_images[index], self.label_images[index] 168 169 # Load the corresponding data. 170 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 171 172 # Sample random bounding box for this image. 173 bb = self._sample_bounding_box(shape) 174 raw_patch = np.array(raw[prefix_box + bb]) 175 label_patch = np.array(label[bb]) 176 177 if self.sampler is not None: 178 sample_id = 0 179 while not self.sampler(raw_patch, label_patch): 180 bb = self._sample_bounding_box(shape) 181 raw_patch = np.array(raw[prefix_box + bb]) 182 label_patch = np.array(label[bb]) 183 sample_id += 1 184 185 # We need to avoid sampling from the same image over and over agagin, 186 # otherwise this will fail just because of one or a few empty images. 187 # Hence we update the image from which we sample sometimes. 188 if sample_id % self.max_sampling_attempts_image == 0: 189 index = np.random.randint(0, len(self.raw_images)) 190 raw_path, label_path = self.raw_images[index], self.label_images[index] 191 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path) 192 193 if sample_id > self.max_sampling_attempts: 194 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 195 196 # to channel first 197 if have_raw_channels and len(prefix_box) == 0: 198 raw_patch = raw_patch.transpose((2, 0, 1)) 199 200 return raw_patch, label_patch 201 202 def __getitem__(self, index): 203 raw, labels = self._get_sample(index) 204 initial_label_dtype = labels.dtype 205 206 if self.raw_transform is not None: 207 raw = self.raw_transform(raw) 208 209 if self.label_transform is not None: 210 labels = self.label_transform(labels) 211 212 if self.transform is not None: 213 raw, labels = self.transform(raw, labels) 214 # if self.trafo_halo is not None: 215 # raw = self.crop(raw) 216 # labels = self.crop(labels) 217 218 # support enlarging bounding box here as well (for affinity transform) ? 219 if self.label_transform2 is not None: 220 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 221 labels = self.label_transform2(labels) 222 223 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 224 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 225 return raw, labels
An abstract class representing a Dataset
.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__()
, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__()
, which is expected to return the size of the dataset by many
~torch.utils.data.Sampler
implementations and the default options
of ~torch.utils.data.DataLoader
. Subclasses could also
optionally implement __getitems__()
, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
~torch.utils.data.DataLoader
by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
52 def __init__( 53 self, 54 raw_image_paths: List[Union[str, os.PathLike]], 55 label_image_paths: List[Union[str, os.PathLike]], 56 patch_shape: Tuple[int, ...], 57 raw_transform=None, 58 label_transform=None, 59 label_transform2=None, 60 transform=None, 61 dtype: torch.dtype = torch.float32, 62 label_dtype: torch.dtype = torch.float32, 63 n_samples: Optional[int] = None, 64 sampler=None, 65 full_check: bool = False, 66 ): 67 self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) 68 self.raw_images = raw_image_paths 69 self.label_images = label_image_paths 70 self._ndim = 2 71 72 if patch_shape is not None: 73 assert len(patch_shape) == self._ndim 74 self.patch_shape = patch_shape 75 76 self.raw_transform = raw_transform 77 self.label_transform = label_transform 78 self.label_transform2 = label_transform2 79 self.transform = transform 80 self.sampler = sampler 81 82 self.dtype = dtype 83 self.label_dtype = label_dtype 84 85 if n_samples is None: 86 self._len = len(self.raw_images) 87 self.sample_random_index = False 88 else: 89 self._len = n_samples 90 self.sample_random_index = True