adjust omega_matrix property
This commit is contained in:
parent
824dfced92
commit
1786031b4e
@ -238,7 +238,8 @@ class GMLMLVQ(GLVQ):
|
||||
"""
|
||||
|
||||
# Parameters
|
||||
_omega: torch.Tensor
|
||||
_omega_0: torch.Tensor
|
||||
_omega_1: torch.Tensor
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
|
||||
@ -255,8 +256,8 @@ class GMLMLVQ(GLVQ):
|
||||
self.mask_1 = masks[1]
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
return self._omega.detach().cpu()
|
||||
def omega_matrices(self):
|
||||
return [self._omega_0.detach().cpu(), self._omega_1.detach().cpu()]
|
||||
|
||||
@property
|
||||
def lambda_matrix(self):
|
||||
|
Loading…
Reference in New Issue
Block a user