fix: GLVQ can now be restored from checkpoint
This commit is contained in:
parent
75a39f5b03
commit
d5855dbe97
@ -3,6 +3,7 @@
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
from prototorch.core.initializers import ZerosCompInitializer
|
||||||
|
|
||||||
from ..core.competitions import WTAC
|
from ..core.competitions import WTAC
|
||||||
from ..core.components import Components, LabeledComponents
|
from ..core.components import Components, LabeledComponents
|
||||||
@ -120,6 +121,13 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
components_initializer=prototypes_initializer,
|
components_initializer=prototypes_initializer,
|
||||||
labels_initializer=labels_initializer,
|
labels_initializer=labels_initializer,
|
||||||
)
|
)
|
||||||
|
self.hparams.initialized_proto_dims = self.proto_layer.components.shape[
|
||||||
|
1:]
|
||||||
|
else:
|
||||||
|
self.proto_layer = LabeledComponents(
|
||||||
|
self.hparams.distribution,
|
||||||
|
ZerosCompInitializer(self.hparams.initialized_proto_dims),
|
||||||
|
)
|
||||||
self.competition_layer = WTAC()
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -177,7 +185,6 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
class ProtoTorchMixin(object):
|
class ProtoTorchMixin(object):
|
||||||
"""All mixins are ProtoTorchMixins."""
|
"""All mixins are ProtoTorchMixins."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NonGradientMixin(ProtoTorchMixin):
|
class NonGradientMixin(ProtoTorchMixin):
|
||||||
|
@ -39,6 +39,10 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
beta=self.hparams.transfer_beta,
|
beta=self.hparams.transfer_beta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
if "prototype_win_ratios" in checkpoint["state_dict"]:
|
||||||
|
del checkpoint["state_dict"]["prototype_win_ratios"]
|
||||||
|
|
||||||
def initialize_prototype_win_ratios(self):
|
def initialize_prototype_win_ratios(self):
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"prototype_win_ratios",
|
"prototype_win_ratios",
|
||||||
|
Loading…
Reference in New Issue
Block a user