torch_em.util.reporting

 1from .util import get_trainer
 2
 3
 4def _get_n_images(loader):
 5    ds = loader.dataset
 6    n_images = None
 7    if "ImageCollectionDataset" in str(ds):
 8        n_images = len(ds.raw_images)
 9    # TODO cover other cases
10    return n_images
11
12
13def _get_training_summary(trainer, lr):
14
15    n_epochs = trainer.epoch
16    batches_per_epoch = len(trainer.train_loader)
17    batch_size = trainer.train_loader.batch_size
18    print("The model was trained for", n_epochs, "epochs with length", batches_per_epoch, "and batch size", batch_size)
19
20    loss = str(trainer.loss)
21    if loss.startswith("LossWrapper"):
22        loss = loss.split("\n")[1]
23        index = loss.find(":")
24        loss = loss[index+1:]
25    loss = loss.replace(" ", "").replace(")", "").replace("(", "")
26    print("It was trained with", loss, "as loss function")
27
28    opt_ = str(trainer.optimizer)
29    if lr is None:
30        print("Learning rate is determined from optimizer - this will be the final, not initial learning rate")
31        i0 = opt_.find("lr:")
32        i1 = opt_.find("\n", i0)
33        lr = opt_[i0+3:i1].replace(" ", "")
34    opt = opt_[:opt_.find(" ")]
35    print("And using the", opt, "optimizer with learning rate", lr)
36
37    n_train = _get_n_images(trainer.train_loader)
38    n_val = _get_n_images(trainer.val_loader)
39    print(n_train, "images were used for training and", n_val, "for validation")
40
41    report = dict(
42        n_epochs=n_epochs, batches_per_epoch=batches_per_epoch, batch_size=batch_size,
43        loss_function=loss, optimizer=opt, learning_rate=lr,
44        n_train_images=n_train, n_validation_images=n_val
45    )
46    if n_train is not None:
47        report["n_train_images"] = n_train
48    if n_val is not None:
49        report["n_val_images"] = n_val
50    return report
51
52
53def get_training_summary(
54    ckpt, lr=None, model_name="best", to_md=False
55):
56    trainer = get_trainer(ckpt, name=model_name)
57    print("Model summary for", ckpt, "using the", model_name, "model")
58    training_summary = _get_training_summary(trainer, lr)
59    if to_md:
60        training_summary = "\n".join(f"- {k}: {v}" for k, v in training_summary.items())
61    return training_summary
def get_training_summary(ckpt, lr=None, model_name='best', to_md=False):
54def get_training_summary(
55    ckpt, lr=None, model_name="best", to_md=False
56):
57    trainer = get_trainer(ckpt, name=model_name)
58    print("Model summary for", ckpt, "using the", model_name, "model")
59    training_summary = _get_training_summary(trainer, lr)
60    if to_md:
61        training_summary = "\n".join(f"- {k}: {v}" for k, v in training_summary.items())
62    return training_summary