torch_em.util.reporting

 1from typing import Dict, Optional, Union
 2from .util import get_trainer
 3
 4
 5def _get_n_images(loader):
 6    ds = loader.dataset
 7    n_images = None
 8    if "ImageCollectionDataset" in str(ds):
 9        n_images = len(ds.raw_images)
10    # TODO cover other cases
11    return n_images
12
13
14def _get_training_summary(trainer, lr):
15
16    n_epochs = trainer.epoch
17    batches_per_epoch = len(trainer.train_loader)
18    batch_size = trainer.train_loader.batch_size
19    print("The model was trained for", n_epochs, "epochs with length", batches_per_epoch, "and batch size", batch_size)
20
21    loss = str(trainer.loss)
22    if loss.startswith("LossWrapper"):
23        loss = loss.split("\n")[1]
24        index = loss.find(":")
25        loss = loss[index+1:]
26    loss = loss.replace(" ", "").replace(")", "").replace("(", "")
27    print("It was trained with", loss, "as loss function")
28
29    opt_ = str(trainer.optimizer)
30    if lr is None:
31        print("Learning rate is determined from optimizer - this will be the final, not initial learning rate")
32        i0 = opt_.find("lr:")
33        i1 = opt_.find("\n", i0)
34        lr = opt_[i0+3:i1].replace(" ", "")
35    opt = opt_[:opt_.find(" ")]
36    print("And using the", opt, "optimizer with learning rate", lr)
37
38    n_train = _get_n_images(trainer.train_loader)
39    n_val = _get_n_images(trainer.val_loader)
40    print(n_train, "images were used for training and", n_val, "for validation")
41
42    report = dict(
43        n_epochs=n_epochs, batches_per_epoch=batches_per_epoch, batch_size=batch_size,
44        loss_function=loss, optimizer=opt, learning_rate=lr,
45        n_train_images=n_train, n_validation_images=n_val
46    )
47    if n_train is not None:
48        report["n_train_images"] = n_train
49    if n_val is not None:
50        report["n_val_images"] = n_val
51    return report
52
53
54def get_training_summary(
55    ckpt: str,
56    lr: Optional[float] = None,
57    model_name: str = "best",
58    to_md: bool = False,
59) -> Union[str, Dict]:
60    """Summarize the training process of a checkpoint.
61
62    Args:
63        ckpt: The checkpoint.
64        lr: The initial learning rate this model was trained with.
65            The initial learning rate cannot be read from the checkpoint.
66        model_name: The name of the checkpoint to load.
67        to_md: Whether to translate the training summary to markdown.
68
69    Returns:
70        The training summary, either as a dictionary or markdown str (if `to_md=True`).
71    """
72    trainer = get_trainer(ckpt, name=model_name)
73    print("Model summary for", ckpt, "using the", model_name, "model")
74    training_summary = _get_training_summary(trainer, lr)
75    if to_md:
76        training_summary = "\n".join(f"- {k}: {v}" for k, v in training_summary.items())
77    return training_summary
def get_training_summary( ckpt: str, lr: Optional[float] = None, model_name: str = 'best', to_md: bool = False) -> Union[str, Dict]:
55def get_training_summary(
56    ckpt: str,
57    lr: Optional[float] = None,
58    model_name: str = "best",
59    to_md: bool = False,
60) -> Union[str, Dict]:
61    """Summarize the training process of a checkpoint.
62
63    Args:
64        ckpt: The checkpoint.
65        lr: The initial learning rate this model was trained with.
66            The initial learning rate cannot be read from the checkpoint.
67        model_name: The name of the checkpoint to load.
68        to_md: Whether to translate the training summary to markdown.
69
70    Returns:
71        The training summary, either as a dictionary or markdown str (if `to_md=True`).
72    """
73    trainer = get_trainer(ckpt, name=model_name)
74    print("Model summary for", ckpt, "using the", model_name, "model")
75    training_summary = _get_training_summary(trainer, lr)
76    if to_md:
77        training_summary = "\n".join(f"- {k}: {v}" for k, v in training_summary.items())
78    return training_summary

Summarize the training process of a checkpoint.

Arguments:
  • ckpt: The checkpoint.
  • lr: The initial learning rate this model was trained with. The initial learning rate cannot be read from the checkpoint.
  • model_name: The name of the checkpoint to load.
  • to_md: Whether to translate the training summary to markdown.
Returns:

The training summary, either as a dictionary or markdown str (if to_md=True).