feat(vis): add flag to save visualization frames

This commit is contained in:
Jensun Ravichandran 2022-06-02 19:55:03 +02:00
parent b7992c01db
commit 93b1d0bd46
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921

View File

@ -1,5 +1,6 @@
"""Visualization Callbacks.""" """Visualization Callbacks."""
import os
import warnings import warnings
from typing import Sized from typing import Sized
@ -32,6 +33,10 @@ class Vis2DAbstract(pl.Callback):
tensorboard=False, tensorboard=False,
show_last_only=False, show_last_only=False,
pause_time=0.1, pause_time=0.1,
save=False,
save_dir="./img",
fig_size=(5, 4),
dpi=500,
block=False): block=False):
super().__init__() super().__init__()
@ -75,8 +80,16 @@ class Vis2DAbstract(pl.Callback):
self.tensorboard = tensorboard self.tensorboard = tensorboard
self.show_last_only = show_last_only self.show_last_only = show_last_only
self.pause_time = pause_time self.pause_time = pause_time
self.save = save
self.save_dir = save_dir
self.fig_size = fig_size
self.dpi = dpi
self.block = block self.block = block
if save:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
def precheck(self, trainer): def precheck(self, trainer):
if self.show_last_only: if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1: if trainer.current_epoch != trainer.max_epochs - 1:
@ -125,6 +138,11 @@ class Vis2DAbstract(pl.Callback):
def log_and_display(self, trainer, pl_module): def log_and_display(self, trainer, pl_module):
if self.tensorboard: if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module) self.add_to_tensorboard(trainer, pl_module)
if self.save:
plt.tight_layout()
self.fig.set_size_inches(*self.fig_size, forward=False)
plt.savefig(f"{self.save_dir}/{trainer.current_epoch}.png",
dpi=self.dpi)
if self.show: if self.show:
if not self.block: if not self.block:
plt.pause(self.pause_time) plt.pause(self.pause_time)