fix: GLVQ can now be restored from checkpoint

This commit is contained in:
Alexander Engelsberger 2022-02-02 16:17:11 +01:00
parent 75a39f5b03
commit d5855dbe97
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
2 changed files with 12 additions and 1 deletions

View File

@ -3,6 +3,7 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.core.initializers import ZerosCompInitializer
from ..core.competitions import WTAC from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents from ..core.components import Components, LabeledComponents
@ -120,6 +121,13 @@ class SupervisedPrototypeModel(PrototypeModel):
components_initializer=prototypes_initializer, components_initializer=prototypes_initializer,
labels_initializer=labels_initializer, labels_initializer=labels_initializer,
) )
self.hparams.initialized_proto_dims = self.proto_layer.components.shape[
1:]
else:
self.proto_layer = LabeledComponents(
self.hparams.distribution,
ZerosCompInitializer(self.hparams.initialized_proto_dims),
)
self.competition_layer = WTAC() self.competition_layer = WTAC()
@property @property
@ -177,7 +185,6 @@ class SupervisedPrototypeModel(PrototypeModel):
class ProtoTorchMixin(object): class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins.""" """All mixins are ProtoTorchMixins."""
pass
class NonGradientMixin(ProtoTorchMixin): class NonGradientMixin(ProtoTorchMixin):

View File

@ -39,6 +39,10 @@ class GLVQ(SupervisedPrototypeModel):
beta=self.hparams.transfer_beta, beta=self.hparams.transfer_beta,
) )
def on_save_checkpoint(self, checkpoint):
if "prototype_win_ratios" in checkpoint["state_dict"]:
del checkpoint["state_dict"]["prototype_win_ratios"]
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.register_buffer( self.register_buffer(
"prototype_win_ratios", "prototype_win_ratios",