torch_em.data.dataset_wrapper
1from typing import Callable 2from collections.abc import Sized 3 4from torch.utils.data import Dataset 5 6 7class DatasetWrapper(Dataset): 8 """Wrapper around a dataset that applies a function to items before retrieval. 9 10 Args: 11 dataset: The datast. 12 wrap_item: The function to apply to items before retrieval. 13 """ 14 def __init__(self, dataset: Dataset, wrap_item: Callable): 15 assert isinstance(dataset, Dataset) and isinstance(dataset, Sized), "iterable datasets not supported" 16 self.dataset = dataset 17 self.wrap_item = wrap_item 18 19 def __getitem__(self, item): 20 return self.wrap_item(self.dataset[item]) 21 22 def __len__(self): 23 return len(self.dataset)
class
DatasetWrapper(typing.Generic[+_T_co]):
8class DatasetWrapper(Dataset): 9 """Wrapper around a dataset that applies a function to items before retrieval. 10 11 Args: 12 dataset: The datast. 13 wrap_item: The function to apply to items before retrieval. 14 """ 15 def __init__(self, dataset: Dataset, wrap_item: Callable): 16 assert isinstance(dataset, Dataset) and isinstance(dataset, Sized), "iterable datasets not supported" 17 self.dataset = dataset 18 self.wrap_item = wrap_item 19 20 def __getitem__(self, item): 21 return self.wrap_item(self.dataset[item]) 22 23 def __len__(self): 24 return len(self.dataset)
Wrapper around a dataset that applies a function to items before retrieval.
Arguments:
- dataset: The datast.
- wrap_item: The function to apply to items before retrieval.