fix: GLVQ can now be restored from checkpoint
This commit is contained in:
parent
75a39f5b03
commit
d5855dbe97
@ -3,6 +3,7 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.core.initializers import ZerosCompInitializer
|
||||
|
||||
from ..core.competitions import WTAC
|
||||
from ..core.components import Components, LabeledComponents
|
||||
@ -120,6 +121,13 @@ class SupervisedPrototypeModel(PrototypeModel):
|
||||
components_initializer=prototypes_initializer,
|
||||
labels_initializer=labels_initializer,
|
||||
)
|
||||
self.hparams.initialized_proto_dims = self.proto_layer.components.shape[
|
||||
1:]
|
||||
else:
|
||||
self.proto_layer = LabeledComponents(
|
||||
self.hparams.distribution,
|
||||
ZerosCompInitializer(self.hparams.initialized_proto_dims),
|
||||
)
|
||||
self.competition_layer = WTAC()
|
||||
|
||||
@property
|
||||
@ -177,7 +185,6 @@ class SupervisedPrototypeModel(PrototypeModel):
|
||||
|
||||
class ProtoTorchMixin(object):
|
||||
"""All mixins are ProtoTorchMixins."""
|
||||
pass
|
||||
|
||||
|
||||
class NonGradientMixin(ProtoTorchMixin):
|
||||
|
@ -39,6 +39,10 @@ class GLVQ(SupervisedPrototypeModel):
|
||||
beta=self.hparams.transfer_beta,
|
||||
)
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
if "prototype_win_ratios" in checkpoint["state_dict"]:
|
||||
del checkpoint["state_dict"]["prototype_win_ratios"]
|
||||
|
||||
def initialize_prototype_win_ratios(self):
|
||||
self.register_buffer(
|
||||
"prototype_win_ratios",
|
||||
|
Loading…
Reference in New Issue
Block a user