diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index d57fc4f..697421a 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -16,6 +16,8 @@ class Vis2DAbstract(pl.Callback): data, title="Prototype Visualization", cmap="viridis", + xlabel="Data dimension 1", + ylabel="Data dimension 2", border=0.1, resolution=100, flatten_data=True, @@ -46,6 +48,8 @@ class Vis2DAbstract(pl.Callback): self.y_train = y self.title = title + self.xlabel = xlabel + self.ylabel = ylabel self.fig = plt.figure(self.title) self.cmap = cmap self.border = border @@ -64,14 +68,12 @@ class Vis2DAbstract(pl.Callback): return False return True - def setup_ax(self, xlabel=None, ylabel=None): + def setup_ax(self): ax = self.fig.gca() ax.cla() ax.set_title(self.title) - if xlabel: - ax.set_xlabel("Data dimension 1") - if ylabel: - ax.set_ylabel("Data dimension 2") + ax.set_xlabel(self.xlabel) + ax.set_ylabel(self.ylabel) if self.axis_off: ax.axis("off") return ax @@ -130,8 +132,7 @@ class VisGLVQ2D(Vis2DAbstract): protos = pl_module.prototypes plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train - ax = self.setup_ax(xlabel="Data dimension 1", - ylabel="Data dimension 2") + ax = self.setup_ax() self.plot_data(ax, x_train, y_train) self.plot_protos(ax, protos, plabels) x = np.vstack((x_train, protos))