Add tensorboard argument to visualization callbacks
This commit is contained in:
parent
6dd9b1492c
commit
042b3fcaa2
@ -268,6 +268,7 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
title="Prototype Visualization",
|
title="Prototype Visualization",
|
||||||
cmap="viridis",
|
cmap="viridis",
|
||||||
border=1,
|
border=1,
|
||||||
|
tensorboard=False,
|
||||||
show_last_only=False,
|
show_last_only=False,
|
||||||
block=False):
|
block=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -277,9 +278,17 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
self.fig = plt.figure(self.title)
|
self.fig = plt.figure(self.title)
|
||||||
self.cmap = cmap
|
self.cmap = cmap
|
||||||
self.border = border
|
self.border = border
|
||||||
|
self.tensorboard = tensorboard
|
||||||
self.show_last_only = show_last_only
|
self.show_last_only = show_last_only
|
||||||
self.block = block
|
self.block = block
|
||||||
|
|
||||||
|
def add_to_tensorboard(self, trainer, pl_module):
|
||||||
|
tb = pl_module.logger.experiment
|
||||||
|
tb.add_figure(tag=f"{self.title}",
|
||||||
|
figure=self.fig,
|
||||||
|
global_step=trainer.current_epoch,
|
||||||
|
close=False)
|
||||||
|
|
||||||
|
|
||||||
class VisGLVQ2D(Vis2DAbstract):
|
class VisGLVQ2D(Vis2DAbstract):
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
@ -317,6 +326,8 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||||
|
if self.tensorboard:
|
||||||
|
self.add_to_tensorboard(trainer, pl_module)
|
||||||
if not self.block:
|
if not self.block:
|
||||||
plt.pause(0.01)
|
plt.pause(0.01)
|
||||||
else:
|
else:
|
||||||
@ -364,6 +375,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
close=False,
|
close=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.tensorboard:
|
||||||
|
self.add_to_tensorboard(trainer, pl_module)
|
||||||
if not self.block:
|
if not self.block:
|
||||||
plt.pause(0.01)
|
plt.pause(0.01)
|
||||||
else:
|
else:
|
||||||
@ -404,6 +417,8 @@ class VisNG2D(Vis2DAbstract):
|
|||||||
"k-",
|
"k-",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.tensorboard:
|
||||||
|
self.add_to_tensorboard(trainer, pl_module)
|
||||||
if not self.block:
|
if not self.block:
|
||||||
plt.pause(0.01)
|
plt.pause(0.01)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user