Merge branch 'feature/better-hparams' of github.com:si-cim/prototorch_models into feature/better-hparams
This commit is contained in:
commit
bcf9c6bdb1
@ -169,13 +169,13 @@ class VisGLVQ2D(Vis2DAbstract):
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
ax = self.setup_ax()
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
if x_train is not None:
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
|
||||
self.border, self.resolution)
|
||||
else:
|
||||
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
_components = pl_module.proto_layer._components
|
||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||
y_pred = pl_module.predict(mesh_input)
|
||||
|
Loading…
Reference in New Issue
Block a user