fix: fix problems with y architecture and checkpoint
This commit is contained in:
@@ -13,8 +13,8 @@ from torch.utils.data import DataLoader
|
||||
|
||||
# ##############################################################################
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
# ------------------------------------------------------------
|
||||
# DATA
|
||||
# ------------------------------------------------------------
|
||||
@@ -51,7 +51,7 @@ if __name__ == "__main__":
|
||||
# Create Model
|
||||
model = GMLVQ(hyperparameters)
|
||||
|
||||
print(model)
|
||||
print(model.hparams)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# TRAINING
|
||||
@@ -74,15 +74,27 @@ if __name__ == "__main__":
|
||||
vis = VisGMLVQ2D(data=train_ds)
|
||||
|
||||
# Define trainer
|
||||
trainer = pl.Trainer(
|
||||
callbacks=[
|
||||
vis,
|
||||
stopping_criterion,
|
||||
es,
|
||||
PlotLambdaMatrixToTensorboard(),
|
||||
],
|
||||
max_epochs=1000,
|
||||
)
|
||||
trainer = pl.Trainer(callbacks=[
|
||||
vis,
|
||||
stopping_criterion,
|
||||
es,
|
||||
PlotLambdaMatrixToTensorboard(),
|
||||
], )
|
||||
|
||||
# Train
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# Manual save
|
||||
trainer.save_checkpoint("./y_arch.ckpt")
|
||||
|
||||
# Load saved model
|
||||
new_model = GMLVQ.load_from_checkpoint(
|
||||
checkpoint_path="./y_arch.ckpt",
|
||||
strict=True,
|
||||
)
|
||||
|
||||
print(new_model.hparams)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user