torch_em.classification.classification_dataset

 1from typing import Sequence, Tuple
 2
 3import numpy as np
 4import torch
 5
 6from numpy.typing import ArrayLike
 7from skimage.transform import resize
 8
 9
10class ClassificationDataset(torch.utils.data.Dataset):
11    """Dataset for classification training.
12
13    Args:
14        data: The input data for classification. Expects a sequence of array-like data.
15            The data can be two or three dimensional.
16        target: The target data for classification. Expects a sequence of the same length as `data`.
17            Each value in the sequence must be a scalar.
18        normalization: The normalization function.
19        augmentation: The augmentation function.
20        image_shape: The target shape of the data. If given, each sample will be resampled to this size.
21    """
22    def __init__(
23        self,
24        data: Sequence[ArrayLike],
25        target: Sequence[ArrayLike],
26        normalization: callable,
27        augmentation: callable,
28        image_shape: Tuple[int, ...],
29    ):
30        if len(data) != len(target):
31            raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}")
32        self.data = data
33        self.target = target
34        self.normalization = normalization
35        self.augmentation = augmentation
36        self.image_shape = image_shape
37
38    def __len__(self):
39        return len(self.data)
40
41    def resize(self, x):
42        """@private
43        """
44        out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x]
45        return np.concatenate(out, axis=0)
46
47    def __getitem__(self, index):
48        x, y = self.data[index], self.target[index]
49
50        # apply normalization
51        if self.normalization is not None:
52            x = self.normalization(x)
53
54        # resize to sample shape if it was given
55        if self.image_shape is not None:
56            x = self.resize(x)
57
58        # apply augmentations (if any)
59        if self.augmentation is not None:
60            _shape = x.shape
61            # adds unwanted batch axis
62            x = self.augmentation(x)[0][0]
63            assert x.shape == _shape
64
65        return x, y
class ClassificationDataset(typing.Generic[+_T_co]):
11class ClassificationDataset(torch.utils.data.Dataset):
12    """Dataset for classification training.
13
14    Args:
15        data: The input data for classification. Expects a sequence of array-like data.
16            The data can be two or three dimensional.
17        target: The target data for classification. Expects a sequence of the same length as `data`.
18            Each value in the sequence must be a scalar.
19        normalization: The normalization function.
20        augmentation: The augmentation function.
21        image_shape: The target shape of the data. If given, each sample will be resampled to this size.
22    """
23    def __init__(
24        self,
25        data: Sequence[ArrayLike],
26        target: Sequence[ArrayLike],
27        normalization: callable,
28        augmentation: callable,
29        image_shape: Tuple[int, ...],
30    ):
31        if len(data) != len(target):
32            raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}")
33        self.data = data
34        self.target = target
35        self.normalization = normalization
36        self.augmentation = augmentation
37        self.image_shape = image_shape
38
39    def __len__(self):
40        return len(self.data)
41
42    def resize(self, x):
43        """@private
44        """
45        out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x]
46        return np.concatenate(out, axis=0)
47
48    def __getitem__(self, index):
49        x, y = self.data[index], self.target[index]
50
51        # apply normalization
52        if self.normalization is not None:
53            x = self.normalization(x)
54
55        # resize to sample shape if it was given
56        if self.image_shape is not None:
57            x = self.resize(x)
58
59        # apply augmentations (if any)
60        if self.augmentation is not None:
61            _shape = x.shape
62            # adds unwanted batch axis
63            x = self.augmentation(x)[0][0]
64            assert x.shape == _shape
65
66        return x, y

Dataset for classification training.

Arguments:
  • data: The input data for classification. Expects a sequence of array-like data. The data can be two or three dimensional.
  • target: The target data for classification. Expects a sequence of the same length as data. Each value in the sequence must be a scalar.
  • normalization: The normalization function.
  • augmentation: The augmentation function.
  • image_shape: The target shape of the data. If given, each sample will be resampled to this size.
ClassificationDataset( data: Sequence[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]], target: Sequence[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]], normalization: <built-in function callable>, augmentation: <built-in function callable>, image_shape: Tuple[int, ...])
23    def __init__(
24        self,
25        data: Sequence[ArrayLike],
26        target: Sequence[ArrayLike],
27        normalization: callable,
28        augmentation: callable,
29        image_shape: Tuple[int, ...],
30    ):
31        if len(data) != len(target):
32            raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}")
33        self.data = data
34        self.target = target
35        self.normalization = normalization
36        self.augmentation = augmentation
37        self.image_shape = image_shape
data
target
normalization
augmentation
image_shape