Fix: saving GMLVQ and GRLVQ fixed

This commit is contained in:
Alexander Engelsberger
2023-03-09 15:50:13 +01:00
parent 87fa3f0729
commit 46dfb82371
4 changed files with 82 additions and 5 deletions

View File

@@ -71,7 +71,7 @@ class PrototypeModel(ProtoTorchBolt):
super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
self.distance_layer = LambdaLayer(distance_fn, name="distance_fn")
@property
def num_prototypes(self):

View File

@@ -209,9 +209,12 @@ class GRLVQ(SiameseGLVQ):
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
self.backbone = LambdaLayer(self._apply_relevances,
name="relevance scaling")
def _apply_relevances(self, x):
return x @ torch.diag(self._relevances)
@property
def relevance_profile(self):
return self._relevances.detach().cpu()
@@ -271,9 +274,7 @@ class GMLVQ(GLVQ):
omega = omega_initializer.generate(self.hparams["input_dim"],
self.hparams["latent_dim"])
self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega,
name="omega matrix")
@property
def omega_matrix(self):
return self._omega.detach().cpu()