[BUGFIX] Fix visualization callbacks bug
This commit is contained in:
parent
e87563e10d
commit
b38acd58a8
@ -298,7 +298,8 @@ class Vis2DAbstract(pl.Callback):
|
||||
def precheck(self, trainer):
|
||||
if self.show_last_only:
|
||||
if trainer.current_epoch != trainer.max_epochs - 1:
|
||||
return
|
||||
return False
|
||||
return True
|
||||
|
||||
def setup_ax(self, xlabel=None, ylabel=None):
|
||||
ax = self.fig.gca()
|
||||
@ -362,7 +363,8 @@ class Vis2DAbstract(pl.Callback):
|
||||
|
||||
class VisGLVQ2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.precheck(trainer)
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
@ -386,7 +388,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||
self.map_protos = map_protos
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.precheck(trainer)
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
@ -411,14 +414,15 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||
|
||||
class VisCBC2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.precheck(trainer)
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
protos = pl_module.components
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
self.plot_protos(ax, protos, "w")
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
||||
@ -431,7 +435,8 @@ class VisCBC2D(Vis2DAbstract):
|
||||
|
||||
class VisNG2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.precheck(trainer)
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
protos = pl_module.prototypes
|
||||
|
Loading…
Reference in New Issue
Block a user