torch_em.trainer.wandb_logger

  1import os
  2from datetime import datetime
  3from typing import Optional
  4
  5import numpy as np
  6
  7from .logger_base import TorchEmLogger
  8from .tensorboard_logger import make_grid_image, normalize_im
  9
 10try:
 11    import wandb
 12except ImportError:
 13    wandb = None
 14
 15try:
 16    from typing import Literal
 17except ImportError:
 18    from typing_extensions import Literal  # type: ignore
 19
 20
 21class WandbLogger(TorchEmLogger):
 22    """Logger to write training progress to weights and biases.
 23
 24    Args:
 25        trainer: The instantiated trainer.
 26        save_root: The root directury for writing checkpoints and log files.
 27        project_name: The name of the weights and biases project for these logs.
 28        log_model_freq: The frequency for logging the model.
 29        log_model_graph: Whether to log the model graph.
 30        mode: The logging mode.
 31        config: The configuration.
 32        resume:
 33    """
 34    def __init__(
 35        self,
 36        trainer,
 37        save_root: str,
 38        *,
 39        project_name: Optional[str] = None,
 40        log_model: Optional[Literal["gradients", "parameters", "all"]] = "all",
 41        log_model_freq: int = 1,
 42        log_model_graph: bool = True,
 43        mode: Literal["online", "offline", "disabled"] = "online",
 44        config: Optional[dict] = None,
 45        resume: Optional[str] = None,
 46        **unused_kwargs,
 47    ):
 48        if wandb is None:
 49            raise RuntimeError("WandbLogger is not available")
 50
 51        super().__init__(trainer, save_root)
 52
 53        self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs")
 54        os.makedirs(self.log_dir, exist_ok=True)
 55
 56        config = dict(config or {})
 57        config.update(trainer.init_data)
 58        self.wand_run = wandb.init(
 59            id=resume, project=project_name, name=trainer.name, dir=self.log_dir,
 60            mode=mode, config=config, resume="allow"
 61        )
 62        trainer.id = self.wand_run.id
 63
 64        if trainer.name is None:
 65            if mode == "online":
 66                trainer.name = self.wand_run.name
 67            elif mode in ("offline", "disabled"):
 68                trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}"
 69                trainer.id = trainer.name  # if we don't upload the log, name with time stamp is a better run id
 70            else:
 71                raise ValueError(mode)
 72
 73        self.log_image_interval = trainer.log_image_interval
 74
 75        wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph)
 76
 77    def _log_images(self, step, x, y, prediction, name, gradients=None):
 78
 79        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
 80
 81        image = normalize_im(x[selection].cpu())
 82        grid_image, grid_name = make_grid_image(image, y, prediction, selection, gradients)
 83
 84        # to numpy and channel last
 85        image = image.numpy().transpose((1, 2, 0))
 86        wandb.log({f"images_{name}/input": [wandb.Image(image, caption="Input Data")]}, step=step)
 87
 88        grid_image = grid_image.numpy().transpose((1, 2, 0))
 89
 90        wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}, step=step)
 91
 92    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
 93        """@private
 94        """
 95        wandb.log({"train/loss": loss}, step=step)
 96        if loss < self.wand_run.summary.get("train/loss", np.inf):
 97            self.wand_run.summary["train/loss"] = loss
 98
 99        if step % self.log_image_interval == 0:
100            gradients = prediction.grad if log_gradients else None
101            self._log_images(step, x, y, prediction, "train", gradients=gradients)
102
103    def log_validation(self, step, metric, loss, x, y, prediction):
104        """@private
105        """
106        wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step)
107        if loss < self.wand_run.summary.get("validation/loss", np.inf):
108            self.wand_run.summary["validation/loss"] = loss
109
110        if metric < self.wand_run.summary.get("validation/metric", np.inf):
111            self.wand_run.summary["validation/metric"] = metric
112
113        self._log_images(step, x, y, prediction, "validation")
114
115    def get_wandb(self):
116        """@private
117        """
118        return wandb
class WandbLogger(torch_em.trainer.logger_base.TorchEmLogger):
 22class WandbLogger(TorchEmLogger):
 23    """Logger to write training progress to weights and biases.
 24
 25    Args:
 26        trainer: The instantiated trainer.
 27        save_root: The root directury for writing checkpoints and log files.
 28        project_name: The name of the weights and biases project for these logs.
 29        log_model_freq: The frequency for logging the model.
 30        log_model_graph: Whether to log the model graph.
 31        mode: The logging mode.
 32        config: The configuration.
 33        resume:
 34    """
 35    def __init__(
 36        self,
 37        trainer,
 38        save_root: str,
 39        *,
 40        project_name: Optional[str] = None,
 41        log_model: Optional[Literal["gradients", "parameters", "all"]] = "all",
 42        log_model_freq: int = 1,
 43        log_model_graph: bool = True,
 44        mode: Literal["online", "offline", "disabled"] = "online",
 45        config: Optional[dict] = None,
 46        resume: Optional[str] = None,
 47        **unused_kwargs,
 48    ):
 49        if wandb is None:
 50            raise RuntimeError("WandbLogger is not available")
 51
 52        super().__init__(trainer, save_root)
 53
 54        self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs")
 55        os.makedirs(self.log_dir, exist_ok=True)
 56
 57        config = dict(config or {})
 58        config.update(trainer.init_data)
 59        self.wand_run = wandb.init(
 60            id=resume, project=project_name, name=trainer.name, dir=self.log_dir,
 61            mode=mode, config=config, resume="allow"
 62        )
 63        trainer.id = self.wand_run.id
 64
 65        if trainer.name is None:
 66            if mode == "online":
 67                trainer.name = self.wand_run.name
 68            elif mode in ("offline", "disabled"):
 69                trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}"
 70                trainer.id = trainer.name  # if we don't upload the log, name with time stamp is a better run id
 71            else:
 72                raise ValueError(mode)
 73
 74        self.log_image_interval = trainer.log_image_interval
 75
 76        wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph)
 77
 78    def _log_images(self, step, x, y, prediction, name, gradients=None):
 79
 80        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
 81
 82        image = normalize_im(x[selection].cpu())
 83        grid_image, grid_name = make_grid_image(image, y, prediction, selection, gradients)
 84
 85        # to numpy and channel last
 86        image = image.numpy().transpose((1, 2, 0))
 87        wandb.log({f"images_{name}/input": [wandb.Image(image, caption="Input Data")]}, step=step)
 88
 89        grid_image = grid_image.numpy().transpose((1, 2, 0))
 90
 91        wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}, step=step)
 92
 93    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
 94        """@private
 95        """
 96        wandb.log({"train/loss": loss}, step=step)
 97        if loss < self.wand_run.summary.get("train/loss", np.inf):
 98            self.wand_run.summary["train/loss"] = loss
 99
100        if step % self.log_image_interval == 0:
101            gradients = prediction.grad if log_gradients else None
102            self._log_images(step, x, y, prediction, "train", gradients=gradients)
103
104    def log_validation(self, step, metric, loss, x, y, prediction):
105        """@private
106        """
107        wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step)
108        if loss < self.wand_run.summary.get("validation/loss", np.inf):
109            self.wand_run.summary["validation/loss"] = loss
110
111        if metric < self.wand_run.summary.get("validation/metric", np.inf):
112            self.wand_run.summary["validation/metric"] = metric
113
114        self._log_images(step, x, y, prediction, "validation")
115
116    def get_wandb(self):
117        """@private
118        """
119        return wandb

Logger to write training progress to weights and biases.

Arguments:
  • trainer: The instantiated trainer.
  • save_root: The root directury for writing checkpoints and log files.
  • project_name: The name of the weights and biases project for these logs.
  • log_model_freq: The frequency for logging the model.
  • log_model_graph: Whether to log the model graph.
  • mode: The logging mode.
  • config: The configuration.
  • resume:
WandbLogger( trainer, save_root: str, *, project_name: Optional[str] = None, log_model: Optional[Literal['gradients', 'parameters', 'all']] = 'all', log_model_freq: int = 1, log_model_graph: bool = True, mode: Literal['online', 'offline', 'disabled'] = 'online', config: Optional[dict] = None, resume: Optional[str] = None, **unused_kwargs)
35    def __init__(
36        self,
37        trainer,
38        save_root: str,
39        *,
40        project_name: Optional[str] = None,
41        log_model: Optional[Literal["gradients", "parameters", "all"]] = "all",
42        log_model_freq: int = 1,
43        log_model_graph: bool = True,
44        mode: Literal["online", "offline", "disabled"] = "online",
45        config: Optional[dict] = None,
46        resume: Optional[str] = None,
47        **unused_kwargs,
48    ):
49        if wandb is None:
50            raise RuntimeError("WandbLogger is not available")
51
52        super().__init__(trainer, save_root)
53
54        self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs")
55        os.makedirs(self.log_dir, exist_ok=True)
56
57        config = dict(config or {})
58        config.update(trainer.init_data)
59        self.wand_run = wandb.init(
60            id=resume, project=project_name, name=trainer.name, dir=self.log_dir,
61            mode=mode, config=config, resume="allow"
62        )
63        trainer.id = self.wand_run.id
64
65        if trainer.name is None:
66            if mode == "online":
67                trainer.name = self.wand_run.name
68            elif mode in ("offline", "disabled"):
69                trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}"
70                trainer.id = trainer.name  # if we don't upload the log, name with time stamp is a better run id
71            else:
72                raise ValueError(mode)
73
74        self.log_image_interval = trainer.log_image_interval
75
76        wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph)
log_dir
wand_run
log_image_interval