From d7972a69e8592c44916c5a5d7366d523b7820107 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:24:47 +0200 Subject: [PATCH] Update GMLVQ model --- prototorch/models/glvq.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 868d4c9..6fe76d3 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -191,6 +191,26 @@ class GMLVQ(GLVQ): self.hparams.latent_dim, bias=False) + @property + def omega_matrix(self): + return self.omega_layer.weight.detach().cpu() + + @property + def lambda_matrix(self): + omega = self.omega_layer.weight + lam = omega @ omega.T + return lam.detach().cpu() + + def show_lambda(self): + import matplotlib.pyplot as plt + title = "Lambda matrix" + plt.figure(title) + plt.title(title) + plt.imshow(self.lambda_matrix, cmap="gray") + plt.axis("off") + plt.colorbar() + plt.show(block=True) + def forward(self, x): protos, _ = self.proto_layer() latent_x = self.omega_layer(x)