torch_em.trainer.wandb_logger
1import os 2from datetime import datetime 3from typing import Optional 4 5import numpy as np 6 7from .logger_base import TorchEmLogger 8from .tensorboard_logger import make_grid_image, normalize_im 9 10try: 11 import wandb 12except ImportError: 13 wandb = None 14 15try: 16 from typing import Literal 17except ImportError: 18 from typing_extensions import Literal # type: ignore 19 20 21class WandbLogger(TorchEmLogger): 22 def __init__( 23 self, 24 trainer, 25 save_root, 26 *, 27 project_name: Optional[str] = None, 28 log_model: Optional[Literal["gradients", "parameters", "all"]] = "all", 29 log_model_freq: int = 1, 30 log_model_graph: bool = True, 31 mode: Literal["online", "offline", "disabled"] = "online", 32 config: Optional[dict] = None, 33 resume: Optional[str] = None, 34 **unused_kwargs, 35 ): 36 if wandb is None: 37 raise RuntimeError("WandbLogger is not available") 38 39 super().__init__(trainer, save_root) 40 41 self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs") 42 os.makedirs(self.log_dir, exist_ok=True) 43 44 config = dict(config or {}) 45 config.update(trainer.init_data) 46 self.wand_run = wandb.init( 47 id=resume, project=project_name, name=trainer.name, dir=self.log_dir, mode=mode, config=config, resume="allow" 48 ) 49 trainer.id = self.wand_run.id 50 51 if trainer.name is None: 52 if mode == "online": 53 trainer.name = self.wand_run.name 54 elif mode in ("offline", "disabled"): 55 trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}" 56 trainer.id = trainer.name # if we don't upload the log, name with time stamp is a better run id 57 else: 58 raise ValueError(mode) 59 60 self.log_image_interval = trainer.log_image_interval 61 62 wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph) 63 64 def _log_images(self, step, x, y, prediction, name, gradients=None): 65 66 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 67 68 image = normalize_im(x[selection].cpu()) 69 grid_image, grid_name = make_grid_image(image, y, prediction, selection, gradients) 70 71 # to numpy and channel last 72 image = image.numpy().transpose((1, 2, 0)) 73 wandb.log({f"images_{name}/input": [wandb.Image(image, caption="Input Data")]}, step=step) 74 75 grid_image = grid_image.numpy().transpose((1, 2, 0)) 76 77 wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}, step=step) 78 79 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 80 wandb.log({"train/loss": loss}, step=step) 81 if loss < self.wand_run.summary.get("train/loss", np.inf): 82 self.wand_run.summary["train/loss"] = loss 83 84 if step % self.log_image_interval == 0: 85 gradients = prediction.grad if log_gradients else None 86 self._log_images(step, x, y, prediction, "train", gradients=gradients) 87 88 def log_validation(self, step, metric, loss, x, y, prediction): 89 wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step) 90 if loss < self.wand_run.summary.get("validation/loss", np.inf): 91 self.wand_run.summary["validation/loss"] = loss 92 93 if metric < self.wand_run.summary.get("validation/metric", np.inf): 94 self.wand_run.summary["validation/metric"] = metric 95 96 self._log_images(step, x, y, prediction, "validation") 97 98 def get_wandb(self): 99 return wandb
22class WandbLogger(TorchEmLogger): 23 def __init__( 24 self, 25 trainer, 26 save_root, 27 *, 28 project_name: Optional[str] = None, 29 log_model: Optional[Literal["gradients", "parameters", "all"]] = "all", 30 log_model_freq: int = 1, 31 log_model_graph: bool = True, 32 mode: Literal["online", "offline", "disabled"] = "online", 33 config: Optional[dict] = None, 34 resume: Optional[str] = None, 35 **unused_kwargs, 36 ): 37 if wandb is None: 38 raise RuntimeError("WandbLogger is not available") 39 40 super().__init__(trainer, save_root) 41 42 self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs") 43 os.makedirs(self.log_dir, exist_ok=True) 44 45 config = dict(config or {}) 46 config.update(trainer.init_data) 47 self.wand_run = wandb.init( 48 id=resume, project=project_name, name=trainer.name, dir=self.log_dir, mode=mode, config=config, resume="allow" 49 ) 50 trainer.id = self.wand_run.id 51 52 if trainer.name is None: 53 if mode == "online": 54 trainer.name = self.wand_run.name 55 elif mode in ("offline", "disabled"): 56 trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}" 57 trainer.id = trainer.name # if we don't upload the log, name with time stamp is a better run id 58 else: 59 raise ValueError(mode) 60 61 self.log_image_interval = trainer.log_image_interval 62 63 wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph) 64 65 def _log_images(self, step, x, y, prediction, name, gradients=None): 66 67 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 68 69 image = normalize_im(x[selection].cpu()) 70 grid_image, grid_name = make_grid_image(image, y, prediction, selection, gradients) 71 72 # to numpy and channel last 73 image = image.numpy().transpose((1, 2, 0)) 74 wandb.log({f"images_{name}/input": [wandb.Image(image, caption="Input Data")]}, step=step) 75 76 grid_image = grid_image.numpy().transpose((1, 2, 0)) 77 78 wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}, step=step) 79 80 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 81 wandb.log({"train/loss": loss}, step=step) 82 if loss < self.wand_run.summary.get("train/loss", np.inf): 83 self.wand_run.summary["train/loss"] = loss 84 85 if step % self.log_image_interval == 0: 86 gradients = prediction.grad if log_gradients else None 87 self._log_images(step, x, y, prediction, "train", gradients=gradients) 88 89 def log_validation(self, step, metric, loss, x, y, prediction): 90 wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step) 91 if loss < self.wand_run.summary.get("validation/loss", np.inf): 92 self.wand_run.summary["validation/loss"] = loss 93 94 if metric < self.wand_run.summary.get("validation/metric", np.inf): 95 self.wand_run.summary["validation/metric"] = metric 96 97 self._log_images(step, x, y, prediction, "validation") 98 99 def get_wandb(self): 100 return wandb
WandbLogger( trainer, save_root, *, project_name: Optional[str] = None, log_model: Optional[Literal['gradients', 'parameters', 'all']] = 'all', log_model_freq: int = 1, log_model_graph: bool = True, mode: Literal['online', 'offline', 'disabled'] = 'online', config: Optional[dict] = None, resume: Optional[str] = None, **unused_kwargs)
23 def __init__( 24 self, 25 trainer, 26 save_root, 27 *, 28 project_name: Optional[str] = None, 29 log_model: Optional[Literal["gradients", "parameters", "all"]] = "all", 30 log_model_freq: int = 1, 31 log_model_graph: bool = True, 32 mode: Literal["online", "offline", "disabled"] = "online", 33 config: Optional[dict] = None, 34 resume: Optional[str] = None, 35 **unused_kwargs, 36 ): 37 if wandb is None: 38 raise RuntimeError("WandbLogger is not available") 39 40 super().__init__(trainer, save_root) 41 42 self.log_dir = "./logs" if save_root is None else os.path.join(save_root, "logs") 43 os.makedirs(self.log_dir, exist_ok=True) 44 45 config = dict(config or {}) 46 config.update(trainer.init_data) 47 self.wand_run = wandb.init( 48 id=resume, project=project_name, name=trainer.name, dir=self.log_dir, mode=mode, config=config, resume="allow" 49 ) 50 trainer.id = self.wand_run.id 51 52 if trainer.name is None: 53 if mode == "online": 54 trainer.name = self.wand_run.name 55 elif mode in ("offline", "disabled"): 56 trainer.name = f"{mode}_{datetime.now():%Y-%m-%d_%H-%M-%S}" 57 trainer.id = trainer.name # if we don't upload the log, name with time stamp is a better run id 58 else: 59 raise ValueError(mode) 60 61 self.log_image_interval = trainer.log_image_interval 62 63 wandb.watch(trainer.model, log=log_model, log_freq=log_model_freq, log_graph=log_model_graph)
def
log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
80 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 81 wandb.log({"train/loss": loss}, step=step) 82 if loss < self.wand_run.summary.get("train/loss", np.inf): 83 self.wand_run.summary["train/loss"] = loss 84 85 if step % self.log_image_interval == 0: 86 gradients = prediction.grad if log_gradients else None 87 self._log_images(step, x, y, prediction, "train", gradients=gradients)
def
log_validation(self, step, metric, loss, x, y, prediction):
89 def log_validation(self, step, metric, loss, x, y, prediction): 90 wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step) 91 if loss < self.wand_run.summary.get("validation/loss", np.inf): 92 self.wand_run.summary["validation/loss"] = loss 93 94 if metric < self.wand_run.summary.get("validation/metric", np.inf): 95 self.wand_run.summary["validation/metric"] = metric 96 97 self._log_images(step, x, y, prediction, "validation")