torch_em.classification.classification_trainer

 1import warnings
 2
 3import numpy as np
 4import torch
 5import torch_em
 6
 7
 8class ClassificationTrainer(torch_em.trainer.DefaultTrainer):
 9    """Trainer for classification tasks.
10
11    This class inherits from `torch_em.trainer.DefaultTrainer` with minor changes for
12    classification instead of segmentation training. Check out the documentation of
13    the default trainer class for details on how to configure and use the trainer
14    """
15    def _validate_impl(self, forward_context):
16        self.model.eval()
17
18        loss_val = 0.0
19
20        # we use the syntax from sklearn.metrics to compute metrics
21        # over all the preditions
22        y_true, y_pred = [], []
23
24        with torch.no_grad():
25            for x, y in self.val_loader:
26                x, y = x.to(self.device), y.to(self.device)
27                with forward_context():
28                    pred, loss = self._forward_and_loss(x, y)
29                loss_val += loss.item()
30                y_true.append(y.detach().cpu().numpy())
31                y_pred.append(pred.max(1)[1].detach().cpu().numpy())
32
33        if torch.isnan(pred).any():
34            warnings.warn("Predictions are NaN")
35        loss_val /= len(self.val_loader)
36
37        y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred)
38        metric_val = self.metric(y_true, y_pred)
39
40        if self.logger is not None:
41            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred, y_true, y_pred)
42        return metric_val
class ClassificationTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 9class ClassificationTrainer(torch_em.trainer.DefaultTrainer):
10    """Trainer for classification tasks.
11
12    This class inherits from `torch_em.trainer.DefaultTrainer` with minor changes for
13    classification instead of segmentation training. Check out the documentation of
14    the default trainer class for details on how to configure and use the trainer
15    """
16    def _validate_impl(self, forward_context):
17        self.model.eval()
18
19        loss_val = 0.0
20
21        # we use the syntax from sklearn.metrics to compute metrics
22        # over all the preditions
23        y_true, y_pred = [], []
24
25        with torch.no_grad():
26            for x, y in self.val_loader:
27                x, y = x.to(self.device), y.to(self.device)
28                with forward_context():
29                    pred, loss = self._forward_and_loss(x, y)
30                loss_val += loss.item()
31                y_true.append(y.detach().cpu().numpy())
32                y_pred.append(pred.max(1)[1].detach().cpu().numpy())
33
34        if torch.isnan(pred).any():
35            warnings.warn("Predictions are NaN")
36        loss_val /= len(self.val_loader)
37
38        y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred)
39        metric_val = self.metric(y_true, y_pred)
40
41        if self.logger is not None:
42            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred, y_true, y_pred)
43        return metric_val

Trainer for classification tasks.

This class inherits from torch_em.trainer.DefaultTrainer with minor changes for classification instead of segmentation training. Check out the documentation of the default trainer class for details on how to configure and use the trainer