torch_em.trainer.wandb_logger
1import os 2from datetime import datetime 3from typing import Optional 4 5import numpy as np 6 7from .logger_base import TorchEmLogger 8from .tensorboard_logger import make_grid_image, normalize_im 9 10try: 11 import wandb 12except ImportError: 13 wandb = None 14 15try: 16 from typing import Literal 17except ImportError: 18 from typing_extensions import Literal # type: ignore 19 20 21class WandbLogger(TorchEmLogger): 22 """Logger to write training progress to weights and biases. 23 24 Args: 25 trainer: The instantiated trainer. 26 save_root: The root directury for writing checkpoints and log files. 27 project_name: The name of the weights and biases project for these logs. 28 log_model_freq: The frequency for logging the model. 29 log_model_graph: Whether to log the model graph. 30 mode: The logging mode. 31 config: The configuration. 32 resume: 33 """ 34 def __init__( 35 self, 36 trainer, 37 save_root: str, 38 *, 39 project_name: Optional[str] = None, 40 log_model: Optional[Literal["gradients", "parameters", "all"]] = "all", 41 log_model_freq: int = 1, 42 log_model_graph: bool = True, 43 mode: Literal["online", "offline", "disabled"] = "online", 44 config: Optional[dict] = None, 45 resume: Optional[str] = None, 46 **unused_kwargs, 47 ): 48 if wandb is None: 49 raise RuntimeError("WandbLogger is not available") 50 51 super().__init__(trainer, save_root) 52 53 self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs") 54 os.makedirs(self.log_dir, exist_ok=True) 55 56 config = dict(config or {}) 57 config.update(trainer.init_data) 58 self.wand_run = wandb.init( 59 id=resume, project=project_name, name=trainer.name, dir=self.log_dir, 60 mode=mode, config=config, resume="allow" 61 ) 62 trainer.id = self.wand_run.id 63 64 if trainer.name is None: 65 if mode == "online": 66 trainer.name = self.wand_run.name 67 elif mode in ("offline", "disabled"): 68 trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}" 69 trainer.id = trainer.name # if we don't upload the log, name with time stamp is a better run id 70 else: 71 raise ValueError(mode) 72 73 self.log_image_interval = trainer.log_image_interval 74 75 wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph) 76 77 def _log_images(self, step, x, y, prediction, name, gradients=None): 78 79 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 80 81 image = normalize_im(x[selection].cpu()) 82 grid_image, grid_name = make_grid_image(image, y, prediction, selection, gradients) 83 84 # to numpy and channel last 85 image = image.numpy().transpose((1, 2, 0)) 86 wandb.log({f"images_{name}/input": [wandb.Image(image, caption="Input Data")]}, step=step) 87 88 grid_image = grid_image.numpy().transpose((1, 2, 0)) 89 90 wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}, step=step) 91 92 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 93 """@private 94 """ 95 wandb.log({"train/loss": loss}, step=step) 96 if loss < self.wand_run.summary.get("train/loss", np.inf): 97 self.wand_run.summary["train/loss"] = loss 98 99 if step % self.log_image_interval == 0: 100 gradients = prediction.grad if log_gradients else None 101 self._log_images(step, x, y, prediction, "train", gradients=gradients) 102 103 def log_validation(self, step, metric, loss, x, y, prediction): 104 """@private 105 """ 106 wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step) 107 if loss < self.wand_run.summary.get("validation/loss", np.inf): 108 self.wand_run.summary["validation/loss"] = loss 109 110 if metric < self.wand_run.summary.get("validation/metric", np.inf): 111 self.wand_run.summary["validation/metric"] = metric 112 113 self._log_images(step, x, y, prediction, "validation") 114 115 def get_wandb(self): 116 """@private 117 """ 118 return wandb
22class WandbLogger(TorchEmLogger): 23 """Logger to write training progress to weights and biases. 24 25 Args: 26 trainer: The instantiated trainer. 27 save_root: The root directury for writing checkpoints and log files. 28 project_name: The name of the weights and biases project for these logs. 29 log_model_freq: The frequency for logging the model. 30 log_model_graph: Whether to log the model graph. 31 mode: The logging mode. 32 config: The configuration. 33 resume: 34 """ 35 def __init__( 36 self, 37 trainer, 38 save_root: str, 39 *, 40 project_name: Optional[str] = None, 41 log_model: Optional[Literal["gradients", "parameters", "all"]] = "all", 42 log_model_freq: int = 1, 43 log_model_graph: bool = True, 44 mode: Literal["online", "offline", "disabled"] = "online", 45 config: Optional[dict] = None, 46 resume: Optional[str] = None, 47 **unused_kwargs, 48 ): 49 if wandb is None: 50 raise RuntimeError("WandbLogger is not available") 51 52 super().__init__(trainer, save_root) 53 54 self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs") 55 os.makedirs(self.log_dir, exist_ok=True) 56 57 config = dict(config or {}) 58 config.update(trainer.init_data) 59 self.wand_run = wandb.init( 60 id=resume, project=project_name, name=trainer.name, dir=self.log_dir, 61 mode=mode, config=config, resume="allow" 62 ) 63 trainer.id = self.wand_run.id 64 65 if trainer.name is None: 66 if mode == "online": 67 trainer.name = self.wand_run.name 68 elif mode in ("offline", "disabled"): 69 trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}" 70 trainer.id = trainer.name # if we don't upload the log, name with time stamp is a better run id 71 else: 72 raise ValueError(mode) 73 74 self.log_image_interval = trainer.log_image_interval 75 76 wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph) 77 78 def _log_images(self, step, x, y, prediction, name, gradients=None): 79 80 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 81 82 image = normalize_im(x[selection].cpu()) 83 grid_image, grid_name = make_grid_image(image, y, prediction, selection, gradients) 84 85 # to numpy and channel last 86 image = image.numpy().transpose((1, 2, 0)) 87 wandb.log({f"images_{name}/input": [wandb.Image(image, caption="Input Data")]}, step=step) 88 89 grid_image = grid_image.numpy().transpose((1, 2, 0)) 90 91 wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}, step=step) 92 93 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 94 """@private 95 """ 96 wandb.log({"train/loss": loss}, step=step) 97 if loss < self.wand_run.summary.get("train/loss", np.inf): 98 self.wand_run.summary["train/loss"] = loss 99 100 if step % self.log_image_interval == 0: 101 gradients = prediction.grad if log_gradients else None 102 self._log_images(step, x, y, prediction, "train", gradients=gradients) 103 104 def log_validation(self, step, metric, loss, x, y, prediction): 105 """@private 106 """ 107 wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step) 108 if loss < self.wand_run.summary.get("validation/loss", np.inf): 109 self.wand_run.summary["validation/loss"] = loss 110 111 if metric < self.wand_run.summary.get("validation/metric", np.inf): 112 self.wand_run.summary["validation/metric"] = metric 113 114 self._log_images(step, x, y, prediction, "validation") 115 116 def get_wandb(self): 117 """@private 118 """ 119 return wandb
Logger to write training progress to weights and biases.
Arguments:
- trainer: The instantiated trainer.
- save_root: The root directury for writing checkpoints and log files.
- project_name: The name of the weights and biases project for these logs.
- log_model_freq: The frequency for logging the model.
- log_model_graph: Whether to log the model graph.
- mode: The logging mode.
- config: The configuration.
- resume:
WandbLogger( trainer, save_root: str, *, project_name: Optional[str] = None, log_model: Optional[Literal['gradients', 'parameters', 'all']] = 'all', log_model_freq: int = 1, log_model_graph: bool = True, mode: Literal['online', 'offline', 'disabled'] = 'online', config: Optional[dict] = None, resume: Optional[str] = None, **unused_kwargs)
35 def __init__( 36 self, 37 trainer, 38 save_root: str, 39 *, 40 project_name: Optional[str] = None, 41 log_model: Optional[Literal["gradients", "parameters", "all"]] = "all", 42 log_model_freq: int = 1, 43 log_model_graph: bool = True, 44 mode: Literal["online", "offline", "disabled"] = "online", 45 config: Optional[dict] = None, 46 resume: Optional[str] = None, 47 **unused_kwargs, 48 ): 49 if wandb is None: 50 raise RuntimeError("WandbLogger is not available") 51 52 super().__init__(trainer, save_root) 53 54 self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs") 55 os.makedirs(self.log_dir, exist_ok=True) 56 57 config = dict(config or {}) 58 config.update(trainer.init_data) 59 self.wand_run = wandb.init( 60 id=resume, project=project_name, name=trainer.name, dir=self.log_dir, 61 mode=mode, config=config, resume="allow" 62 ) 63 trainer.id = self.wand_run.id 64 65 if trainer.name is None: 66 if mode == "online": 67 trainer.name = self.wand_run.name 68 elif mode in ("offline", "disabled"): 69 trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}" 70 trainer.id = trainer.name # if we don't upload the log, name with time stamp is a better run id 71 else: 72 raise ValueError(mode) 73 74 self.log_image_interval = trainer.log_image_interval 75 76 wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph)