torch_em.trainer.wandb_logger

 1import os
 2from datetime import datetime
 3from typing import Optional
 4
 5import numpy as np
 6
 7from .logger_base import TorchEmLogger
 8from .tensorboard_logger import make_grid_image, normalize_im
 9
10try:
11    import wandb
12except ImportError:
13    wandb = None
14
15try:
16    from typing import Literal
17except ImportError:
18    from typing_extensions import Literal  # type: ignore
19
20
21class WandbLogger(TorchEmLogger):
22    def __init__(
23        self,
24        trainer,
25        save_root,
26        *,
27        project_name: Optional[str] = None,
28        log_model: Optional[Literal["gradients", "parameters", "all"]] = "all",
29        log_model_freq: int = 1,
30        log_model_graph: bool = True,
31        mode: Literal["online", "offline", "disabled"] = "online",
32        config: Optional[dict] = None,
33        resume: Optional[str] = None,
34        **unused_kwargs,
35    ):
36        if wandb is None:
37            raise RuntimeError("WandbLogger is not available")
38
39        super().__init__(trainer, save_root)
40
41        self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs")
42        os.makedirs(self.log_dir, exist_ok=True)
43
44        config = dict(config or {})
45        config.update(trainer.init_data)
46        self.wand_run = wandb.init(
47            id=resume, project=project_name, name=trainer.name, dir=self.log_dir, mode=mode, config=config, resume="allow"
48        )
49        trainer.id = self.wand_run.id
50
51        if trainer.name is None:
52            if mode == "online":
53                trainer.name = self.wand_run.name
54            elif mode in ("offline", "disabled"):
55                trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}"
56                trainer.id = trainer.name  # if we don't upload the log, name with time stamp is a better run id
57            else:
58                raise ValueError(mode)
59
60        self.log_image_interval = trainer.log_image_interval
61
62        wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph)
63
64    def _log_images(self, step, x, y, prediction, name, gradients=None):
65
66        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
67
68        image = normalize_im(x[selection].cpu())
69        grid_image, grid_name = make_grid_image(image, y, prediction, selection, gradients)
70
71        # to numpy and channel last
72        image = image.numpy().transpose((1, 2, 0))
73        wandb.log({f"images_{name}/input": [wandb.Image(image, caption="Input Data")]}, step=step)
74
75        grid_image = grid_image.numpy().transpose((1, 2, 0))
76
77        wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}, step=step)
78
79    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
80        wandb.log({"train/loss": loss}, step=step)
81        if loss < self.wand_run.summary.get("train/loss", np.inf):
82            self.wand_run.summary["train/loss"] = loss
83
84        if step % self.log_image_interval == 0:
85            gradients = prediction.grad if log_gradients else None
86            self._log_images(step, x, y, prediction, "train", gradients=gradients)
87
88    def log_validation(self, step, metric, loss, x, y, prediction):
89        wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step)
90        if loss < self.wand_run.summary.get("validation/loss", np.inf):
91            self.wand_run.summary["validation/loss"] = loss
92
93        if metric < self.wand_run.summary.get("validation/metric", np.inf):
94            self.wand_run.summary["validation/metric"] = metric
95
96        self._log_images(step, x, y, prediction, "validation")
97
98    def get_wandb(self):
99        return wandb
class WandbLogger(torch_em.trainer.logger_base.TorchEmLogger):
 22class WandbLogger(TorchEmLogger):
 23    def __init__(
 24        self,
 25        trainer,
 26        save_root,
 27        *,
 28        project_name: Optional[str] = None,
 29        log_model: Optional[Literal["gradients", "parameters", "all"]] = "all",
 30        log_model_freq: int = 1,
 31        log_model_graph: bool = True,
 32        mode: Literal["online", "offline", "disabled"] = "online",
 33        config: Optional[dict] = None,
 34        resume: Optional[str] = None,
 35        **unused_kwargs,
 36    ):
 37        if wandb is None:
 38            raise RuntimeError("WandbLogger is not available")
 39
 40        super().__init__(trainer, save_root)
 41
 42        self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs")
 43        os.makedirs(self.log_dir, exist_ok=True)
 44
 45        config = dict(config or {})
 46        config.update(trainer.init_data)
 47        self.wand_run = wandb.init(
 48            id=resume, project=project_name, name=trainer.name, dir=self.log_dir, mode=mode, config=config, resume="allow"
 49        )
 50        trainer.id = self.wand_run.id
 51
 52        if trainer.name is None:
 53            if mode == "online":
 54                trainer.name = self.wand_run.name
 55            elif mode in ("offline", "disabled"):
 56                trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}"
 57                trainer.id = trainer.name  # if we don't upload the log, name with time stamp is a better run id
 58            else:
 59                raise ValueError(mode)
 60
 61        self.log_image_interval = trainer.log_image_interval
 62
 63        wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph)
 64
 65    def _log_images(self, step, x, y, prediction, name, gradients=None):
 66
 67        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
 68
 69        image = normalize_im(x[selection].cpu())
 70        grid_image, grid_name = make_grid_image(image, y, prediction, selection, gradients)
 71
 72        # to numpy and channel last
 73        image = image.numpy().transpose((1, 2, 0))
 74        wandb.log({f"images_{name}/input": [wandb.Image(image, caption="Input Data")]}, step=step)
 75
 76        grid_image = grid_image.numpy().transpose((1, 2, 0))
 77
 78        wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}, step=step)
 79
 80    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
 81        wandb.log({"train/loss": loss}, step=step)
 82        if loss < self.wand_run.summary.get("train/loss", np.inf):
 83            self.wand_run.summary["train/loss"] = loss
 84
 85        if step % self.log_image_interval == 0:
 86            gradients = prediction.grad if log_gradients else None
 87            self._log_images(step, x, y, prediction, "train", gradients=gradients)
 88
 89    def log_validation(self, step, metric, loss, x, y, prediction):
 90        wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step)
 91        if loss < self.wand_run.summary.get("validation/loss", np.inf):
 92            self.wand_run.summary["validation/loss"] = loss
 93
 94        if metric < self.wand_run.summary.get("validation/metric", np.inf):
 95            self.wand_run.summary["validation/metric"] = metric
 96
 97        self._log_images(step, x, y, prediction, "validation")
 98
 99    def get_wandb(self):
100        return wandb
WandbLogger( trainer, save_root, *, project_name: Optional[str] = None, log_model: Optional[Literal['gradients', 'parameters', 'all']] = 'all', log_model_freq: int = 1, log_model_graph: bool = True, mode: Literal['online', 'offline', 'disabled'] = 'online', config: Optional[dict] = None, resume: Optional[str] = None, **unused_kwargs)
23    def __init__(
24        self,
25        trainer,
26        save_root,
27        *,
28        project_name: Optional[str] = None,
29        log_model: Optional[Literal["gradients", "parameters", "all"]] = "all",
30        log_model_freq: int = 1,
31        log_model_graph: bool = True,
32        mode: Literal["online", "offline", "disabled"] = "online",
33        config: Optional[dict] = None,
34        resume: Optional[str] = None,
35        **unused_kwargs,
36    ):
37        if wandb is None:
38            raise RuntimeError("WandbLogger is not available")
39
40        super().__init__(trainer, save_root)
41
42        self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs")
43        os.makedirs(self.log_dir, exist_ok=True)
44
45        config = dict(config or {})
46        config.update(trainer.init_data)
47        self.wand_run = wandb.init(
48            id=resume, project=project_name, name=trainer.name, dir=self.log_dir, mode=mode, config=config, resume="allow"
49        )
50        trainer.id = self.wand_run.id
51
52        if trainer.name is None:
53            if mode == "online":
54                trainer.name = self.wand_run.name
55            elif mode in ("offline", "disabled"):
56                trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}"
57                trainer.id = trainer.name  # if we don't upload the log, name with time stamp is a better run id
58            else:
59                raise ValueError(mode)
60
61        self.log_image_interval = trainer.log_image_interval
62
63        wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph)
log_dir
wand_run
log_image_interval
def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
80    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
81        wandb.log({"train/loss": loss}, step=step)
82        if loss < self.wand_run.summary.get("train/loss", np.inf):
83            self.wand_run.summary["train/loss"] = loss
84
85        if step % self.log_image_interval == 0:
86            gradients = prediction.grad if log_gradients else None
87            self._log_images(step, x, y, prediction, "train", gradients=gradients)
def log_validation(self, step, metric, loss, x, y, prediction):
89    def log_validation(self, step, metric, loss, x, y, prediction):
90        wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step)
91        if loss < self.wand_run.summary.get("validation/loss", np.inf):
92            self.wand_run.summary["validation/loss"] = loss
93
94        if metric < self.wand_run.summary.get("validation/metric", np.inf):
95            self.wand_run.summary["validation/metric"] = metric
96
97        self._log_images(step, x, y, prediction, "validation")
def get_wandb(self):
 99    def get_wandb(self):
100        return wandb