torch_em.data.concat_dataset

 1import numpy as np
 2
 3from torch.utils.data import Dataset
 4
 5
 6class ConcatDataset(Dataset):
 7    """Dataset to concatenate multiple PyTorch datasets.
 8
 9    Args:
10        datasets: The datasets to concatenate.
11    """
12    def __init__(self, *datasets: Dataset):
13        self.datasets = datasets
14        self.ndim = datasets[0].ndim
15
16        # compute the number of samples for each volume
17        self.ds_lens = [len(dataset) for dataset in self.datasets]
18        self._len = sum(self.ds_lens)
19
20        # compute the offsets for the samples
21        self.ds_offsets = np.cumsum(self.ds_lens)
22
23    def __len__(self):
24        return self._len
25
26    def __getitem__(self, idx):
27        # find the dataset id corresponding to this index
28        ds_idx = 0
29        while True:
30            if idx < self.ds_offsets[ds_idx]:
31                break
32            ds_idx += 1
33
34        # get sample from the dataset
35        ds = self.datasets[ds_idx]
36        offset = self.ds_offsets[ds_idx - 1] if ds_idx > 0 else 0
37        idx_in_ds = idx - offset
38        assert idx_in_ds < len(ds) and idx_in_ds >= 0, f"Failed with: {idx_in_ds}, {len(ds)}"
39        return ds[idx_in_ds]
class ConcatDataset(typing.Generic[+_T_co]):
 7class ConcatDataset(Dataset):
 8    """Dataset to concatenate multiple PyTorch datasets.
 9
10    Args:
11        datasets: The datasets to concatenate.
12    """
13    def __init__(self, *datasets: Dataset):
14        self.datasets = datasets
15        self.ndim = datasets[0].ndim
16
17        # compute the number of samples for each volume
18        self.ds_lens = [len(dataset) for dataset in self.datasets]
19        self._len = sum(self.ds_lens)
20
21        # compute the offsets for the samples
22        self.ds_offsets = np.cumsum(self.ds_lens)
23
24    def __len__(self):
25        return self._len
26
27    def __getitem__(self, idx):
28        # find the dataset id corresponding to this index
29        ds_idx = 0
30        while True:
31            if idx < self.ds_offsets[ds_idx]:
32                break
33            ds_idx += 1
34
35        # get sample from the dataset
36        ds = self.datasets[ds_idx]
37        offset = self.ds_offsets[ds_idx - 1] if ds_idx > 0 else 0
38        idx_in_ds = idx - offset
39        assert idx_in_ds < len(ds) and idx_in_ds >= 0, f"Failed with: {idx_in_ds}, {len(ds)}"
40        return ds[idx_in_ds]

Dataset to concatenate multiple PyTorch datasets.

Arguments:
  • datasets: The datasets to concatenate.
ConcatDataset(*datasets: torch.utils.data.dataset.Dataset)
13    def __init__(self, *datasets: Dataset):
14        self.datasets = datasets
15        self.ndim = datasets[0].ndim
16
17        # compute the number of samples for each volume
18        self.ds_lens = [len(dataset) for dataset in self.datasets]
19        self._len = sum(self.ds_lens)
20
21        # compute the offsets for the samples
22        self.ds_offsets = np.cumsum(self.ds_lens)
datasets
ndim
ds_lens
ds_offsets