[BUGFIX] Fix visualization callbacks bug

This commit is contained in:
Jensun Ravichandran 2021-05-11 16:09:27 +02:00
parent e87563e10d
commit b38acd58a8

View File

@ -298,7 +298,8 @@ class Vis2DAbstract(pl.Callback):
def precheck(self, trainer): def precheck(self, trainer):
if self.show_last_only: if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1: if trainer.current_epoch != trainer.max_epochs - 1:
return return False
return True
def setup_ax(self, xlabel=None, ylabel=None): def setup_ax(self, xlabel=None, ylabel=None):
ax = self.fig.gca() ax = self.fig.gca()
@ -362,7 +363,8 @@ class Vis2DAbstract(pl.Callback):
class VisGLVQ2D(Vis2DAbstract): class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer) if not self.precheck(trainer):
return True
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
@ -386,7 +388,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
self.map_protos = map_protos self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer) if not self.precheck(trainer):
return True
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
@ -411,14 +414,15 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
class VisCBC2D(Vis2DAbstract): class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer) if not self.precheck(trainer):
return True
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
protos = pl_module.components protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1", ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2") ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input)) y_pred = pl_module.predict(torch.Tensor(mesh_input))
@ -431,7 +435,8 @@ class VisCBC2D(Vis2DAbstract):
class VisNG2D(Vis2DAbstract): class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer) if not self.precheck(trainer):
return True
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes protos = pl_module.prototypes