torch_em.data.raw_image_collection_dataset
1import os 2import numpy as np 3from typing import List, Union, Tuple, Optional, Any, Callable 4 5import torch 6 7from ..util import ensure_tensor_with_channels, load_image, supports_memmap 8 9 10class RawImageCollectionDataset(torch.utils.data.Dataset): 11 """Dataset that provides raw data stored in a regular image data format for unsupervised training. 12 13 The dataset loads a patch the raw data and returns a sample for a batch. 14 It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files. 15 16 The dataset can also be used for contrastive learning that relies on two different views of the same data. 17 You can use the `augmentations` argument for this. 18 19 Args: 20 raw_image_paths: The file paths to the raw data. 21 patch_shape: The patch shape for a training sample. 22 raw_transform: Transformation applied to the raw data of a sample. 23 transform: Transformation to the raw data. This can be used to implement data augmentations. 24 dtype: The return data type of the raw data. 25 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 26 sampler: Sampler for rejecting samples according to a defined criterion. 27 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 28 augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. 29 They will be applied to the sampled raw data to return two independent views of the raw data. 30 full_check: Whether to check that the input data is valid for all image paths. 31 This will ensure that the data is valid, but will take longer for creating the dataset. 32 """ 33 max_sampling_attempts = 500 34 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 35 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 36 """ 37 38 def _check_inputs(self, raw_images, full_check): 39 if not full_check: 40 return 41 42 is_multichan = None 43 for raw_im in raw_images: 44 45 # we only check for compatible shapes if images support memmap, because 46 # we don't want to load everything into ram 47 if supports_memmap(raw_im): 48 shape = load_image(raw_im).shape 49 assert len(shape) in (2, 3) 50 51 multichan = len(shape) == 3 52 if is_multichan is None: 53 is_multichan = multichan 54 else: 55 assert is_multichan == multichan 56 57 # we assume axis last 58 if is_multichan: 59 shape = shape[:-1] 60 61 def __init__( 62 self, 63 raw_image_paths: Union[List[Any], str, os.PathLike], 64 patch_shape: Tuple[int, ...], 65 raw_transform: Optional[Callable] = None, 66 transform: Optional[Callable] = None, 67 dtype: torch.dtype = torch.float32, 68 n_samples: Optional[int] = None, 69 sampler: Optional[Callable] = None, 70 augmentations: Optional[Callable] = None, 71 full_check: bool = False, 72 ): 73 self._check_inputs(raw_image_paths, full_check) 74 self.raw_images = raw_image_paths 75 self._ndim = 2 76 77 assert len(patch_shape) == self._ndim 78 self.patch_shape = patch_shape 79 80 self.raw_transform = raw_transform 81 self.transform = transform 82 self.dtype = dtype 83 self.sampler = sampler 84 85 if n_samples is None: 86 self._len = len(self.raw_images) 87 self.sample_random_index = False 88 else: 89 self._len = n_samples 90 self.sample_random_index = True 91 92 if augmentations is not None: 93 assert len(augmentations) == 2 94 self.augmentations = augmentations 95 96 def __len__(self): 97 return self._len 98 99 @property 100 def ndim(self): 101 return self._ndim 102 103 def _sample_bounding_box(self, shape): 104 bb_start = [ 105 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, self.patch_shape) 106 ] 107 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) 108 109 def _ensure_patch_shape(self, raw, have_raw_channels, channel_first): 110 shape = raw.shape 111 if have_raw_channels and channel_first: 112 shape = shape[1:] 113 114 if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): 115 pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)] 116 117 if have_raw_channels and channel_first: 118 pw_raw = [(0, 0), *pw] 119 elif have_raw_channels and not channel_first: 120 pw_raw = [*pw, (0, 0)] 121 else: 122 pw_raw = pw 123 124 raw = np.pad(raw, pw_raw) 125 return raw 126 127 def _get_sample(self, index): 128 if self.sample_random_index: 129 index = np.random.randint(0, len(self.raw_images)) 130 131 raw = load_image(self.raw_images[index]) 132 have_raw_channels = raw.ndim == 3 133 134 # We determine if the image has channels as the first or last axis based on the array shape. 135 # This will work only for images with less than 16 channels! 136 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 137 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 138 channel_first = None 139 if have_raw_channels: 140 channel_first = raw.shape[-1] > 16 141 142 raw = self._ensure_patch_shape(raw, have_raw_channels, channel_first) 143 144 shape = raw.shape 145 # we assume images are loaded with channel last! 146 if have_raw_channels: 147 shape = shape[:-1] 148 149 # sample random bounding box for this image 150 bb = self._sample_bounding_box(shape) 151 raw = np.array(raw[bb]) 152 153 if self.sampler is not None: 154 sample_id = 0 155 while not self.sampler(raw): 156 bb = self._sample_bounding_box(shape) 157 raw = np.array(raw[bb]) 158 sample_id += 1 159 if sample_id > self.max_sampling_attempts: 160 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 161 162 # to channel first 163 if have_raw_channels: 164 raw = raw.transpose((2, 0, 1)) 165 166 return raw 167 168 def __getitem__(self, index): 169 raw = self._get_sample(index) 170 171 if self.raw_transform is not None: 172 raw = self.raw_transform(raw) 173 174 if self.transform is not None: 175 raw = self.transform(raw) 176 assert len(raw) == 1 177 raw = raw[0] 178 # if self.trafo_halo is not None: 179 # raw = self.crop(raw) 180 181 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 182 if self.augmentations is not None: 183 aug1, aug2 = self.augmentations 184 raw1, raw2 = aug1(raw), aug2(raw) 185 return raw1, raw2 186 187 return raw
class
RawImageCollectionDataset(typing.Generic[+_T_co]):
11class RawImageCollectionDataset(torch.utils.data.Dataset): 12 """Dataset that provides raw data stored in a regular image data format for unsupervised training. 13 14 The dataset loads a patch the raw data and returns a sample for a batch. 15 It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files. 16 17 The dataset can also be used for contrastive learning that relies on two different views of the same data. 18 You can use the `augmentations` argument for this. 19 20 Args: 21 raw_image_paths: The file paths to the raw data. 22 patch_shape: The patch shape for a training sample. 23 raw_transform: Transformation applied to the raw data of a sample. 24 transform: Transformation to the raw data. This can be used to implement data augmentations. 25 dtype: The return data type of the raw data. 26 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 27 sampler: Sampler for rejecting samples according to a defined criterion. 28 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 29 augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. 30 They will be applied to the sampled raw data to return two independent views of the raw data. 31 full_check: Whether to check that the input data is valid for all image paths. 32 This will ensure that the data is valid, but will take longer for creating the dataset. 33 """ 34 max_sampling_attempts = 500 35 """The maximal number of sampling attempts, for loading a sample via `__getitem__`. 36 This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found. 37 """ 38 39 def _check_inputs(self, raw_images, full_check): 40 if not full_check: 41 return 42 43 is_multichan = None 44 for raw_im in raw_images: 45 46 # we only check for compatible shapes if images support memmap, because 47 # we don't want to load everything into ram 48 if supports_memmap(raw_im): 49 shape = load_image(raw_im).shape 50 assert len(shape) in (2, 3) 51 52 multichan = len(shape) == 3 53 if is_multichan is None: 54 is_multichan = multichan 55 else: 56 assert is_multichan == multichan 57 58 # we assume axis last 59 if is_multichan: 60 shape = shape[:-1] 61 62 def __init__( 63 self, 64 raw_image_paths: Union[List[Any], str, os.PathLike], 65 patch_shape: Tuple[int, ...], 66 raw_transform: Optional[Callable] = None, 67 transform: Optional[Callable] = None, 68 dtype: torch.dtype = torch.float32, 69 n_samples: Optional[int] = None, 70 sampler: Optional[Callable] = None, 71 augmentations: Optional[Callable] = None, 72 full_check: bool = False, 73 ): 74 self._check_inputs(raw_image_paths, full_check) 75 self.raw_images = raw_image_paths 76 self._ndim = 2 77 78 assert len(patch_shape) == self._ndim 79 self.patch_shape = patch_shape 80 81 self.raw_transform = raw_transform 82 self.transform = transform 83 self.dtype = dtype 84 self.sampler = sampler 85 86 if n_samples is None: 87 self._len = len(self.raw_images) 88 self.sample_random_index = False 89 else: 90 self._len = n_samples 91 self.sample_random_index = True 92 93 if augmentations is not None: 94 assert len(augmentations) == 2 95 self.augmentations = augmentations 96 97 def __len__(self): 98 return self._len 99 100 @property 101 def ndim(self): 102 return self._ndim 103 104 def _sample_bounding_box(self, shape): 105 bb_start = [ 106 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, self.patch_shape) 107 ] 108 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) 109 110 def _ensure_patch_shape(self, raw, have_raw_channels, channel_first): 111 shape = raw.shape 112 if have_raw_channels and channel_first: 113 shape = shape[1:] 114 115 if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): 116 pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)] 117 118 if have_raw_channels and channel_first: 119 pw_raw = [(0, 0), *pw] 120 elif have_raw_channels and not channel_first: 121 pw_raw = [*pw, (0, 0)] 122 else: 123 pw_raw = pw 124 125 raw = np.pad(raw, pw_raw) 126 return raw 127 128 def _get_sample(self, index): 129 if self.sample_random_index: 130 index = np.random.randint(0, len(self.raw_images)) 131 132 raw = load_image(self.raw_images[index]) 133 have_raw_channels = raw.ndim == 3 134 135 # We determine if the image has channels as the first or last axis based on the array shape. 136 # This will work only for images with less than 16 channels! 137 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 138 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 139 channel_first = None 140 if have_raw_channels: 141 channel_first = raw.shape[-1] > 16 142 143 raw = self._ensure_patch_shape(raw, have_raw_channels, channel_first) 144 145 shape = raw.shape 146 # we assume images are loaded with channel last! 147 if have_raw_channels: 148 shape = shape[:-1] 149 150 # sample random bounding box for this image 151 bb = self._sample_bounding_box(shape) 152 raw = np.array(raw[bb]) 153 154 if self.sampler is not None: 155 sample_id = 0 156 while not self.sampler(raw): 157 bb = self._sample_bounding_box(shape) 158 raw = np.array(raw[bb]) 159 sample_id += 1 160 if sample_id > self.max_sampling_attempts: 161 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 162 163 # to channel first 164 if have_raw_channels: 165 raw = raw.transpose((2, 0, 1)) 166 167 return raw 168 169 def __getitem__(self, index): 170 raw = self._get_sample(index) 171 172 if self.raw_transform is not None: 173 raw = self.raw_transform(raw) 174 175 if self.transform is not None: 176 raw = self.transform(raw) 177 assert len(raw) == 1 178 raw = raw[0] 179 # if self.trafo_halo is not None: 180 # raw = self.crop(raw) 181 182 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 183 if self.augmentations is not None: 184 aug1, aug2 = self.augmentations 185 raw1, raw2 = aug1(raw), aug2(raw) 186 return raw1, raw2 187 188 return raw
Dataset that provides raw data stored in a regular image data format for unsupervised training.
The dataset loads a patch the raw data and returns a sample for a batch. It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
The dataset can also be used for contrastive learning that relies on two different views of the same data.
You can use the augmentations
argument for this.
Arguments:
- raw_image_paths: The file paths to the raw data.
- patch_shape: The patch shape for a training sample.
- raw_transform: Transformation applied to the raw data of a sample.
- transform: Transformation to the raw data. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths)
. - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. They will be applied to the sampled raw data to return two independent views of the raw data.
- full_check: Whether to check that the input data is valid for all image paths. This will ensure that the data is valid, but will take longer for creating the dataset.
RawImageCollectionDataset( raw_image_paths: Union[List[Any], str, os.PathLike], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, augmentations: Optional[Callable] = None, full_check: bool = False)
62 def __init__( 63 self, 64 raw_image_paths: Union[List[Any], str, os.PathLike], 65 patch_shape: Tuple[int, ...], 66 raw_transform: Optional[Callable] = None, 67 transform: Optional[Callable] = None, 68 dtype: torch.dtype = torch.float32, 69 n_samples: Optional[int] = None, 70 sampler: Optional[Callable] = None, 71 augmentations: Optional[Callable] = None, 72 full_check: bool = False, 73 ): 74 self._check_inputs(raw_image_paths, full_check) 75 self.raw_images = raw_image_paths 76 self._ndim = 2 77 78 assert len(patch_shape) == self._ndim 79 self.patch_shape = patch_shape 80 81 self.raw_transform = raw_transform 82 self.transform = transform 83 self.dtype = dtype 84 self.sampler = sampler 85 86 if n_samples is None: 87 self._len = len(self.raw_images) 88 self.sample_random_index = False 89 else: 90 self._len = n_samples 91 self.sample_random_index = True 92 93 if augmentations is not None: 94 assert len(augmentations) == 2 95 self.augmentations = augmentations
max_sampling_attempts =
500
The maximal number of sampling attempts, for loading a sample via __getitem__
.
This is used when sampler
rejects a sample, to avoid an infinite loop if no valid sample can be found.