feat: add xlabel and ylabel arguments to visualizers

This commit is contained in:
Jensun Ravichandran 2022-03-09 13:59:19 +01:00
parent e21e6c7e02
commit b31c8cc707
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F

View File

@ -16,6 +16,8 @@ class Vis2DAbstract(pl.Callback):
data, data,
title="Prototype Visualization", title="Prototype Visualization",
cmap="viridis", cmap="viridis",
xlabel="Data dimension 1",
ylabel="Data dimension 2",
border=0.1, border=0.1,
resolution=100, resolution=100,
flatten_data=True, flatten_data=True,
@ -46,6 +48,8 @@ class Vis2DAbstract(pl.Callback):
self.y_train = y self.y_train = y
self.title = title self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
self.fig = plt.figure(self.title) self.fig = plt.figure(self.title)
self.cmap = cmap self.cmap = cmap
self.border = border self.border = border
@ -64,14 +68,12 @@ class Vis2DAbstract(pl.Callback):
return False return False
return True return True
def setup_ax(self, xlabel=None, ylabel=None): def setup_ax(self):
ax = self.fig.gca() ax = self.fig.gca()
ax.cla() ax.cla()
ax.set_title(self.title) ax.set_title(self.title)
if xlabel: ax.set_xlabel(self.xlabel)
ax.set_xlabel("Data dimension 1") ax.set_ylabel(self.ylabel)
if ylabel:
ax.set_ylabel("Data dimension 2")
if self.axis_off: if self.axis_off:
ax.axis("off") ax.axis("off")
return ax return ax
@ -130,8 +132,7 @@ class VisGLVQ2D(Vis2DAbstract):
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1", ax = self.setup_ax()
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))