fix: ignore prototype_win_ratios by loading with strict=False

This commit is contained in:
Jensun Ravichandran 2022-02-02 21:52:01 +01:00
parent 197b728c63
commit 15e7232747
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921
2 changed files with 6 additions and 4 deletions

View File

@ -59,5 +59,7 @@ if __name__ == "__main__":
# Load saved model # Load saved model
new_model = pt.models.GLVQ.load_from_checkpoint( new_model = pt.models.GLVQ.load_from_checkpoint(
checkpoint_path="./glvq_iris.ckpt") checkpoint_path="./glvq_iris.ckpt",
strict=False,
)
print(new_model) print(new_model)

View File

@ -39,9 +39,9 @@ class GLVQ(SupervisedPrototypeModel):
beta=self.hparams.transfer_beta, beta=self.hparams.transfer_beta,
) )
def on_save_checkpoint(self, checkpoint): # def on_save_checkpoint(self, checkpoint):
if "prototype_win_ratios" in checkpoint["state_dict"]: # if "prototype_win_ratios" in checkpoint["state_dict"]:
del checkpoint["state_dict"]["prototype_win_ratios"] # del checkpoint["state_dict"]["prototype_win_ratios"]
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.register_buffer( self.register_buffer(