adjust omega_matrix property

This commit is contained in:
julius 2023-11-06 16:32:57 +01:00
parent 824dfced92
commit 1786031b4e
Signed by untrusted user who does not match committer: julius
GPG Key ID: 8AA3791362A8084A

View File

@ -238,7 +238,8 @@ class GMLMLVQ(GLVQ):
"""
# Parameters
_omega: torch.Tensor
_omega_0: torch.Tensor
_omega_1: torch.Tensor
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
@ -255,8 +256,8 @@ class GMLMLVQ(GLVQ):
self.mask_1 = masks[1]
@property
def omega_matrix(self):
return self._omega.detach().cpu()
def omega_matrices(self):
return [self._omega_0.detach().cpu(), self._omega_1.detach().cpu()]
@property
def lambda_matrix(self):