From 1786031b4e339d4a3f2da9d509c1b471d4534cd8 Mon Sep 17 00:00:00 2001 From: julius Date: Mon, 6 Nov 2023 16:32:57 +0100 Subject: [PATCH] adjust omega_matrix property --- src/prototorch/models/glvq.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 1ba9f06..4de706b 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -238,7 +238,8 @@ class GMLMLVQ(GLVQ): """ # Parameters - _omega: torch.Tensor + _omega_0: torch.Tensor + _omega_1: torch.Tensor def __init__(self, hparams, **kwargs): distance_fn = kwargs.pop("distance_fn", ML_omega_distance) @@ -255,8 +256,8 @@ class GMLMLVQ(GLVQ): self.mask_1 = masks[1] @property - def omega_matrix(self): - return self._omega.detach().cpu() + def omega_matrices(self): + return [self._omega_0.detach().cpu(), self._omega_1.detach().cpu()] @property def lambda_matrix(self):