diff --git a/examples/liramlvq_tecator.py b/examples/liramlvq_tecator.py index 53219ab..f948b87 100644 --- a/examples/liramlvq_tecator.py +++ b/examples/liramlvq_tecator.py @@ -37,3 +37,12 @@ if __name__ == "__main__": # Training loop trainer.fit(model, train_loader) + + # Save the model + torch.save(model, "liramlvq_tecator.pt") + + # Load a saved model + saved_model = torch.load("liramlvq_tecator.pt") + + # Display the Lambda matrix + saved_model.show_lambda() diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 7f21e48..07cba8f 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -356,6 +356,9 @@ class Vis2DAbstract(pl.Callback): else: plt.show(block=True) + def on_train_end(self, trainer, pl_module): + plt.show() + class VisGLVQ2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module):