fix(vis): plot prototypes after data

This commit is contained in:
Jensun Ravichandran 2022-06-14 19:59:13 +02:00
parent 46ec7b07d7
commit 94730f492b
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921

View File

@ -169,13 +169,13 @@ class VisGLVQ2D(Vis2DAbstract):
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
self.border, self.resolution)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
self.plot_protos(ax, protos, plabels)
_components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)