torch_em.classification.classification_dataset
1import numpy as np 2import torch 3from skimage.transform import resize 4 5 6class ClassificationDataset(torch.utils.data.Dataset): 7 def __init__(self, data, target, normalization, augmentation, image_shape): 8 if len(data) != len(target): 9 raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}") 10 self.data = data 11 self.target = target 12 self.normalization = normalization 13 self.augmentation = augmentation 14 self.image_shape = image_shape 15 16 def __len__(self): 17 return len(self.data) 18 19 def resize(self, x): 20 out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x] 21 return np.concatenate(out, axis=0) 22 23 def __getitem__(self, index): 24 x, y = self.data[index], self.target[index] 25 26 # apply normalization 27 if self.normalization is not None: 28 x = self.normalization(x) 29 30 # resize to sample shape if it was given 31 if self.image_shape is not None: 32 x = self.resize(x) 33 34 # apply augmentations (if any) 35 if self.augmentation is not None: 36 _shape = x.shape 37 # adds unwanted batch axis 38 x = self.augmentation(x)[0][0] 39 assert x.shape == _shape 40 41 return x, y
7class ClassificationDataset(torch.utils.data.Dataset): 8 def __init__(self, data, target, normalization, augmentation, image_shape): 9 if len(data) != len(target): 10 raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}") 11 self.data = data 12 self.target = target 13 self.normalization = normalization 14 self.augmentation = augmentation 15 self.image_shape = image_shape 16 17 def __len__(self): 18 return len(self.data) 19 20 def resize(self, x): 21 out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x] 22 return np.concatenate(out, axis=0) 23 24 def __getitem__(self, index): 25 x, y = self.data[index], self.target[index] 26 27 # apply normalization 28 if self.normalization is not None: 29 x = self.normalization(x) 30 31 # resize to sample shape if it was given 32 if self.image_shape is not None: 33 x = self.resize(x) 34 35 # apply augmentations (if any) 36 if self.augmentation is not None: 37 _shape = x.shape 38 # adds unwanted batch axis 39 x = self.augmentation(x)[0][0] 40 assert x.shape == _shape 41 42 return x, y
An abstract class representing a Dataset
.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__()
, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__()
, which is expected to return the size of the dataset by many
~torch.utils.data.Sampler
implementations and the default options
of ~torch.utils.data.DataLoader
. Subclasses could also
optionally implement __getitems__()
, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
~torch.utils.data.DataLoader
by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
8 def __init__(self, data, target, normalization, augmentation, image_shape): 9 if len(data) != len(target): 10 raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}") 11 self.data = data 12 self.target = target 13 self.normalization = normalization 14 self.augmentation = augmentation 15 self.image_shape = image_shape