feat: add xlabel and ylabel arguments to visualizers
This commit is contained in:
parent
e21e6c7e02
commit
b31c8cc707
@ -16,6 +16,8 @@ class Vis2DAbstract(pl.Callback):
|
||||
data,
|
||||
title="Prototype Visualization",
|
||||
cmap="viridis",
|
||||
xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2",
|
||||
border=0.1,
|
||||
resolution=100,
|
||||
flatten_data=True,
|
||||
@ -46,6 +48,8 @@ class Vis2DAbstract(pl.Callback):
|
||||
self.y_train = y
|
||||
|
||||
self.title = title
|
||||
self.xlabel = xlabel
|
||||
self.ylabel = ylabel
|
||||
self.fig = plt.figure(self.title)
|
||||
self.cmap = cmap
|
||||
self.border = border
|
||||
@ -64,14 +68,12 @@ class Vis2DAbstract(pl.Callback):
|
||||
return False
|
||||
return True
|
||||
|
||||
def setup_ax(self, xlabel=None, ylabel=None):
|
||||
def setup_ax(self):
|
||||
ax = self.fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(self.title)
|
||||
if xlabel:
|
||||
ax.set_xlabel("Data dimension 1")
|
||||
if ylabel:
|
||||
ax.set_ylabel("Data dimension 2")
|
||||
ax.set_xlabel(self.xlabel)
|
||||
ax.set_ylabel(self.ylabel)
|
||||
if self.axis_off:
|
||||
ax.axis("off")
|
||||
return ax
|
||||
@ -130,8 +132,7 @@ class VisGLVQ2D(Vis2DAbstract):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
x = np.vstack((x_train, protos))
|
||||
|
Loading…
Reference in New Issue
Block a user