fix: GLVQ can now be restored from checkpoint

This commit is contained in:
Alexander Engelsberger 2022-02-02 16:17:11 +01:00
parent 75a39f5b03
commit d5855dbe97
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
2 changed files with 12 additions and 1 deletions

View File

@ -3,6 +3,7 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.core.initializers import ZerosCompInitializer
from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents
@ -120,6 +121,13 @@ class SupervisedPrototypeModel(PrototypeModel):
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
self.hparams.initialized_proto_dims = self.proto_layer.components.shape[
1:]
else:
self.proto_layer = LabeledComponents(
self.hparams.distribution,
ZerosCompInitializer(self.hparams.initialized_proto_dims),
)
self.competition_layer = WTAC()
@property
@ -177,7 +185,6 @@ class SupervisedPrototypeModel(PrototypeModel):
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
pass
class NonGradientMixin(ProtoTorchMixin):

View File

@ -39,6 +39,10 @@ class GLVQ(SupervisedPrototypeModel):
beta=self.hparams.transfer_beta,
)
def on_save_checkpoint(self, checkpoint):
if "prototype_win_ratios" in checkpoint["state_dict"]:
del checkpoint["state_dict"]["prototype_win_ratios"]
def initialize_prototype_win_ratios(self):
self.register_buffer(
"prototype_win_ratios",