Close matplotlib figure on train end
This commit is contained in:
parent
538256dcb7
commit
4957e821f6
@ -114,10 +114,10 @@ class Vis2DAbstract(pl.Callback):
|
||||
if not self.block:
|
||||
plt.pause(self.pause_time)
|
||||
else:
|
||||
plt.show(block=True)
|
||||
plt.show(block=self.block)
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
class VisGLVQ2D(Vis2DAbstract):
|
||||
@ -243,17 +243,6 @@ class VisImgComp(Vis2DAbstract):
|
||||
self.dataformats = dataformats
|
||||
self.nrow = nrow
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
if self.show:
|
||||
components = pl_module.components
|
||||
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
|
||||
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
def add_to_tensorboard(self, trainer, pl_module):
|
||||
tb = pl_module.logger.experiment
|
||||
|
||||
@ -276,3 +265,14 @@ class VisImgComp(Vis2DAbstract):
|
||||
img_tensor=grid,
|
||||
global_step=trainer.current_epoch,
|
||||
dataformats=self.dataformats)
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
if self.show:
|
||||
components = pl_module.components
|
||||
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
|
||||
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
Loading…
Reference in New Issue
Block a user