torch_em.data.raw_image_collection_dataset
1import os 2import numpy as np 3from typing import List, Union, Tuple, Optional, Any 4 5import torch 6 7from ..util import ensure_tensor_with_channels, load_image, supports_memmap 8 9 10class RawImageCollectionDataset(torch.utils.data.Dataset): 11 max_sampling_attempts = 500 12 13 def _check_inputs(self, raw_images, full_check): 14 if not full_check: 15 return 16 17 is_multichan = None 18 for raw_im in raw_images: 19 20 # we only check for compatible shapes if images support memmap, because 21 # we don't want to load everything into ram 22 if supports_memmap(raw_im): 23 shape = load_image(raw_im).shape 24 assert len(shape) in (2, 3) 25 26 multichan = len(shape) == 3 27 if is_multichan is None: 28 is_multichan = multichan 29 else: 30 assert is_multichan == multichan 31 32 # we assume axis last 33 if is_multichan: 34 shape = shape[:-1] 35 36 def __init__( 37 self, 38 raw_image_paths: Union[List[Any], str, os.PathLike], 39 patch_shape: Tuple[int, ...], 40 raw_transform=None, 41 transform=None, 42 dtype: torch.dtype = torch.float32, 43 n_samples: Optional[int] = None, 44 sampler=None, 45 augmentations=None, 46 full_check: bool = False, 47 ): 48 self._check_inputs(raw_image_paths, full_check) 49 self.raw_images = raw_image_paths 50 self._ndim = 2 51 52 assert len(patch_shape) == self._ndim 53 self.patch_shape = patch_shape 54 55 self.raw_transform = raw_transform 56 self.transform = transform 57 self.dtype = dtype 58 self.sampler = sampler 59 60 if n_samples is None: 61 self._len = len(self.raw_images) 62 self.sample_random_index = False 63 else: 64 self._len = n_samples 65 self.sample_random_index = True 66 67 if augmentations is not None: 68 assert len(augmentations) == 2 69 self.augmentations = augmentations 70 71 def __len__(self): 72 return self._len 73 74 @property 75 def ndim(self): 76 return self._ndim 77 78 def _sample_bounding_box(self, shape): 79 bb_start = [ 80 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 81 for sh, psh in zip(shape, self.patch_shape) 82 ] 83 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) 84 85 def _ensure_patch_shape(self, raw, have_raw_channels, channel_first): 86 shape = raw.shape 87 if have_raw_channels and channel_first: 88 shape = shape[1:] 89 if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): 90 pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)] 91 92 if have_raw_channels and channel_first: 93 pw_raw = [(0, 0), *pw] 94 elif have_raw_channels and not channel_first: 95 pw_raw = [*pw, (0, 0)] 96 else: 97 pw_raw = pw 98 99 raw = np.pad(raw, pw_raw) 100 return raw 101 102 def _get_sample(self, index): 103 if self.sample_random_index: 104 index = np.random.randint(0, len(self.raw_images)) 105 raw = load_image(self.raw_images[index]) 106 have_raw_channels = raw.ndim == 3 107 108 # We determine if the image has channels as the first or last axis based on the array shape. 109 # This will work only for images with less than 16 channels! 110 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 111 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 112 channel_first = None 113 if have_raw_channels: 114 channel_first = raw.shape[-1] > 16 115 116 raw = self._ensure_patch_shape(raw, have_raw_channels, channel_first) 117 118 shape = raw.shape 119 # we assume images are loaded with channel last! 120 if have_raw_channels: 121 shape = shape[:-1] 122 123 # sample random bounding box for this image 124 bb = self._sample_bounding_box(shape) 125 raw = np.array(raw[bb]) 126 127 if self.sampler is not None: 128 sample_id = 0 129 while not self.sampler(raw): 130 bb = self._sample_bounding_box(shape) 131 raw = np.array(raw[bb]) 132 sample_id += 1 133 if sample_id > self.max_sampling_attempts: 134 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 135 136 # to channel first 137 if have_raw_channels: 138 raw = raw.transpose((2, 0, 1)) 139 140 return raw 141 142 def __getitem__(self, index): 143 raw = self._get_sample(index) 144 145 if self.raw_transform is not None: 146 raw = self.raw_transform(raw) 147 148 if self.transform is not None: 149 raw = self.transform(raw) 150 assert len(raw) == 1 151 raw = raw[0] 152 # if self.trafo_halo is not None: 153 # raw = self.crop(raw) 154 155 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 156 if self.augmentations is not None: 157 aug1, aug2 = self.augmentations 158 raw1, raw2 = aug1(raw), aug2(raw) 159 return raw1, raw2 160 161 return raw
11class RawImageCollectionDataset(torch.utils.data.Dataset): 12 max_sampling_attempts = 500 13 14 def _check_inputs(self, raw_images, full_check): 15 if not full_check: 16 return 17 18 is_multichan = None 19 for raw_im in raw_images: 20 21 # we only check for compatible shapes if images support memmap, because 22 # we don't want to load everything into ram 23 if supports_memmap(raw_im): 24 shape = load_image(raw_im).shape 25 assert len(shape) in (2, 3) 26 27 multichan = len(shape) == 3 28 if is_multichan is None: 29 is_multichan = multichan 30 else: 31 assert is_multichan == multichan 32 33 # we assume axis last 34 if is_multichan: 35 shape = shape[:-1] 36 37 def __init__( 38 self, 39 raw_image_paths: Union[List[Any], str, os.PathLike], 40 patch_shape: Tuple[int, ...], 41 raw_transform=None, 42 transform=None, 43 dtype: torch.dtype = torch.float32, 44 n_samples: Optional[int] = None, 45 sampler=None, 46 augmentations=None, 47 full_check: bool = False, 48 ): 49 self._check_inputs(raw_image_paths, full_check) 50 self.raw_images = raw_image_paths 51 self._ndim = 2 52 53 assert len(patch_shape) == self._ndim 54 self.patch_shape = patch_shape 55 56 self.raw_transform = raw_transform 57 self.transform = transform 58 self.dtype = dtype 59 self.sampler = sampler 60 61 if n_samples is None: 62 self._len = len(self.raw_images) 63 self.sample_random_index = False 64 else: 65 self._len = n_samples 66 self.sample_random_index = True 67 68 if augmentations is not None: 69 assert len(augmentations) == 2 70 self.augmentations = augmentations 71 72 def __len__(self): 73 return self._len 74 75 @property 76 def ndim(self): 77 return self._ndim 78 79 def _sample_bounding_box(self, shape): 80 bb_start = [ 81 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 82 for sh, psh in zip(shape, self.patch_shape) 83 ] 84 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) 85 86 def _ensure_patch_shape(self, raw, have_raw_channels, channel_first): 87 shape = raw.shape 88 if have_raw_channels and channel_first: 89 shape = shape[1:] 90 if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): 91 pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)] 92 93 if have_raw_channels and channel_first: 94 pw_raw = [(0, 0), *pw] 95 elif have_raw_channels and not channel_first: 96 pw_raw = [*pw, (0, 0)] 97 else: 98 pw_raw = pw 99 100 raw = np.pad(raw, pw_raw) 101 return raw 102 103 def _get_sample(self, index): 104 if self.sample_random_index: 105 index = np.random.randint(0, len(self.raw_images)) 106 raw = load_image(self.raw_images[index]) 107 have_raw_channels = raw.ndim == 3 108 109 # We determine if the image has channels as the first or last axis based on the array shape. 110 # This will work only for images with less than 16 channels! 111 # If the last axis has a length smaller than 16 we assume that it is the channel axis, 112 # otherwise we assume it is a spatial axis and that the first axis is the channel axis. 113 channel_first = None 114 if have_raw_channels: 115 channel_first = raw.shape[-1] > 16 116 117 raw = self._ensure_patch_shape(raw, have_raw_channels, channel_first) 118 119 shape = raw.shape 120 # we assume images are loaded with channel last! 121 if have_raw_channels: 122 shape = shape[:-1] 123 124 # sample random bounding box for this image 125 bb = self._sample_bounding_box(shape) 126 raw = np.array(raw[bb]) 127 128 if self.sampler is not None: 129 sample_id = 0 130 while not self.sampler(raw): 131 bb = self._sample_bounding_box(shape) 132 raw = np.array(raw[bb]) 133 sample_id += 1 134 if sample_id > self.max_sampling_attempts: 135 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 136 137 # to channel first 138 if have_raw_channels: 139 raw = raw.transpose((2, 0, 1)) 140 141 return raw 142 143 def __getitem__(self, index): 144 raw = self._get_sample(index) 145 146 if self.raw_transform is not None: 147 raw = self.raw_transform(raw) 148 149 if self.transform is not None: 150 raw = self.transform(raw) 151 assert len(raw) == 1 152 raw = raw[0] 153 # if self.trafo_halo is not None: 154 # raw = self.crop(raw) 155 156 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 157 if self.augmentations is not None: 158 aug1, aug2 = self.augmentations 159 raw1, raw2 = aug1(raw), aug2(raw) 160 return raw1, raw2 161 162 return raw
An abstract class representing a Dataset
.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__()
, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__()
, which is expected to return the size of the dataset by many
~torch.utils.data.Sampler
implementations and the default options
of ~torch.utils.data.DataLoader
. Subclasses could also
optionally implement __getitems__()
, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
~torch.utils.data.DataLoader
by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
37 def __init__( 38 self, 39 raw_image_paths: Union[List[Any], str, os.PathLike], 40 patch_shape: Tuple[int, ...], 41 raw_transform=None, 42 transform=None, 43 dtype: torch.dtype = torch.float32, 44 n_samples: Optional[int] = None, 45 sampler=None, 46 augmentations=None, 47 full_check: bool = False, 48 ): 49 self._check_inputs(raw_image_paths, full_check) 50 self.raw_images = raw_image_paths 51 self._ndim = 2 52 53 assert len(patch_shape) == self._ndim 54 self.patch_shape = patch_shape 55 56 self.raw_transform = raw_transform 57 self.transform = transform 58 self.dtype = dtype 59 self.sampler = sampler 60 61 if n_samples is None: 62 self._len = len(self.raw_images) 63 self.sample_random_index = False 64 else: 65 self._len = n_samples 66 self.sample_random_index = True 67 68 if augmentations is not None: 69 assert len(augmentations) == 2 70 self.augmentations = augmentations