torch_em.trainer.logger_base

 1try:
 2    from typing import Literal
 3except ImportError:
 4    from typing_extensions import Literal  # type: ignore
 5
 6
 7class TorchEmLogger:
 8    """@private
 9    """
10    def __init__(self, trainer, save_root, **kwargs):
11        self.trainer = trainer
12        self.save_root = save_root
13
14    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
15        raise NotImplementedError
16
17    def log_validation(self, step, metric, loss, x, y, prediction):
18        raise NotImplementedError