torch_em.classification.classification_trainer
1import warnings 2 3import numpy as np 4import torch 5import torch_em 6 7 8class ClassificationTrainer(torch_em.trainer.DefaultTrainer): 9 def _validate_impl(self, forward_context): 10 self.model.eval() 11 12 loss_val = 0.0 13 14 # we use the syntax from sklearn.metrics to compute metrics 15 # over all the preditions 16 y_true, y_pred = [], [] 17 18 with torch.no_grad(): 19 for x, y in self.val_loader: 20 x, y = x.to(self.device), y.to(self.device) 21 with forward_context(): 22 pred, loss = self._forward_and_loss(x, y) 23 loss_val += loss.item() 24 y_true.append(y.detach().cpu().numpy()) 25 y_pred.append(pred.max(1)[1].detach().cpu().numpy()) 26 27 if torch.isnan(pred).any(): 28 warnings.warn("Predictions are NaN") 29 loss_val /= len(self.val_loader) 30 31 y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred) 32 metric_val = self.metric(y_true, y_pred) 33 34 if self.logger is not None: 35 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred, y_true, y_pred) 36 return metric_val
9class ClassificationTrainer(torch_em.trainer.DefaultTrainer): 10 def _validate_impl(self, forward_context): 11 self.model.eval() 12 13 loss_val = 0.0 14 15 # we use the syntax from sklearn.metrics to compute metrics 16 # over all the preditions 17 y_true, y_pred = [], [] 18 19 with torch.no_grad(): 20 for x, y in self.val_loader: 21 x, y = x.to(self.device), y.to(self.device) 22 with forward_context(): 23 pred, loss = self._forward_and_loss(x, y) 24 loss_val += loss.item() 25 y_true.append(y.detach().cpu().numpy()) 26 y_pred.append(pred.max(1)[1].detach().cpu().numpy()) 27 28 if torch.isnan(pred).any(): 29 warnings.warn("Predictions are NaN") 30 loss_val /= len(self.val_loader) 31 32 y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred) 33 metric_val = self.metric(y_true, y_pred) 34 35 if self.logger is not None: 36 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred, y_true, y_pred) 37 return metric_val
Trainer class for 2d/3d training on a single GPU.
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- mixed_precision
- early_stopping
- train_time
- scaler
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- save_checkpoint
- load_checkpoint
- fit