torch_em.classification.classification_dataset
1from typing import Sequence, Tuple 2 3import numpy as np 4import torch 5 6from numpy.typing import ArrayLike 7from skimage.transform import resize 8 9 10class ClassificationDataset(torch.utils.data.Dataset): 11 """Dataset for classification training. 12 13 Args: 14 data: The input data for classification. Expects a sequence of array-like data. 15 The data can be two or three dimensional. 16 target: The target data for classification. Expects a sequence of the same length as `data`. 17 Each value in the sequence must be a scalar. 18 normalization: The normalization function. 19 augmentation: The augmentation function. 20 image_shape: The target shape of the data. If given, each sample will be resampled to this size. 21 """ 22 def __init__( 23 self, 24 data: Sequence[ArrayLike], 25 target: Sequence[ArrayLike], 26 normalization: callable, 27 augmentation: callable, 28 image_shape: Tuple[int, ...], 29 ): 30 if len(data) != len(target): 31 raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}") 32 self.data = data 33 self.target = target 34 self.normalization = normalization 35 self.augmentation = augmentation 36 self.image_shape = image_shape 37 38 def __len__(self): 39 return len(self.data) 40 41 def resize(self, x): 42 """@private 43 """ 44 out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x] 45 return np.concatenate(out, axis=0) 46 47 def __getitem__(self, index): 48 x, y = self.data[index], self.target[index] 49 50 # apply normalization 51 if self.normalization is not None: 52 x = self.normalization(x) 53 54 # resize to sample shape if it was given 55 if self.image_shape is not None: 56 x = self.resize(x) 57 58 # apply augmentations (if any) 59 if self.augmentation is not None: 60 _shape = x.shape 61 # adds unwanted batch axis 62 x = self.augmentation(x)[0][0] 63 assert x.shape == _shape 64 65 return x, y
class
ClassificationDataset(typing.Generic[+_T_co]):
11class ClassificationDataset(torch.utils.data.Dataset): 12 """Dataset for classification training. 13 14 Args: 15 data: The input data for classification. Expects a sequence of array-like data. 16 The data can be two or three dimensional. 17 target: The target data for classification. Expects a sequence of the same length as `data`. 18 Each value in the sequence must be a scalar. 19 normalization: The normalization function. 20 augmentation: The augmentation function. 21 image_shape: The target shape of the data. If given, each sample will be resampled to this size. 22 """ 23 def __init__( 24 self, 25 data: Sequence[ArrayLike], 26 target: Sequence[ArrayLike], 27 normalization: callable, 28 augmentation: callable, 29 image_shape: Tuple[int, ...], 30 ): 31 if len(data) != len(target): 32 raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}") 33 self.data = data 34 self.target = target 35 self.normalization = normalization 36 self.augmentation = augmentation 37 self.image_shape = image_shape 38 39 def __len__(self): 40 return len(self.data) 41 42 def resize(self, x): 43 """@private 44 """ 45 out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x] 46 return np.concatenate(out, axis=0) 47 48 def __getitem__(self, index): 49 x, y = self.data[index], self.target[index] 50 51 # apply normalization 52 if self.normalization is not None: 53 x = self.normalization(x) 54 55 # resize to sample shape if it was given 56 if self.image_shape is not None: 57 x = self.resize(x) 58 59 # apply augmentations (if any) 60 if self.augmentation is not None: 61 _shape = x.shape 62 # adds unwanted batch axis 63 x = self.augmentation(x)[0][0] 64 assert x.shape == _shape 65 66 return x, y
Dataset for classification training.
Arguments:
- data: The input data for classification. Expects a sequence of array-like data. The data can be two or three dimensional.
- target: The target data for classification. Expects a sequence of the same length as
data
. Each value in the sequence must be a scalar. - normalization: The normalization function.
- augmentation: The augmentation function.
- image_shape: The target shape of the data. If given, each sample will be resampled to this size.
ClassificationDataset( data: Sequence[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]], target: Sequence[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]], normalization: <built-in function callable>, augmentation: <built-in function callable>, image_shape: Tuple[int, ...])
23 def __init__( 24 self, 25 data: Sequence[ArrayLike], 26 target: Sequence[ArrayLike], 27 normalization: callable, 28 augmentation: callable, 29 image_shape: Tuple[int, ...], 30 ): 31 if len(data) != len(target): 32 raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}") 33 self.data = data 34 self.target = target 35 self.normalization = normalization 36 self.augmentation = augmentation 37 self.image_shape = image_shape