torch_em.data.concat_dataset
1import numpy as np 2 3from torch.utils.data import Dataset 4 5 6class ConcatDataset(Dataset): 7 def __init__( 8 self, 9 *datasets: Dataset 10 ): 11 self.datasets = datasets 12 self.ndim = datasets[0].ndim 13 14 # compute the number of samples for each volume 15 self.ds_lens = [len(dataset) for dataset in self.datasets] 16 self._len = sum(self.ds_lens) 17 18 # compute the offsets for the samples 19 self.ds_offsets = np.cumsum(self.ds_lens) 20 21 def __len__(self): 22 return self._len 23 24 def __getitem__(self, idx): 25 # find the dataset id corresponding to this index 26 ds_idx = 0 27 while True: 28 if idx < self.ds_offsets[ds_idx]: 29 break 30 ds_idx += 1 31 32 # get sample from the dataset 33 ds = self.datasets[ds_idx] 34 offset = self.ds_offsets[ds_idx - 1] if ds_idx > 0 else 0 35 idx_in_ds = idx - offset 36 assert idx_in_ds < len(ds) and idx_in_ds >= 0, f"Failed with: {idx_in_ds}, {len(ds)}" 37 return ds[idx_in_ds]
7class ConcatDataset(Dataset): 8 def __init__( 9 self, 10 *datasets: Dataset 11 ): 12 self.datasets = datasets 13 self.ndim = datasets[0].ndim 14 15 # compute the number of samples for each volume 16 self.ds_lens = [len(dataset) for dataset in self.datasets] 17 self._len = sum(self.ds_lens) 18 19 # compute the offsets for the samples 20 self.ds_offsets = np.cumsum(self.ds_lens) 21 22 def __len__(self): 23 return self._len 24 25 def __getitem__(self, idx): 26 # find the dataset id corresponding to this index 27 ds_idx = 0 28 while True: 29 if idx < self.ds_offsets[ds_idx]: 30 break 31 ds_idx += 1 32 33 # get sample from the dataset 34 ds = self.datasets[ds_idx] 35 offset = self.ds_offsets[ds_idx - 1] if ds_idx > 0 else 0 36 idx_in_ds = idx - offset 37 assert idx_in_ds < len(ds) and idx_in_ds >= 0, f"Failed with: {idx_in_ds}, {len(ds)}" 38 return ds[idx_in_ds]
An abstract class representing a Dataset
.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__()
, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__()
, which is expected to return the size of the dataset by many
~torch.utils.data.Sampler
implementations and the default options
of ~torch.utils.data.DataLoader
. Subclasses could also
optionally implement __getitems__()
, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
~torch.utils.data.DataLoader
by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
8 def __init__( 9 self, 10 *datasets: Dataset 11 ): 12 self.datasets = datasets 13 self.ndim = datasets[0].ndim 14 15 # compute the number of samples for each volume 16 self.ds_lens = [len(dataset) for dataset in self.datasets] 17 self._len = sum(self.ds_lens) 18 19 # compute the offsets for the samples 20 self.ds_offsets = np.cumsum(self.ds_lens)