Example to save and reload a model

This commit is contained in:
Jensun Ravichandran 2021-05-10 14:30:02 +02:00
parent ed03ab168e
commit 49100f43f5
2 changed files with 12 additions and 0 deletions

View File

@ -37,3 +37,12 @@ if __name__ == "__main__":
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
# Save the model
torch.save(model, "liramlvq_tecator.pt")
# Load a saved model
saved_model = torch.load("liramlvq_tecator.pt")
# Display the Lambda matrix
saved_model.show_lambda()

View File

@ -356,6 +356,9 @@ class Vis2DAbstract(pl.Callback):
else: else:
plt.show(block=True) plt.show(block=True)
def on_train_end(self, trainer, pl_module):
plt.show()
class VisGLVQ2D(Vis2DAbstract): class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):