From 49100f43f57d16e49939c0a8c4882f67d2db1a7b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 10 May 2021 14:30:02 +0200 Subject: [PATCH] Example to save and reload a model --- examples/liramlvq_tecator.py | 9 +++++++++ prototorch/models/vis.py | 3 +++ 2 files changed, 12 insertions(+) diff --git a/examples/liramlvq_tecator.py b/examples/liramlvq_tecator.py index 53219ab..f948b87 100644 --- a/examples/liramlvq_tecator.py +++ b/examples/liramlvq_tecator.py @@ -37,3 +37,12 @@ if __name__ == "__main__": # Training loop trainer.fit(model, train_loader) + + # Save the model + torch.save(model, "liramlvq_tecator.pt") + + # Load a saved model + saved_model = torch.load("liramlvq_tecator.pt") + + # Display the Lambda matrix + saved_model.show_lambda() diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 7f21e48..07cba8f 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -356,6 +356,9 @@ class Vis2DAbstract(pl.Callback): else: plt.show(block=True) + def on_train_end(self, trainer, pl_module): + plt.show() + class VisGLVQ2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module):