Example to save and reload a model
This commit is contained in:
		@@ -37,3 +37,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Save the model
 | 
				
			||||||
 | 
					    torch.save(model, "liramlvq_tecator.pt")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load a saved model
 | 
				
			||||||
 | 
					    saved_model = torch.load("liramlvq_tecator.pt")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Display the Lambda matrix
 | 
				
			||||||
 | 
					    saved_model.show_lambda()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -356,6 +356,9 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            plt.show(block=True)
 | 
					            plt.show(block=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_train_end(self, trainer, pl_module):
 | 
				
			||||||
 | 
					        plt.show()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisGLVQ2D(Vis2DAbstract):
 | 
					class VisGLVQ2D(Vis2DAbstract):
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user