torch_em.classification.classification
1from functools import partial 2 3import sklearn.metrics as metrics 4import torch 5import torch_em 6 7from .classification_dataset import ClassificationDataset 8from .classification_logger import ClassificationLogger 9from .classification_trainer import ClassificationTrainer 10 11 12class ClassificationMetric: 13 def __init__(self, metric_name="accuracy_score", **metric_kwargs): 14 if not hasattr(metrics, metric_name): 15 raise ValueError(f"Invalid metric_name {metric_name}") 16 self.metric = getattr(metrics, metric_name) 17 self.metric_kwargs = metric_kwargs 18 19 def __call__(self, y_true, y_pred): 20 metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs) 21 return metric_error 22 23 24def default_classification_loader( 25 data, target, batch_size, normalization=None, augmentation=None, image_shape=None, **loader_kwargs, 26): 27 ndim = data[0].ndim - 1 28 if ndim not in (2, 3): 29 raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}") 30 31 if normalization is None: 32 axis = (1, 2) if ndim == 2 else (1, 2, 3) 33 normalization = partial(torch_em.transform.raw.standardize, axis=axis) 34 35 if augmentation is None: 36 augmentation = torch_em.transform.get_augmentations(ndim=ndim) 37 38 dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape) 39 loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs) 40 return loader 41 42 43# TODO 44def zarr_classification_loader(): 45 return default_classification_loader() 46 47 48def default_classification_trainer( 49 name, 50 model, 51 train_loader, 52 val_loader, 53 loss=None, 54 metric=None, 55 logger=ClassificationLogger, 56 trainer_class=ClassificationTrainer, 57 **kwargs, 58): 59 """ 60 """ 61 # set the default loss and metric (if no values where passed) 62 loss = torch.nn.CrossEntropyLoss() if loss is None else loss 63 metric = ClassificationMetric() if metric is None else metric 64 65 # metric: note that we use lower metric = better ! 66 # so we record the accuracy error instead of the error rate 67 trainer = torch_em.default_segmentation_trainer( 68 name, model, train_loader, val_loader, 69 loss=loss, metric=metric, 70 logger=logger, trainer_class=trainer_class, 71 **kwargs, 72 ) 73 return trainer
class
ClassificationMetric:
13class ClassificationMetric: 14 def __init__(self, metric_name="accuracy_score", **metric_kwargs): 15 if not hasattr(metrics, metric_name): 16 raise ValueError(f"Invalid metric_name {metric_name}") 17 self.metric = getattr(metrics, metric_name) 18 self.metric_kwargs = metric_kwargs 19 20 def __call__(self, y_true, y_pred): 21 metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs) 22 return metric_error
def
default_classification_loader( data, target, batch_size, normalization=None, augmentation=None, image_shape=None, **loader_kwargs):
25def default_classification_loader( 26 data, target, batch_size, normalization=None, augmentation=None, image_shape=None, **loader_kwargs, 27): 28 ndim = data[0].ndim - 1 29 if ndim not in (2, 3): 30 raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}") 31 32 if normalization is None: 33 axis = (1, 2) if ndim == 2 else (1, 2, 3) 34 normalization = partial(torch_em.transform.raw.standardize, axis=axis) 35 36 if augmentation is None: 37 augmentation = torch_em.transform.get_augmentations(ndim=ndim) 38 39 dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape) 40 loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs) 41 return loader
def
zarr_classification_loader():
def
default_classification_trainer( name, model, train_loader, val_loader, loss=None, metric=None, logger=<class 'torch_em.classification.classification_logger.ClassificationLogger'>, trainer_class=<class 'torch_em.classification.classification_trainer.ClassificationTrainer'>, **kwargs):
49def default_classification_trainer( 50 name, 51 model, 52 train_loader, 53 val_loader, 54 loss=None, 55 metric=None, 56 logger=ClassificationLogger, 57 trainer_class=ClassificationTrainer, 58 **kwargs, 59): 60 """ 61 """ 62 # set the default loss and metric (if no values where passed) 63 loss = torch.nn.CrossEntropyLoss() if loss is None else loss 64 metric = ClassificationMetric() if metric is None else metric 65 66 # metric: note that we use lower metric = better ! 67 # so we record the accuracy error instead of the error rate 68 trainer = torch_em.default_segmentation_trainer( 69 name, model, train_loader, val_loader, 70 loss=loss, metric=metric, 71 logger=logger, trainer_class=trainer_class, 72 **kwargs, 73 ) 74 return trainer