torch_em.classification.classification_trainer
1import warnings 2 3import numpy as np 4import torch 5import torch_em 6 7 8class ClassificationTrainer(torch_em.trainer.DefaultTrainer): 9 """Trainer for classification tasks. 10 11 This class inherits from `torch_em.trainer.DefaultTrainer` with minor changes for 12 classification instead of segmentation training. Check out the documentation of 13 the default trainer class for details on how to configure and use the trainer 14 """ 15 def _validate_impl(self, forward_context): 16 self.model.eval() 17 18 loss_val = 0.0 19 20 # we use the syntax from sklearn.metrics to compute metrics 21 # over all the preditions 22 y_true, y_pred = [], [] 23 24 with torch.no_grad(): 25 for x, y in self.val_loader: 26 x, y = x.to(self.device), y.to(self.device) 27 with forward_context(): 28 pred, loss = self._forward_and_loss(x, y) 29 loss_val += loss.item() 30 y_true.append(y.detach().cpu().numpy()) 31 y_pred.append(pred.max(1)[1].detach().cpu().numpy()) 32 33 if torch.isnan(pred).any(): 34 warnings.warn("Predictions are NaN") 35 loss_val /= len(self.val_loader) 36 37 y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred) 38 metric_val = self.metric(y_true, y_pred) 39 40 if self.logger is not None: 41 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred, y_true, y_pred) 42 return metric_val
9class ClassificationTrainer(torch_em.trainer.DefaultTrainer): 10 """Trainer for classification tasks. 11 12 This class inherits from `torch_em.trainer.DefaultTrainer` with minor changes for 13 classification instead of segmentation training. Check out the documentation of 14 the default trainer class for details on how to configure and use the trainer 15 """ 16 def _validate_impl(self, forward_context): 17 self.model.eval() 18 19 loss_val = 0.0 20 21 # we use the syntax from sklearn.metrics to compute metrics 22 # over all the preditions 23 y_true, y_pred = [], [] 24 25 with torch.no_grad(): 26 for x, y in self.val_loader: 27 x, y = x.to(self.device), y.to(self.device) 28 with forward_context(): 29 pred, loss = self._forward_and_loss(x, y) 30 loss_val += loss.item() 31 y_true.append(y.detach().cpu().numpy()) 32 y_pred.append(pred.max(1)[1].detach().cpu().numpy()) 33 34 if torch.isnan(pred).any(): 35 warnings.warn("Predictions are NaN") 36 loss_val /= len(self.val_loader) 37 38 y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred) 39 metric_val = self.metric(y_true, y_pred) 40 41 if self.logger is not None: 42 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred, y_true, y_pred) 43 return metric_val
Trainer for classification tasks.
This class inherits from torch_em.trainer.DefaultTrainer
with minor changes for
classification instead of segmentation training. Check out the documentation of
the default trainer class for details on how to configure and use the trainer
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit