diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 07cba8f..d788e0e 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -298,7 +298,8 @@ class Vis2DAbstract(pl.Callback): def precheck(self, trainer): if self.show_last_only: if trainer.current_epoch != trainer.max_epochs - 1: - return + return False + return True def setup_ax(self, xlabel=None, ylabel=None): ax = self.fig.gca() @@ -362,7 +363,8 @@ class Vis2DAbstract(pl.Callback): class VisGLVQ2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): - self.precheck(trainer) + if not self.precheck(trainer): + return True protos = pl_module.prototypes plabels = pl_module.prototype_labels @@ -386,7 +388,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract): self.map_protos = map_protos def on_epoch_end(self, trainer, pl_module): - self.precheck(trainer) + if not self.precheck(trainer): + return True protos = pl_module.prototypes plabels = pl_module.prototype_labels @@ -411,14 +414,15 @@ class VisSiameseGLVQ2D(Vis2DAbstract): class VisCBC2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): - self.precheck(trainer) + if not self.precheck(trainer): + return True x_train, y_train = self.x_train, self.y_train protos = pl_module.components ax = self.setup_ax(xlabel="Data dimension 1", ylabel="Data dimension 2") self.plot_data(ax, x_train, y_train) - self.plot_protos(ax, protos, plabels) + self.plot_protos(ax, protos, "w") x = np.vstack((x_train, protos)) mesh_input, xx, yy = self.get_mesh_input(x) y_pred = pl_module.predict(torch.Tensor(mesh_input)) @@ -431,7 +435,8 @@ class VisCBC2D(Vis2DAbstract): class VisNG2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): - self.precheck(trainer) + if not self.precheck(trainer): + return True x_train, y_train = self.x_train, self.y_train protos = pl_module.prototypes