torch_em.data.concat_dataset
1import numpy as np 2 3from torch.utils.data import Dataset 4 5 6class ConcatDataset(Dataset): 7 """Dataset to concatenate multiple PyTorch datasets. 8 9 Args: 10 datasets: The datasets to concatenate. 11 """ 12 def __init__(self, *datasets: Dataset): 13 self.datasets = datasets 14 self.ndim = datasets[0].ndim 15 16 # compute the number of samples for each volume 17 self.ds_lens = [len(dataset) for dataset in self.datasets] 18 self._len = sum(self.ds_lens) 19 20 # compute the offsets for the samples 21 self.ds_offsets = np.cumsum(self.ds_lens) 22 23 def __len__(self): 24 return self._len 25 26 def __getitem__(self, idx): 27 # find the dataset id corresponding to this index 28 ds_idx = 0 29 while True: 30 if idx < self.ds_offsets[ds_idx]: 31 break 32 ds_idx += 1 33 34 # get sample from the dataset 35 ds = self.datasets[ds_idx] 36 offset = self.ds_offsets[ds_idx - 1] if ds_idx > 0 else 0 37 idx_in_ds = idx - offset 38 assert idx_in_ds < len(ds) and idx_in_ds >= 0, f"Failed with: {idx_in_ds}, {len(ds)}" 39 return ds[idx_in_ds]
class
ConcatDataset(typing.Generic[+_T_co]):
7class ConcatDataset(Dataset): 8 """Dataset to concatenate multiple PyTorch datasets. 9 10 Args: 11 datasets: The datasets to concatenate. 12 """ 13 def __init__(self, *datasets: Dataset): 14 self.datasets = datasets 15 self.ndim = datasets[0].ndim 16 17 # compute the number of samples for each volume 18 self.ds_lens = [len(dataset) for dataset in self.datasets] 19 self._len = sum(self.ds_lens) 20 21 # compute the offsets for the samples 22 self.ds_offsets = np.cumsum(self.ds_lens) 23 24 def __len__(self): 25 return self._len 26 27 def __getitem__(self, idx): 28 # find the dataset id corresponding to this index 29 ds_idx = 0 30 while True: 31 if idx < self.ds_offsets[ds_idx]: 32 break 33 ds_idx += 1 34 35 # get sample from the dataset 36 ds = self.datasets[ds_idx] 37 offset = self.ds_offsets[ds_idx - 1] if ds_idx > 0 else 0 38 idx_in_ds = idx - offset 39 assert idx_in_ds < len(ds) and idx_in_ds >= 0, f"Failed with: {idx_in_ds}, {len(ds)}" 40 return ds[idx_in_ds]
Dataset to concatenate multiple PyTorch datasets.
Arguments:
- datasets: The datasets to concatenate.
ConcatDataset(*datasets: torch.utils.data.dataset.Dataset)
13 def __init__(self, *datasets: Dataset): 14 self.datasets = datasets 15 self.ndim = datasets[0].ndim 16 17 # compute the number of samples for each volume 18 self.ds_lens = [len(dataset) for dataset in self.datasets] 19 self._len = sum(self.ds_lens) 20 21 # compute the offsets for the samples 22 self.ds_offsets = np.cumsum(self.ds_lens)