fix: fix problems with y architecture and checkpoint

This commit is contained in:
Alexander Engelsberger
2022-06-12 10:36:15 +02:00
parent fe729781fc
commit be7d7f43bd
4 changed files with 47 additions and 32 deletions

View File

@@ -13,8 +13,8 @@ from torch.utils.data import DataLoader
# ##############################################################################
if __name__ == "__main__":
def main():
# ------------------------------------------------------------
# DATA
# ------------------------------------------------------------
@@ -51,7 +51,7 @@ if __name__ == "__main__":
# Create Model
model = GMLVQ(hyperparameters)
print(model)
print(model.hparams)
# ------------------------------------------------------------
# TRAINING
@@ -74,15 +74,27 @@ if __name__ == "__main__":
vis = VisGMLVQ2D(data=train_ds)
# Define trainer
trainer = pl.Trainer(
callbacks=[
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
],
max_epochs=1000,
)
trainer = pl.Trainer(callbacks=[
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
], )
# Train
trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./y_arch.ckpt")
# Load saved model
new_model = GMLVQ.load_from_checkpoint(
checkpoint_path="./y_arch.ckpt",
strict=True,
)
print(new_model.hparams)
if __name__ == "__main__":
main()