Merge branch 'feature/better-hparams' of github.com:si-cim/prototorch_models into feature/better-hparams

This commit is contained in:
Alexander Engelsberger 2022-06-24 15:05:53 +02:00
commit bcf9c6bdb1

View File

@ -169,13 +169,13 @@ class VisGLVQ2D(Vis2DAbstract):
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax() ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None: if x_train is not None:
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]), mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
self.border, self.resolution) self.border, self.resolution)
else: else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution) mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
self.plot_protos(ax, protos, plabels)
_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components) mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input) y_pred = pl_module.predict(mesh_input)