diff --git a/prototorch/models/callbacks/visualization.py b/prototorch/models/callbacks/visualization.py index 8692eff..98a1889 100644 --- a/prototorch/models/callbacks/visualization.py +++ b/prototorch/models/callbacks/visualization.py @@ -268,6 +268,7 @@ class Vis2DAbstract(pl.Callback): title="Prototype Visualization", cmap="viridis", border=1, + tensorboard=False, show_last_only=False, block=False): super().__init__() @@ -277,9 +278,17 @@ class Vis2DAbstract(pl.Callback): self.fig = plt.figure(self.title) self.cmap = cmap self.border = border + self.tensorboard = tensorboard self.show_last_only = show_last_only self.block = block + def add_to_tensorboard(self, trainer, pl_module): + tb = pl_module.logger.experiment + tb.add_figure(tag=f"{self.title}", + figure=self.fig, + global_step=trainer.current_epoch, + close=False) + class VisGLVQ2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): @@ -317,6 +326,8 @@ class VisGLVQ2D(Vis2DAbstract): ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.set_xlim(left=x_min + 0, right=x_max - 0) ax.set_ylim(bottom=y_min + 0, top=y_max - 0) + if self.tensorboard: + self.add_to_tensorboard(trainer, pl_module) if not self.block: plt.pause(0.01) else: @@ -364,6 +375,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract): close=False, ) + if self.tensorboard: + self.add_to_tensorboard(trainer, pl_module) if not self.block: plt.pause(0.01) else: @@ -404,6 +417,8 @@ class VisNG2D(Vis2DAbstract): "k-", ) + if self.tensorboard: + self.add_to_tensorboard(trainer, pl_module) if not self.block: plt.pause(0.01) else: