torch_em.data.concat_dataset

 1import numpy as np
 2
 3from torch.utils.data import Dataset
 4
 5
 6class ConcatDataset(Dataset):
 7    def __init__(
 8        self,
 9        *datasets: Dataset
10    ):
11        self.datasets = datasets
12        self.ndim = datasets[0].ndim
13
14        # compute the number of samples for each volume
15        self.ds_lens = [len(dataset) for dataset in self.datasets]
16        self._len = sum(self.ds_lens)
17
18        # compute the offsets for the samples
19        self.ds_offsets = np.cumsum(self.ds_lens)
20
21    def __len__(self):
22        return self._len
23
24    def __getitem__(self, idx):
25        # find the dataset id corresponding to this index
26        ds_idx = 0
27        while True:
28            if idx < self.ds_offsets[ds_idx]:
29                break
30            ds_idx += 1
31
32        # get sample from the dataset
33        ds = self.datasets[ds_idx]
34        offset = self.ds_offsets[ds_idx - 1] if ds_idx > 0 else 0
35        idx_in_ds = idx - offset
36        assert idx_in_ds < len(ds) and idx_in_ds >= 0, f"Failed with: {idx_in_ds}, {len(ds)}"
37        return ds[idx_in_ds]
class ConcatDataset(typing.Generic[+T_co]):
 7class ConcatDataset(Dataset):
 8    def __init__(
 9        self,
10        *datasets: Dataset
11    ):
12        self.datasets = datasets
13        self.ndim = datasets[0].ndim
14
15        # compute the number of samples for each volume
16        self.ds_lens = [len(dataset) for dataset in self.datasets]
17        self._len = sum(self.ds_lens)
18
19        # compute the offsets for the samples
20        self.ds_offsets = np.cumsum(self.ds_lens)
21
22    def __len__(self):
23        return self._len
24
25    def __getitem__(self, idx):
26        # find the dataset id corresponding to this index
27        ds_idx = 0
28        while True:
29            if idx < self.ds_offsets[ds_idx]:
30                break
31            ds_idx += 1
32
33        # get sample from the dataset
34        ds = self.datasets[ds_idx]
35        offset = self.ds_offsets[ds_idx - 1] if ds_idx > 0 else 0
36        idx_in_ds = idx - offset
37        assert idx_in_ds < len(ds) and idx_in_ds >= 0, f"Failed with: {idx_in_ds}, {len(ds)}"
38        return ds[idx_in_ds]

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

ConcatDataset(*datasets: torch.utils.data.dataset.Dataset)
 8    def __init__(
 9        self,
10        *datasets: Dataset
11    ):
12        self.datasets = datasets
13        self.ndim = datasets[0].ndim
14
15        # compute the number of samples for each volume
16        self.ds_lens = [len(dataset) for dataset in self.datasets]
17        self._len = sum(self.ds_lens)
18
19        # compute the offsets for the samples
20        self.ds_offsets = np.cumsum(self.ds_lens)
datasets
ndim
ds_lens
ds_offsets