torch_em.classification.classification

 1from functools import partial
 2
 3import sklearn.metrics as metrics
 4import torch
 5import torch_em
 6
 7from .classification_dataset import ClassificationDataset
 8from .classification_logger import ClassificationLogger
 9from .classification_trainer import ClassificationTrainer
10
11
12class ClassificationMetric:
13    def __init__(self, metric_name="accuracy_score", **metric_kwargs):
14        if not hasattr(metrics, metric_name):
15            raise ValueError(f"Invalid metric_name {metric_name}")
16        self.metric = getattr(metrics, metric_name)
17        self.metric_kwargs = metric_kwargs
18
19    def __call__(self, y_true, y_pred):
20        metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs)
21        return metric_error
22
23
24def default_classification_loader(
25    data, target, batch_size, normalization=None, augmentation=None, image_shape=None, **loader_kwargs,
26):
27    ndim = data[0].ndim - 1
28    if ndim not in (2, 3):
29        raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}")
30
31    if normalization is None:
32        axis = (1, 2) if ndim == 2 else (1, 2, 3)
33        normalization = partial(torch_em.transform.raw.standardize, axis=axis)
34
35    if augmentation is None:
36        augmentation = torch_em.transform.get_augmentations(ndim=ndim)
37
38    dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape)
39    loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs)
40    return loader
41
42
43# TODO
44def zarr_classification_loader():
45    return default_classification_loader()
46
47
48def default_classification_trainer(
49    name,
50    model,
51    train_loader,
52    val_loader,
53    loss=None,
54    metric=None,
55    logger=ClassificationLogger,
56    trainer_class=ClassificationTrainer,
57    **kwargs,
58):
59    """
60    """
61    # set the default loss and metric (if no values where passed)
62    loss = torch.nn.CrossEntropyLoss() if loss is None else loss
63    metric = ClassificationMetric() if metric is None else metric
64
65    # metric: note that we use lower metric = better !
66    # so we record the accuracy error instead of the error rate
67    trainer = torch_em.default_segmentation_trainer(
68        name, model, train_loader, val_loader,
69        loss=loss, metric=metric,
70        logger=logger, trainer_class=trainer_class,
71        **kwargs,
72    )
73    return trainer
class ClassificationMetric:
13class ClassificationMetric:
14    def __init__(self, metric_name="accuracy_score", **metric_kwargs):
15        if not hasattr(metrics, metric_name):
16            raise ValueError(f"Invalid metric_name {metric_name}")
17        self.metric = getattr(metrics, metric_name)
18        self.metric_kwargs = metric_kwargs
19
20    def __call__(self, y_true, y_pred):
21        metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs)
22        return metric_error
ClassificationMetric(metric_name='accuracy_score', **metric_kwargs)
14    def __init__(self, metric_name="accuracy_score", **metric_kwargs):
15        if not hasattr(metrics, metric_name):
16            raise ValueError(f"Invalid metric_name {metric_name}")
17        self.metric = getattr(metrics, metric_name)
18        self.metric_kwargs = metric_kwargs
metric
metric_kwargs
def default_classification_loader( data, target, batch_size, normalization=None, augmentation=None, image_shape=None, **loader_kwargs):
25def default_classification_loader(
26    data, target, batch_size, normalization=None, augmentation=None, image_shape=None, **loader_kwargs,
27):
28    ndim = data[0].ndim - 1
29    if ndim not in (2, 3):
30        raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}")
31
32    if normalization is None:
33        axis = (1, 2) if ndim == 2 else (1, 2, 3)
34        normalization = partial(torch_em.transform.raw.standardize, axis=axis)
35
36    if augmentation is None:
37        augmentation = torch_em.transform.get_augmentations(ndim=ndim)
38
39    dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape)
40    loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs)
41    return loader
def zarr_classification_loader():
45def zarr_classification_loader():
46    return default_classification_loader()
def default_classification_trainer( name, model, train_loader, val_loader, loss=None, metric=None, logger=<class 'torch_em.classification.classification_logger.ClassificationLogger'>, trainer_class=<class 'torch_em.classification.classification_trainer.ClassificationTrainer'>, **kwargs):
49def default_classification_trainer(
50    name,
51    model,
52    train_loader,
53    val_loader,
54    loss=None,
55    metric=None,
56    logger=ClassificationLogger,
57    trainer_class=ClassificationTrainer,
58    **kwargs,
59):
60    """
61    """
62    # set the default loss and metric (if no values where passed)
63    loss = torch.nn.CrossEntropyLoss() if loss is None else loss
64    metric = ClassificationMetric() if metric is None else metric
65
66    # metric: note that we use lower metric = better !
67    # so we record the accuracy error instead of the error rate
68    trainer = torch_em.default_segmentation_trainer(
69        name, model, train_loader, val_loader,
70        loss=loss, metric=metric,
71        logger=logger, trainer_class=trainer_class,
72        **kwargs,
73    )
74    return trainer