Update GMLVQ model
This commit is contained in:
parent
17315ff242
commit
d7972a69e8
@ -191,6 +191,26 @@ class GMLVQ(GLVQ):
|
||||
self.hparams.latent_dim,
|
||||
bias=False)
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
return self.omega_layer.weight.detach().cpu()
|
||||
|
||||
@property
|
||||
def lambda_matrix(self):
|
||||
omega = self.omega_layer.weight
|
||||
lam = omega @ omega.T
|
||||
return lam.detach().cpu()
|
||||
|
||||
def show_lambda(self):
|
||||
import matplotlib.pyplot as plt
|
||||
title = "Lambda matrix"
|
||||
plt.figure(title)
|
||||
plt.title(title)
|
||||
plt.imshow(self.lambda_matrix, cmap="gray")
|
||||
plt.axis("off")
|
||||
plt.colorbar()
|
||||
plt.show(block=True)
|
||||
|
||||
def forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
latent_x = self.omega_layer(x)
|
||||
|
Loading…
Reference in New Issue
Block a user