Update GMLVQ model

This commit is contained in:
Jensun Ravichandran 2021-05-07 15:24:47 +02:00
parent 17315ff242
commit d7972a69e8

View File

@ -191,6 +191,26 @@ class GMLVQ(GLVQ):
self.hparams.latent_dim, self.hparams.latent_dim,
bias=False) bias=False)
@property
def omega_matrix(self):
return self.omega_layer.weight.detach().cpu()
@property
def lambda_matrix(self):
omega = self.omega_layer.weight
lam = omega @ omega.T
return lam.detach().cpu()
def show_lambda(self):
import matplotlib.pyplot as plt
title = "Lambda matrix"
plt.figure(title)
plt.title(title)
plt.imshow(self.lambda_matrix, cmap="gray")
plt.axis("off")
plt.colorbar()
plt.show(block=True)
def forward(self, x): def forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
latent_x = self.omega_layer(x) latent_x = self.omega_layer(x)