Update GMLVQ model
This commit is contained in:
parent
17315ff242
commit
d7972a69e8
@ -191,6 +191,26 @@ class GMLVQ(GLVQ):
|
|||||||
self.hparams.latent_dim,
|
self.hparams.latent_dim,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def omega_matrix(self):
|
||||||
|
return self.omega_layer.weight.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lambda_matrix(self):
|
||||||
|
omega = self.omega_layer.weight
|
||||||
|
lam = omega @ omega.T
|
||||||
|
return lam.detach().cpu()
|
||||||
|
|
||||||
|
def show_lambda(self):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
title = "Lambda matrix"
|
||||||
|
plt.figure(title)
|
||||||
|
plt.title(title)
|
||||||
|
plt.imshow(self.lambda_matrix, cmap="gray")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.colorbar()
|
||||||
|
plt.show(block=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
latent_x = self.omega_layer(x)
|
latent_x = self.omega_layer(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user