diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 308bfae..bd8c3be 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -200,8 +200,8 @@ class GMLVQ(GLVQ): @property def lambda_matrix(self): - omega = self.omega_layer.weight - lam = omega @ omega.T + omega = self.omega_layer.weight # (latent_dim, input_dim) + lam = omega.T @ omega return lam.detach().cpu() def show_lambda(self):