Fix: saving GMLVQ and GRLVQ fixed
This commit is contained in:
@@ -71,7 +71,7 @@ class PrototypeModel(ProtoTorchBolt):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||
self.distance_layer = LambdaLayer(distance_fn)
|
||||
self.distance_layer = LambdaLayer(distance_fn, name="distance_fn")
|
||||
|
||||
@property
|
||||
def num_prototypes(self):
|
||||
|
@@ -209,9 +209,12 @@ class GRLVQ(SiameseGLVQ):
|
||||
self.register_parameter("_relevances", Parameter(relevances))
|
||||
|
||||
# Override the backbone
|
||||
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
|
||||
self.backbone = LambdaLayer(self._apply_relevances,
|
||||
name="relevance scaling")
|
||||
|
||||
def _apply_relevances(self, x):
|
||||
return x @ torch.diag(self._relevances)
|
||||
|
||||
@property
|
||||
def relevance_profile(self):
|
||||
return self._relevances.detach().cpu()
|
||||
@@ -271,9 +274,7 @@ class GMLVQ(GLVQ):
|
||||
omega = omega_initializer.generate(self.hparams["input_dim"],
|
||||
self.hparams["latent_dim"])
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
self.backbone = LambdaLayer(lambda x: x @ self._omega,
|
||||
name="omega matrix")
|
||||
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
return self._omega.detach().cpu()
|
||||
|
Reference in New Issue
Block a user