adjust omega_matrix property
This commit is contained in:
parent
824dfced92
commit
1786031b4e
@ -238,7 +238,8 @@ class GMLMLVQ(GLVQ):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
_omega: torch.Tensor
|
_omega_0: torch.Tensor
|
||||||
|
_omega_1: torch.Tensor
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
|
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
|
||||||
@ -255,8 +256,8 @@ class GMLMLVQ(GLVQ):
|
|||||||
self.mask_1 = masks[1]
|
self.mask_1 = masks[1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrix(self):
|
def omega_matrices(self):
|
||||||
return self._omega.detach().cpu()
|
return [self._omega_0.detach().cpu(), self._omega_1.detach().cpu()]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lambda_matrix(self):
|
def lambda_matrix(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user