torch_em.util.reporting
1from typing import Dict, Optional, Union 2from .util import get_trainer 3 4 5def _get_n_images(loader): 6 ds = loader.dataset 7 n_images = None 8 if "ImageCollectionDataset" in str(ds): 9 n_images = len(ds.raw_images) 10 # TODO cover other cases 11 return n_images 12 13 14def _get_training_summary(trainer, lr): 15 16 n_epochs = trainer.epoch 17 batches_per_epoch = len(trainer.train_loader) 18 batch_size = trainer.train_loader.batch_size 19 print("The model was trained for", n_epochs, "epochs with length", batches_per_epoch, "and batch size", batch_size) 20 21 loss = str(trainer.loss) 22 if loss.startswith("LossWrapper"): 23 loss = loss.split("\n")[1] 24 index = loss.find(":") 25 loss = loss[index+1:] 26 loss = loss.replace(" ", "").replace(")", "").replace("(", "") 27 print("It was trained with", loss, "as loss function") 28 29 opt_ = str(trainer.optimizer) 30 if lr is None: 31 print("Learning rate is determined from optimizer - this will be the final, not initial learning rate") 32 i0 = opt_.find("lr:") 33 i1 = opt_.find("\n", i0) 34 lr = opt_[i0+3:i1].replace(" ", "") 35 opt = opt_[:opt_.find(" ")] 36 print("And using the", opt, "optimizer with learning rate", lr) 37 38 n_train = _get_n_images(trainer.train_loader) 39 n_val = _get_n_images(trainer.val_loader) 40 print(n_train, "images were used for training and", n_val, "for validation") 41 42 report = dict( 43 n_epochs=n_epochs, batches_per_epoch=batches_per_epoch, batch_size=batch_size, 44 loss_function=loss, optimizer=opt, learning_rate=lr, 45 n_train_images=n_train, n_validation_images=n_val 46 ) 47 if n_train is not None: 48 report["n_train_images"] = n_train 49 if n_val is not None: 50 report["n_val_images"] = n_val 51 return report 52 53 54def get_training_summary( 55 ckpt: str, 56 lr: Optional[float] = None, 57 model_name: str = "best", 58 to_md: bool = False, 59) -> Union[str, Dict]: 60 """Summarize the training process of a checkpoint. 61 62 Args: 63 ckpt: The checkpoint. 64 lr: The initial learning rate this model was trained with. 65 The initial learning rate cannot be read from the checkpoint. 66 model_name: The name of the checkpoint to load. 67 to_md: Whether to translate the training summary to markdown. 68 69 Returns: 70 The training summary, either as a dictionary or markdown str (if `to_md=True`). 71 """ 72 trainer = get_trainer(ckpt, name=model_name) 73 print("Model summary for", ckpt, "using the", model_name, "model") 74 training_summary = _get_training_summary(trainer, lr) 75 if to_md: 76 training_summary = "\n".join(f"- {k}: {v}" for k, v in training_summary.items()) 77 return training_summary
def
get_training_summary( ckpt: str, lr: Optional[float] = None, model_name: str = 'best', to_md: bool = False) -> Union[str, Dict]:
55def get_training_summary( 56 ckpt: str, 57 lr: Optional[float] = None, 58 model_name: str = "best", 59 to_md: bool = False, 60) -> Union[str, Dict]: 61 """Summarize the training process of a checkpoint. 62 63 Args: 64 ckpt: The checkpoint. 65 lr: The initial learning rate this model was trained with. 66 The initial learning rate cannot be read from the checkpoint. 67 model_name: The name of the checkpoint to load. 68 to_md: Whether to translate the training summary to markdown. 69 70 Returns: 71 The training summary, either as a dictionary or markdown str (if `to_md=True`). 72 """ 73 trainer = get_trainer(ckpt, name=model_name) 74 print("Model summary for", ckpt, "using the", model_name, "model") 75 training_summary = _get_training_summary(trainer, lr) 76 if to_md: 77 training_summary = "\n".join(f"- {k}: {v}" for k, v in training_summary.items()) 78 return training_summary
Summarize the training process of a checkpoint.
Arguments:
- ckpt: The checkpoint.
- lr: The initial learning rate this model was trained with. The initial learning rate cannot be read from the checkpoint.
- model_name: The name of the checkpoint to load.
- to_md: Whether to translate the training summary to markdown.
Returns:
The training summary, either as a dictionary or markdown str (if
to_md=True
).