[BUGFIX] Fix visualization callbacks bug
This commit is contained in:
parent
e87563e10d
commit
b38acd58a8
@ -298,7 +298,8 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
def precheck(self, trainer):
|
def precheck(self, trainer):
|
||||||
if self.show_last_only:
|
if self.show_last_only:
|
||||||
if trainer.current_epoch != trainer.max_epochs - 1:
|
if trainer.current_epoch != trainer.max_epochs - 1:
|
||||||
return
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def setup_ax(self, xlabel=None, ylabel=None):
|
def setup_ax(self, xlabel=None, ylabel=None):
|
||||||
ax = self.fig.gca()
|
ax = self.fig.gca()
|
||||||
@ -362,7 +363,8 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
|
|
||||||
class VisGLVQ2D(Vis2DAbstract):
|
class VisGLVQ2D(Vis2DAbstract):
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
self.precheck(trainer)
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
@ -386,7 +388,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
self.map_protos = map_protos
|
self.map_protos = map_protos
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
self.precheck(trainer)
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
@ -411,14 +414,15 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
|
|
||||||
class VisCBC2D(Vis2DAbstract):
|
class VisCBC2D(Vis2DAbstract):
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
self.precheck(trainer)
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
protos = pl_module.components
|
protos = pl_module.components
|
||||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||||
ylabel="Data dimension 2")
|
ylabel="Data dimension 2")
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, "w")
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
||||||
@ -431,7 +435,8 @@ class VisCBC2D(Vis2DAbstract):
|
|||||||
|
|
||||||
class VisNG2D(Vis2DAbstract):
|
class VisNG2D(Vis2DAbstract):
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
self.precheck(trainer)
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
|
Loading…
Reference in New Issue
Block a user