torch_em.classification.classification
1from functools import partial 2from typing import Optional, Sequence, Tuple, Union 3 4import numpy as np 5import sklearn.metrics as metrics 6import torch 7import torch_em 8from numpy.typing import ArrayLike 9 10from .classification_dataset import ClassificationDataset 11from .classification_logger import ClassificationLogger 12from .classification_trainer import ClassificationTrainer 13 14 15class ClassificationMetric: 16 """Metric for classification training. 17 18 Args: 19 metric_name: The name of the metrics. The name will be looked up in `sklearn.metrics`, 20 so it must be a valid identifier in that python package. 21 metric_kwargs: Keyword arguments for the metric. 22 """ 23 def __init__(self, metric_name: str = "accuracy_score", **metric_kwargs): 24 if not hasattr(metrics, metric_name): 25 raise ValueError(f"Invalid metric_name {metric_name}.") 26 self.metric = getattr(metrics, metric_name) 27 self.metric_kwargs = metric_kwargs 28 29 def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float: 30 """Evaluate model prediction against classification labels. 31 32 Args: 33 y_true: The classification labels. 34 y_pred: The model predictions. 35 36 Returns: 37 The metric value. 38 """ 39 metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs) 40 return metric_error 41 42 43def default_classification_loader( 44 data: Sequence[ArrayLike], 45 target: Sequence[ArrayLike], 46 batch_size: int, 47 normalization: Optional[callable] = None, 48 augmentation: Optional[callable] = None, 49 image_shape: Optional[Tuple[int, ...]] = None, 50 **loader_kwargs, 51) -> torch.utils.data.DataLoader: 52 """Get a data loader for classification training. 53 54 Args: 55 data: The input data for classification. Expects a sequence of array-like data. 56 The data can be two or three dimensional. 57 target: The target data for classification. Expects a sequence of the same length as `data`. 58 Each value in the sequence must be a scalar. 59 batch_size: The batch size for the data loader. 60 normalization: The normalization function. If None, data standardization will be used. 61 augmentation: The augmentation function. If None, the default augmentations will be used. 62 image_shape: The target shape of the data. If given, each sample will be resampled to this size. 63 loader_kwargs: Additional keyword arguments for `torch.utils.data.DataLoader`. 64 65 Returns: 66 The data loader. 67 """ 68 ndim = data[0].ndim - 1 69 if ndim not in (2, 3): 70 raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}") 71 72 if normalization is None: 73 axis = (1, 2) if ndim == 2 else (1, 2, 3) 74 normalization = partial(torch_em.transform.raw.standardize, axis=axis) 75 76 if augmentation is None: 77 augmentation = torch_em.transform.get_augmentations(ndim=ndim) 78 79 dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape) 80 loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs) 81 return loader 82 83 84def default_classification_trainer( 85 name: str, 86 model: torch.nn.Module, 87 train_loader: torch.utils.data.DataLoader, 88 val_loader: torch.utils.data.DataLoader, 89 loss: Optional[Union[torch.nn.Module, callable]] = None, 90 metric: Optional[Union[torch.nn.Module, callable]] = None, 91 logger=ClassificationLogger, 92 trainer_class=ClassificationTrainer, 93 **kwargs, 94): 95 """Get a trainer for a classification task. 96 97 This will create an instance of `torch_em.classification.ClassificationTrainer`. 98 Check out its documentation string for details on how to configure and use the trainer. 99 100 Args: 101 name: The name for the checkpoint created by the trainer. 102 model: The classification model to train. 103 train_loader: The data loader for training. 104 val_loader: The data loader for validation. 105 loss: The loss function. If None, will use cross entropy. 106 metric: The metric function. If None, will use the accuracy error. 107 logger: The logger for keeping track of the training progress. 108 trainer_class: The trainer class. 109 kwargs: Keyword arguments for the trainer class. 110 111 Returns: 112 The classification trainer. 113 """ 114 # Set the default loss and metric (if no values where passed). 115 loss = torch.nn.CrossEntropyLoss() if loss is None else loss 116 metric = ClassificationMetric() if metric is None else metric 117 118 # Metric: Note that we use lower metric = better. 119 # So we record the accuracy error instead of the accuracy.. 120 trainer = torch_em.default_segmentation_trainer( 121 name, model, train_loader, val_loader, 122 loss=loss, metric=metric, 123 logger=logger, trainer_class=trainer_class, 124 **kwargs, 125 ) 126 return trainer
class
ClassificationMetric:
16class ClassificationMetric: 17 """Metric for classification training. 18 19 Args: 20 metric_name: The name of the metrics. The name will be looked up in `sklearn.metrics`, 21 so it must be a valid identifier in that python package. 22 metric_kwargs: Keyword arguments for the metric. 23 """ 24 def __init__(self, metric_name: str = "accuracy_score", **metric_kwargs): 25 if not hasattr(metrics, metric_name): 26 raise ValueError(f"Invalid metric_name {metric_name}.") 27 self.metric = getattr(metrics, metric_name) 28 self.metric_kwargs = metric_kwargs 29 30 def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float: 31 """Evaluate model prediction against classification labels. 32 33 Args: 34 y_true: The classification labels. 35 y_pred: The model predictions. 36 37 Returns: 38 The metric value. 39 """ 40 metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs) 41 return metric_error
Metric for classification training.
Arguments:
- metric_name: The name of the metrics. The name will be looked up in
sklearn.metrics
, so it must be a valid identifier in that python package. - metric_kwargs: Keyword arguments for the metric.
def
default_classification_loader( data: Sequence[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]], target: Sequence[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]], batch_size: int, normalization: Optional[<built-in function callable>] = None, augmentation: Optional[<built-in function callable>] = None, image_shape: Optional[Tuple[int, ...]] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
44def default_classification_loader( 45 data: Sequence[ArrayLike], 46 target: Sequence[ArrayLike], 47 batch_size: int, 48 normalization: Optional[callable] = None, 49 augmentation: Optional[callable] = None, 50 image_shape: Optional[Tuple[int, ...]] = None, 51 **loader_kwargs, 52) -> torch.utils.data.DataLoader: 53 """Get a data loader for classification training. 54 55 Args: 56 data: The input data for classification. Expects a sequence of array-like data. 57 The data can be two or three dimensional. 58 target: The target data for classification. Expects a sequence of the same length as `data`. 59 Each value in the sequence must be a scalar. 60 batch_size: The batch size for the data loader. 61 normalization: The normalization function. If None, data standardization will be used. 62 augmentation: The augmentation function. If None, the default augmentations will be used. 63 image_shape: The target shape of the data. If given, each sample will be resampled to this size. 64 loader_kwargs: Additional keyword arguments for `torch.utils.data.DataLoader`. 65 66 Returns: 67 The data loader. 68 """ 69 ndim = data[0].ndim - 1 70 if ndim not in (2, 3): 71 raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}") 72 73 if normalization is None: 74 axis = (1, 2) if ndim == 2 else (1, 2, 3) 75 normalization = partial(torch_em.transform.raw.standardize, axis=axis) 76 77 if augmentation is None: 78 augmentation = torch_em.transform.get_augmentations(ndim=ndim) 79 80 dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape) 81 loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs) 82 return loader
Get a data loader for classification training.
Arguments:
- data: The input data for classification. Expects a sequence of array-like data. The data can be two or three dimensional.
- target: The target data for classification. Expects a sequence of the same length as
data
. Each value in the sequence must be a scalar. - batch_size: The batch size for the data loader.
- normalization: The normalization function. If None, data standardization will be used.
- augmentation: The augmentation function. If None, the default augmentations will be used.
- image_shape: The target shape of the data. If given, each sample will be resampled to this size.
- loader_kwargs: Additional keyword arguments for
torch.utils.data.DataLoader
.
Returns:
The data loader.
def
default_classification_trainer( name: str, model: torch.nn.modules.module.Module, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, loss: Union[torch.nn.modules.module.Module, <built-in function callable>, NoneType] = None, metric: Union[torch.nn.modules.module.Module, <built-in function callable>, NoneType] = None, logger=<class 'torch_em.classification.classification_logger.ClassificationLogger'>, trainer_class=<class 'torch_em.classification.classification_trainer.ClassificationTrainer'>, **kwargs):
85def default_classification_trainer( 86 name: str, 87 model: torch.nn.Module, 88 train_loader: torch.utils.data.DataLoader, 89 val_loader: torch.utils.data.DataLoader, 90 loss: Optional[Union[torch.nn.Module, callable]] = None, 91 metric: Optional[Union[torch.nn.Module, callable]] = None, 92 logger=ClassificationLogger, 93 trainer_class=ClassificationTrainer, 94 **kwargs, 95): 96 """Get a trainer for a classification task. 97 98 This will create an instance of `torch_em.classification.ClassificationTrainer`. 99 Check out its documentation string for details on how to configure and use the trainer. 100 101 Args: 102 name: The name for the checkpoint created by the trainer. 103 model: The classification model to train. 104 train_loader: The data loader for training. 105 val_loader: The data loader for validation. 106 loss: The loss function. If None, will use cross entropy. 107 metric: The metric function. If None, will use the accuracy error. 108 logger: The logger for keeping track of the training progress. 109 trainer_class: The trainer class. 110 kwargs: Keyword arguments for the trainer class. 111 112 Returns: 113 The classification trainer. 114 """ 115 # Set the default loss and metric (if no values where passed). 116 loss = torch.nn.CrossEntropyLoss() if loss is None else loss 117 metric = ClassificationMetric() if metric is None else metric 118 119 # Metric: Note that we use lower metric = better. 120 # So we record the accuracy error instead of the accuracy.. 121 trainer = torch_em.default_segmentation_trainer( 122 name, model, train_loader, val_loader, 123 loss=loss, metric=metric, 124 logger=logger, trainer_class=trainer_class, 125 **kwargs, 126 ) 127 return trainer
Get a trainer for a classification task.
This will create an instance of torch_em.classification.ClassificationTrainer
.
Check out its documentation string for details on how to configure and use the trainer.
Arguments:
- name: The name for the checkpoint created by the trainer.
- model: The classification model to train.
- train_loader: The data loader for training.
- val_loader: The data loader for validation.
- loss: The loss function. If None, will use cross entropy.
- metric: The metric function. If None, will use the accuracy error.
- logger: The logger for keeping track of the training progress.
- trainer_class: The trainer class.
- kwargs: Keyword arguments for the trainer class.
Returns:
The classification trainer.