torch_em.data.dataset_wrapper
1from collections.abc import Sized 2from typing import Callable 3 4from torch.utils.data import Dataset 5 6 7class DatasetWrapper(Dataset): 8 def __init__( 9 self, 10 dataset: Dataset, 11 wrap_item: Callable 12 ): 13 assert isinstance(dataset, Dataset) and isinstance(dataset, Sized), "iterable datasets not supported" 14 self.dataset = dataset 15 self.wrap_item = wrap_item 16 17 def __getitem__(self, item): 18 return self.wrap_item(self.dataset[item]) 19 20 def __len__(self): 21 return len(self.dataset)
8class DatasetWrapper(Dataset): 9 def __init__( 10 self, 11 dataset: Dataset, 12 wrap_item: Callable 13 ): 14 assert isinstance(dataset, Dataset) and isinstance(dataset, Sized), "iterable datasets not supported" 15 self.dataset = dataset 16 self.wrap_item = wrap_item 17 18 def __getitem__(self, item): 19 return self.wrap_item(self.dataset[item]) 20 21 def __len__(self): 22 return len(self.dataset)
An abstract class representing a Dataset
.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__()
, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__()
, which is expected to return the size of the dataset by many
~torch.utils.data.Sampler
implementations and the default options
of ~torch.utils.data.DataLoader
. Subclasses could also
optionally implement __getitems__()
, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
~torch.utils.data.DataLoader
by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.