From 15e723274720106215bbf9a9e2b47592a79dcec8 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 2 Feb 2022 21:52:01 +0100 Subject: [PATCH] fix: ignore `prototype_win_ratios` by loading with `strict=False` --- examples/glvq_iris.py | 4 +++- prototorch/models/glvq.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 844b014..d4aa595 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -59,5 +59,7 @@ if __name__ == "__main__": # Load saved model new_model = pt.models.GLVQ.load_from_checkpoint( - checkpoint_path="./glvq_iris.ckpt") + checkpoint_path="./glvq_iris.ckpt", + strict=False, + ) print(new_model) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index f41ec99..71573b3 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -39,9 +39,9 @@ class GLVQ(SupervisedPrototypeModel): beta=self.hparams.transfer_beta, ) - def on_save_checkpoint(self, checkpoint): - if "prototype_win_ratios" in checkpoint["state_dict"]: - del checkpoint["state_dict"]["prototype_win_ratios"] + # def on_save_checkpoint(self, checkpoint): + # if "prototype_win_ratios" in checkpoint["state_dict"]: + # del checkpoint["state_dict"]["prototype_win_ratios"] def initialize_prototype_win_ratios(self): self.register_buffer(