diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 844b014..d4aa595 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -59,5 +59,7 @@ if __name__ == "__main__": # Load saved model new_model = pt.models.GLVQ.load_from_checkpoint( - checkpoint_path="./glvq_iris.ckpt") + checkpoint_path="./glvq_iris.ckpt", + strict=False, + ) print(new_model) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index f41ec99..71573b3 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -39,9 +39,9 @@ class GLVQ(SupervisedPrototypeModel): beta=self.hparams.transfer_beta, ) - def on_save_checkpoint(self, checkpoint): - if "prototype_win_ratios" in checkpoint["state_dict"]: - del checkpoint["state_dict"]["prototype_win_ratios"] + # def on_save_checkpoint(self, checkpoint): + # if "prototype_win_ratios" in checkpoint["state_dict"]: + # del checkpoint["state_dict"]["prototype_win_ratios"] def initialize_prototype_win_ratios(self): self.register_buffer(