From 94730f492b81055fd49286f0c5309eb78b6675c0 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 14 Jun 2022 19:59:13 +0200 Subject: [PATCH] fix(vis): plot prototypes after data --- prototorch/models/vis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 7f0623a..26e15c0 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -169,13 +169,13 @@ class VisGLVQ2D(Vis2DAbstract): plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train ax = self.setup_ax() - self.plot_protos(ax, protos, plabels) if x_train is not None: self.plot_data(ax, x_train, y_train) mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]), self.border, self.resolution) else: mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution) + self.plot_protos(ax, protos, plabels) _components = pl_module.proto_layer._components mesh_input = torch.from_numpy(mesh_input).type_as(_components) y_pred = pl_module.predict(mesh_input)