feat: add xlabel and ylabel arguments to visualizers
This commit is contained in:
parent
e21e6c7e02
commit
b31c8cc707
@ -16,6 +16,8 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
data,
|
data,
|
||||||
title="Prototype Visualization",
|
title="Prototype Visualization",
|
||||||
cmap="viridis",
|
cmap="viridis",
|
||||||
|
xlabel="Data dimension 1",
|
||||||
|
ylabel="Data dimension 2",
|
||||||
border=0.1,
|
border=0.1,
|
||||||
resolution=100,
|
resolution=100,
|
||||||
flatten_data=True,
|
flatten_data=True,
|
||||||
@ -46,6 +48,8 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
self.y_train = y
|
self.y_train = y
|
||||||
|
|
||||||
self.title = title
|
self.title = title
|
||||||
|
self.xlabel = xlabel
|
||||||
|
self.ylabel = ylabel
|
||||||
self.fig = plt.figure(self.title)
|
self.fig = plt.figure(self.title)
|
||||||
self.cmap = cmap
|
self.cmap = cmap
|
||||||
self.border = border
|
self.border = border
|
||||||
@ -64,14 +68,12 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setup_ax(self, xlabel=None, ylabel=None):
|
def setup_ax(self):
|
||||||
ax = self.fig.gca()
|
ax = self.fig.gca()
|
||||||
ax.cla()
|
ax.cla()
|
||||||
ax.set_title(self.title)
|
ax.set_title(self.title)
|
||||||
if xlabel:
|
ax.set_xlabel(self.xlabel)
|
||||||
ax.set_xlabel("Data dimension 1")
|
ax.set_ylabel(self.ylabel)
|
||||||
if ylabel:
|
|
||||||
ax.set_ylabel("Data dimension 2")
|
|
||||||
if self.axis_off:
|
if self.axis_off:
|
||||||
ax.axis("off")
|
ax.axis("off")
|
||||||
return ax
|
return ax
|
||||||
@ -130,8 +132,7 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
ax = self.setup_ax()
|
||||||
ylabel="Data dimension 2")
|
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
|
Loading…
Reference in New Issue
Block a user