Close matplotlib figure on train end

This commit is contained in:
Jensun Ravichandran 2021-05-18 10:13:22 +02:00
parent 538256dcb7
commit 4957e821f6

View File

@ -114,10 +114,10 @@ class Vis2DAbstract(pl.Callback):
if not self.block: if not self.block:
plt.pause(self.pause_time) plt.pause(self.pause_time)
else: else:
plt.show(block=True) plt.show(block=self.block)
def on_train_end(self, trainer, pl_module): def on_train_end(self, trainer, pl_module):
plt.show() plt.close()
class VisGLVQ2D(Vis2DAbstract): class VisGLVQ2D(Vis2DAbstract):
@ -243,17 +243,6 @@ class VisImgComp(Vis2DAbstract):
self.dataformats = dataformats self.dataformats = dataformats
self.nrow = nrow self.nrow = nrow
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
if self.show:
components = pl_module.components
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
self.log_and_display(trainer, pl_module)
def add_to_tensorboard(self, trainer, pl_module): def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment tb = pl_module.logger.experiment
@ -276,3 +265,14 @@ class VisImgComp(Vis2DAbstract):
img_tensor=grid, img_tensor=grid,
global_step=trainer.current_epoch, global_step=trainer.current_epoch,
dataformats=self.dataformats) dataformats=self.dataformats)
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
if self.show:
components = pl_module.components
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
self.log_and_display(trainer, pl_module)