torch_em.data.dataset_wrapper

 1from collections.abc import Sized
 2from typing import Callable
 3
 4from torch.utils.data import Dataset
 5
 6
 7class DatasetWrapper(Dataset):
 8    def __init__(
 9        self,
10        dataset: Dataset,
11        wrap_item: Callable
12    ):
13        assert isinstance(dataset, Dataset) and isinstance(dataset, Sized), "iterable datasets not supported"
14        self.dataset = dataset
15        self.wrap_item = wrap_item
16
17    def __getitem__(self, item):
18        return self.wrap_item(self.dataset[item])
19
20    def __len__(self):
21        return len(self.dataset)
class DatasetWrapper(typing.Generic[+T_co]):
 8class DatasetWrapper(Dataset):
 9    def __init__(
10        self,
11        dataset: Dataset,
12        wrap_item: Callable
13    ):
14        assert isinstance(dataset, Dataset) and isinstance(dataset, Sized), "iterable datasets not supported"
15        self.dataset = dataset
16        self.wrap_item = wrap_item
17
18    def __getitem__(self, item):
19        return self.wrap_item(self.dataset[item])
20
21    def __len__(self):
22        return len(self.dataset)

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

DatasetWrapper(dataset: torch.utils.data.dataset.Dataset, wrap_item: Callable)
 9    def __init__(
10        self,
11        dataset: Dataset,
12        wrap_item: Callable
13    ):
14        assert isinstance(dataset, Dataset) and isinstance(dataset, Sized), "iterable datasets not supported"
15        self.dataset = dataset
16        self.wrap_item = wrap_item
dataset
wrap_item