Merge branch 'feature/better-hparams' of github.com:si-cim/prototorch_models into feature/better-hparams
This commit is contained in:
commit
bcf9c6bdb1
@ -169,13 +169,13 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
ax = self.setup_ax()
|
ax = self.setup_ax()
|
||||||
self.plot_protos(ax, protos, plabels)
|
|
||||||
if x_train is not None:
|
if x_train is not None:
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
|
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
|
||||||
self.border, self.resolution)
|
self.border, self.resolution)
|
||||||
else:
|
else:
|
||||||
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
||||||
|
self.plot_protos(ax, protos, plabels)
|
||||||
_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||||
y_pred = pl_module.predict(mesh_input)
|
y_pred = pl_module.predict(mesh_input)
|
||||||
|
Loading…
Reference in New Issue
Block a user