Add tensorboard argument to visualization callbacks

This commit is contained in:
Jensun Ravichandran 2021-05-03 13:19:23 +02:00
parent 6dd9b1492c
commit 042b3fcaa2

View File

@ -268,6 +268,7 @@ class Vis2DAbstract(pl.Callback):
title="Prototype Visualization",
cmap="viridis",
border=1,
tensorboard=False,
show_last_only=False,
block=False):
super().__init__()
@ -277,9 +278,17 @@ class Vis2DAbstract(pl.Callback):
self.fig = plt.figure(self.title)
self.cmap = cmap
self.border = border
self.tensorboard = tensorboard
self.show_last_only = show_last_only
self.block = block
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
tb.add_figure(tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False)
class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
@ -317,6 +326,8 @@ class VisGLVQ2D(Vis2DAbstract):
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module)
if not self.block:
plt.pause(0.01)
else:
@ -364,6 +375,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
close=False,
)
if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module)
if not self.block:
plt.pause(0.01)
else:
@ -404,6 +417,8 @@ class VisNG2D(Vis2DAbstract):
"k-",
)
if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module)
if not self.block:
plt.pause(0.01)
else: