Example to save and reload a model
This commit is contained in:
parent
ed03ab168e
commit
49100f43f5
@ -37,3 +37,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
torch.save(model, "liramlvq_tecator.pt")
|
||||||
|
|
||||||
|
# Load a saved model
|
||||||
|
saved_model = torch.load("liramlvq_tecator.pt")
|
||||||
|
|
||||||
|
# Display the Lambda matrix
|
||||||
|
saved_model.show_lambda()
|
||||||
|
@ -356,6 +356,9 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
else:
|
else:
|
||||||
plt.show(block=True)
|
plt.show(block=True)
|
||||||
|
|
||||||
|
def on_train_end(self, trainer, pl_module):
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
class VisGLVQ2D(Vis2DAbstract):
|
class VisGLVQ2D(Vis2DAbstract):
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
Loading…
Reference in New Issue
Block a user