torch_em.classification.classification

  1from functools import partial
  2from typing import Optional, Sequence, Tuple, Union
  3
  4import numpy as np
  5import sklearn.metrics as metrics
  6import torch
  7import torch_em
  8from numpy.typing import ArrayLike
  9
 10from .classification_dataset import ClassificationDataset
 11from .classification_logger import ClassificationLogger
 12from .classification_trainer import ClassificationTrainer
 13
 14
 15class ClassificationMetric:
 16    """Metric for classification training.
 17
 18    Args:
 19        metric_name: The name of the metrics. The name will be looked up in `sklearn.metrics`,
 20            so it must be a valid identifier in that python package.
 21        metric_kwargs: Keyword arguments for the metric.
 22    """
 23    def __init__(self, metric_name: str = "accuracy_score", **metric_kwargs):
 24        if not hasattr(metrics, metric_name):
 25            raise ValueError(f"Invalid metric_name {metric_name}.")
 26        self.metric = getattr(metrics, metric_name)
 27        self.metric_kwargs = metric_kwargs
 28
 29    def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
 30        """Evaluate model prediction against classification labels.
 31
 32        Args:
 33            y_true: The classification labels.
 34            y_pred: The model predictions.
 35
 36        Returns:
 37            The metric value.
 38        """
 39        metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs)
 40        return metric_error
 41
 42
 43def default_classification_loader(
 44    data: Sequence[ArrayLike],
 45    target: Sequence[ArrayLike],
 46    batch_size: int,
 47    normalization: Optional[callable] = None,
 48    augmentation: Optional[callable] = None,
 49    image_shape: Optional[Tuple[int, ...]] = None,
 50    **loader_kwargs,
 51) -> torch.utils.data.DataLoader:
 52    """Get a data loader for classification training.
 53
 54    Args:
 55        data: The input data for classification. Expects a sequence of array-like data.
 56            The data can be two or three dimensional.
 57        target: The target data for classification. Expects a sequence of the same length as `data`.
 58            Each value in the sequence must be a scalar.
 59        batch_size: The batch size for the data loader.
 60        normalization: The normalization function. If None, data standardization will be used.
 61        augmentation: The augmentation function. If None, the default augmentations will be used.
 62        image_shape: The target shape of the data. If given, each sample will be resampled to this size.
 63        loader_kwargs: Additional keyword arguments for `torch.utils.data.DataLoader`.
 64
 65    Returns:
 66        The data loader.
 67    """
 68    ndim = data[0].ndim - 1
 69    if ndim not in (2, 3):
 70        raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}")
 71
 72    if normalization is None:
 73        axis = (1, 2) if ndim == 2 else (1, 2, 3)
 74        normalization = partial(torch_em.transform.raw.standardize, axis=axis)
 75
 76    if augmentation is None:
 77        augmentation = torch_em.transform.get_augmentations(ndim=ndim)
 78
 79    dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape)
 80    loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs)
 81    return loader
 82
 83
 84def default_classification_trainer(
 85    name: str,
 86    model: torch.nn.Module,
 87    train_loader: torch.utils.data.DataLoader,
 88    val_loader: torch.utils.data.DataLoader,
 89    loss: Optional[Union[torch.nn.Module, callable]] = None,
 90    metric: Optional[Union[torch.nn.Module, callable]] = None,
 91    logger=ClassificationLogger,
 92    trainer_class=ClassificationTrainer,
 93    **kwargs,
 94):
 95    """Get a trainer for a classification task.
 96
 97    This will create an instance of `torch_em.classification.ClassificationTrainer`.
 98    Check out its documentation string for details on how to configure and use the trainer.
 99
100    Args:
101        name: The name for the checkpoint created by the trainer.
102        model: The classification model to train.
103        train_loader: The data loader for training.
104        val_loader: The data loader for validation.
105        loss: The loss function. If None, will use cross entropy.
106        metric: The metric function. If None, will use the accuracy error.
107        logger: The logger for keeping track of the training progress.
108        trainer_class: The trainer class.
109        kwargs: Keyword arguments for the trainer class.
110
111    Returns:
112        The classification trainer.
113    """
114    # Set the default loss and metric (if no values where passed).
115    loss = torch.nn.CrossEntropyLoss() if loss is None else loss
116    metric = ClassificationMetric() if metric is None else metric
117
118    # Metric: Note that we use lower metric = better.
119    # So we record the accuracy error instead of the accuracy..
120    trainer = torch_em.default_segmentation_trainer(
121        name, model, train_loader, val_loader,
122        loss=loss, metric=metric,
123        logger=logger, trainer_class=trainer_class,
124        **kwargs,
125    )
126    return trainer
class ClassificationMetric:
16class ClassificationMetric:
17    """Metric for classification training.
18
19    Args:
20        metric_name: The name of the metrics. The name will be looked up in `sklearn.metrics`,
21            so it must be a valid identifier in that python package.
22        metric_kwargs: Keyword arguments for the metric.
23    """
24    def __init__(self, metric_name: str = "accuracy_score", **metric_kwargs):
25        if not hasattr(metrics, metric_name):
26            raise ValueError(f"Invalid metric_name {metric_name}.")
27        self.metric = getattr(metrics, metric_name)
28        self.metric_kwargs = metric_kwargs
29
30    def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
31        """Evaluate model prediction against classification labels.
32
33        Args:
34            y_true: The classification labels.
35            y_pred: The model predictions.
36
37        Returns:
38            The metric value.
39        """
40        metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs)
41        return metric_error

Metric for classification training.

Arguments:
  • metric_name: The name of the metrics. The name will be looked up in sklearn.metrics, so it must be a valid identifier in that python package.
  • metric_kwargs: Keyword arguments for the metric.
ClassificationMetric(metric_name: str = 'accuracy_score', **metric_kwargs)
24    def __init__(self, metric_name: str = "accuracy_score", **metric_kwargs):
25        if not hasattr(metrics, metric_name):
26            raise ValueError(f"Invalid metric_name {metric_name}.")
27        self.metric = getattr(metrics, metric_name)
28        self.metric_kwargs = metric_kwargs
metric
metric_kwargs
def default_classification_loader( data: Sequence[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]], target: Sequence[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]], batch_size: int, normalization: Optional[<built-in function callable>] = None, augmentation: Optional[<built-in function callable>] = None, image_shape: Optional[Tuple[int, ...]] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
44def default_classification_loader(
45    data: Sequence[ArrayLike],
46    target: Sequence[ArrayLike],
47    batch_size: int,
48    normalization: Optional[callable] = None,
49    augmentation: Optional[callable] = None,
50    image_shape: Optional[Tuple[int, ...]] = None,
51    **loader_kwargs,
52) -> torch.utils.data.DataLoader:
53    """Get a data loader for classification training.
54
55    Args:
56        data: The input data for classification. Expects a sequence of array-like data.
57            The data can be two or three dimensional.
58        target: The target data for classification. Expects a sequence of the same length as `data`.
59            Each value in the sequence must be a scalar.
60        batch_size: The batch size for the data loader.
61        normalization: The normalization function. If None, data standardization will be used.
62        augmentation: The augmentation function. If None, the default augmentations will be used.
63        image_shape: The target shape of the data. If given, each sample will be resampled to this size.
64        loader_kwargs: Additional keyword arguments for `torch.utils.data.DataLoader`.
65
66    Returns:
67        The data loader.
68    """
69    ndim = data[0].ndim - 1
70    if ndim not in (2, 3):
71        raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}")
72
73    if normalization is None:
74        axis = (1, 2) if ndim == 2 else (1, 2, 3)
75        normalization = partial(torch_em.transform.raw.standardize, axis=axis)
76
77    if augmentation is None:
78        augmentation = torch_em.transform.get_augmentations(ndim=ndim)
79
80    dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape)
81    loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs)
82    return loader

Get a data loader for classification training.

Arguments:
  • data: The input data for classification. Expects a sequence of array-like data. The data can be two or three dimensional.
  • target: The target data for classification. Expects a sequence of the same length as data. Each value in the sequence must be a scalar.
  • batch_size: The batch size for the data loader.
  • normalization: The normalization function. If None, data standardization will be used.
  • augmentation: The augmentation function. If None, the default augmentations will be used.
  • image_shape: The target shape of the data. If given, each sample will be resampled to this size.
  • loader_kwargs: Additional keyword arguments for torch.utils.data.DataLoader.
Returns:

The data loader.

def default_classification_trainer( name: str, model: torch.nn.modules.module.Module, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, loss: Union[torch.nn.modules.module.Module, <built-in function callable>, NoneType] = None, metric: Union[torch.nn.modules.module.Module, <built-in function callable>, NoneType] = None, logger=<class 'torch_em.classification.classification_logger.ClassificationLogger'>, trainer_class=<class 'torch_em.classification.classification_trainer.ClassificationTrainer'>, **kwargs):
 85def default_classification_trainer(
 86    name: str,
 87    model: torch.nn.Module,
 88    train_loader: torch.utils.data.DataLoader,
 89    val_loader: torch.utils.data.DataLoader,
 90    loss: Optional[Union[torch.nn.Module, callable]] = None,
 91    metric: Optional[Union[torch.nn.Module, callable]] = None,
 92    logger=ClassificationLogger,
 93    trainer_class=ClassificationTrainer,
 94    **kwargs,
 95):
 96    """Get a trainer for a classification task.
 97
 98    This will create an instance of `torch_em.classification.ClassificationTrainer`.
 99    Check out its documentation string for details on how to configure and use the trainer.
100
101    Args:
102        name: The name for the checkpoint created by the trainer.
103        model: The classification model to train.
104        train_loader: The data loader for training.
105        val_loader: The data loader for validation.
106        loss: The loss function. If None, will use cross entropy.
107        metric: The metric function. If None, will use the accuracy error.
108        logger: The logger for keeping track of the training progress.
109        trainer_class: The trainer class.
110        kwargs: Keyword arguments for the trainer class.
111
112    Returns:
113        The classification trainer.
114    """
115    # Set the default loss and metric (if no values where passed).
116    loss = torch.nn.CrossEntropyLoss() if loss is None else loss
117    metric = ClassificationMetric() if metric is None else metric
118
119    # Metric: Note that we use lower metric = better.
120    # So we record the accuracy error instead of the accuracy..
121    trainer = torch_em.default_segmentation_trainer(
122        name, model, train_loader, val_loader,
123        loss=loss, metric=metric,
124        logger=logger, trainer_class=trainer_class,
125        **kwargs,
126    )
127    return trainer

Get a trainer for a classification task.

This will create an instance of torch_em.classification.ClassificationTrainer. Check out its documentation string for details on how to configure and use the trainer.

Arguments:
  • name: The name for the checkpoint created by the trainer.
  • model: The classification model to train.
  • train_loader: The data loader for training.
  • val_loader: The data loader for validation.
  • loss: The loss function. If None, will use cross entropy.
  • metric: The metric function. If None, will use the accuracy error.
  • logger: The logger for keeping track of the training progress.
  • trainer_class: The trainer class.
  • kwargs: Keyword arguments for the trainer class.
Returns:

The classification trainer.