torch_em.classification.classification_dataset

 1import numpy as np
 2import torch
 3from skimage.transform import resize
 4
 5
 6class ClassificationDataset(torch.utils.data.Dataset):
 7    def __init__(self, data, target, normalization, augmentation, image_shape):
 8        if len(data) != len(target):
 9            raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}")
10        self.data = data
11        self.target = target
12        self.normalization = normalization
13        self.augmentation = augmentation
14        self.image_shape = image_shape
15
16    def __len__(self):
17        return len(self.data)
18
19    def resize(self, x):
20        out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x]
21        return np.concatenate(out, axis=0)
22
23    def __getitem__(self, index):
24        x, y = self.data[index], self.target[index]
25
26        # apply normalization
27        if self.normalization is not None:
28            x = self.normalization(x)
29
30        # resize to sample shape if it was given
31        if self.image_shape is not None:
32            x = self.resize(x)
33
34        # apply augmentations (if any)
35        if self.augmentation is not None:
36            _shape = x.shape
37            # adds unwanted batch axis
38            x = self.augmentation(x)[0][0]
39            assert x.shape == _shape
40
41        return x, y
class ClassificationDataset(typing.Generic[+T_co]):
 7class ClassificationDataset(torch.utils.data.Dataset):
 8    def __init__(self, data, target, normalization, augmentation, image_shape):
 9        if len(data) != len(target):
10            raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}")
11        self.data = data
12        self.target = target
13        self.normalization = normalization
14        self.augmentation = augmentation
15        self.image_shape = image_shape
16
17    def __len__(self):
18        return len(self.data)
19
20    def resize(self, x):
21        out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x]
22        return np.concatenate(out, axis=0)
23
24    def __getitem__(self, index):
25        x, y = self.data[index], self.target[index]
26
27        # apply normalization
28        if self.normalization is not None:
29            x = self.normalization(x)
30
31        # resize to sample shape if it was given
32        if self.image_shape is not None:
33            x = self.resize(x)
34
35        # apply augmentations (if any)
36        if self.augmentation is not None:
37            _shape = x.shape
38            # adds unwanted batch axis
39            x = self.augmentation(x)[0][0]
40            assert x.shape == _shape
41
42        return x, y

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

ClassificationDataset(data, target, normalization, augmentation, image_shape)
 8    def __init__(self, data, target, normalization, augmentation, image_shape):
 9        if len(data) != len(target):
10            raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}")
11        self.data = data
12        self.target = target
13        self.normalization = normalization
14        self.augmentation = augmentation
15        self.image_shape = image_shape
data
target
normalization
augmentation
image_shape
def resize(self, x):
20    def resize(self, x):
21        out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x]
22        return np.concatenate(out, axis=0)