torch_em.data.dataset_wrapper

 1from typing import Callable
 2from collections.abc import Sized
 3
 4from torch.utils.data import Dataset
 5
 6
 7class DatasetWrapper(Dataset):
 8    """Wrapper around a dataset that applies a function to items before retrieval.
 9
10    Args:
11        dataset: The datast.
12        wrap_item: The function to apply to items before retrieval.
13    """
14    def __init__(self, dataset: Dataset, wrap_item: Callable):
15        assert isinstance(dataset, Dataset) and isinstance(dataset, Sized), "iterable datasets not supported"
16        self.dataset = dataset
17        self.wrap_item = wrap_item
18
19    def __getitem__(self, item):
20        return self.wrap_item(self.dataset[item])
21
22    def __len__(self):
23        return len(self.dataset)
class DatasetWrapper(typing.Generic[+_T_co]):
 8class DatasetWrapper(Dataset):
 9    """Wrapper around a dataset that applies a function to items before retrieval.
10
11    Args:
12        dataset: The datast.
13        wrap_item: The function to apply to items before retrieval.
14    """
15    def __init__(self, dataset: Dataset, wrap_item: Callable):
16        assert isinstance(dataset, Dataset) and isinstance(dataset, Sized), "iterable datasets not supported"
17        self.dataset = dataset
18        self.wrap_item = wrap_item
19
20    def __getitem__(self, item):
21        return self.wrap_item(self.dataset[item])
22
23    def __len__(self):
24        return len(self.dataset)

Wrapper around a dataset that applies a function to items before retrieval.

Arguments:
  • dataset: The datast.
  • wrap_item: The function to apply to items before retrieval.
DatasetWrapper(dataset: torch.utils.data.dataset.Dataset, wrap_item: Callable)
15    def __init__(self, dataset: Dataset, wrap_item: Callable):
16        assert isinstance(dataset, Dataset) and isinstance(dataset, Sized), "iterable datasets not supported"
17        self.dataset = dataset
18        self.wrap_item = wrap_item
dataset
wrap_item