feat: add xlabel and ylabel arguments to visualizers

This commit is contained in:
Jensun Ravichandran 2022-03-09 13:59:19 +01:00
parent e21e6c7e02
commit b31c8cc707
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F

View File

@ -16,6 +16,8 @@ class Vis2DAbstract(pl.Callback):
data,
title="Prototype Visualization",
cmap="viridis",
xlabel="Data dimension 1",
ylabel="Data dimension 2",
border=0.1,
resolution=100,
flatten_data=True,
@ -46,6 +48,8 @@ class Vis2DAbstract(pl.Callback):
self.y_train = y
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
self.fig = plt.figure(self.title)
self.cmap = cmap
self.border = border
@ -64,14 +68,12 @@ class Vis2DAbstract(pl.Callback):
return False
return True
def setup_ax(self, xlabel=None, ylabel=None):
def setup_ax(self):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
if xlabel:
ax.set_xlabel("Data dimension 1")
if ylabel:
ax.set_ylabel("Data dimension 2")
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.ylabel)
if self.axis_off:
ax.axis("off")
return ax
@ -130,8 +132,7 @@ class VisGLVQ2D(Vis2DAbstract):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))