torch_em.classification.classification_trainer

 1import warnings
 2
 3import numpy as np
 4import torch
 5import torch_em
 6
 7
 8class ClassificationTrainer(torch_em.trainer.DefaultTrainer):
 9    def _validate_impl(self, forward_context):
10        self.model.eval()
11
12        loss_val = 0.0
13
14        # we use the syntax from sklearn.metrics to compute metrics
15        # over all the preditions
16        y_true, y_pred = [], []
17
18        with torch.no_grad():
19            for x, y in self.val_loader:
20                x, y = x.to(self.device), y.to(self.device)
21                with forward_context():
22                    pred, loss = self._forward_and_loss(x, y)
23                loss_val += loss.item()
24                y_true.append(y.detach().cpu().numpy())
25                y_pred.append(pred.max(1)[1].detach().cpu().numpy())
26
27        if torch.isnan(pred).any():
28            warnings.warn("Predictions are NaN")
29        loss_val /= len(self.val_loader)
30
31        y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred)
32        metric_val = self.metric(y_true, y_pred)
33
34        if self.logger is not None:
35            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred, y_true, y_pred)
36        return metric_val
class ClassificationTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 9class ClassificationTrainer(torch_em.trainer.DefaultTrainer):
10    def _validate_impl(self, forward_context):
11        self.model.eval()
12
13        loss_val = 0.0
14
15        # we use the syntax from sklearn.metrics to compute metrics
16        # over all the preditions
17        y_true, y_pred = [], []
18
19        with torch.no_grad():
20            for x, y in self.val_loader:
21                x, y = x.to(self.device), y.to(self.device)
22                with forward_context():
23                    pred, loss = self._forward_and_loss(x, y)
24                loss_val += loss.item()
25                y_true.append(y.detach().cpu().numpy())
26                y_pred.append(pred.max(1)[1].detach().cpu().numpy())
27
28        if torch.isnan(pred).any():
29            warnings.warn("Predictions are NaN")
30        loss_val /= len(self.val_loader)
31
32        y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred)
33        metric_val = self.metric(y_true, y_pred)
34
35        if self.logger is not None:
36            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred, y_true, y_pred)
37        return metric_val

Trainer class for 2d/3d training on a single GPU.