feat(vis): add flag to save visualization frames
This commit is contained in:
parent
b7992c01db
commit
93b1d0bd46
@ -1,5 +1,6 @@
|
|||||||
"""Visualization Callbacks."""
|
"""Visualization Callbacks."""
|
||||||
|
|
||||||
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Sized
|
from typing import Sized
|
||||||
|
|
||||||
@ -32,6 +33,10 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
tensorboard=False,
|
tensorboard=False,
|
||||||
show_last_only=False,
|
show_last_only=False,
|
||||||
pause_time=0.1,
|
pause_time=0.1,
|
||||||
|
save=False,
|
||||||
|
save_dir="./img",
|
||||||
|
fig_size=(5, 4),
|
||||||
|
dpi=500,
|
||||||
block=False):
|
block=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -75,8 +80,16 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
self.tensorboard = tensorboard
|
self.tensorboard = tensorboard
|
||||||
self.show_last_only = show_last_only
|
self.show_last_only = show_last_only
|
||||||
self.pause_time = pause_time
|
self.pause_time = pause_time
|
||||||
|
self.save = save
|
||||||
|
self.save_dir = save_dir
|
||||||
|
self.fig_size = fig_size
|
||||||
|
self.dpi = dpi
|
||||||
self.block = block
|
self.block = block
|
||||||
|
|
||||||
|
if save:
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
|
||||||
def precheck(self, trainer):
|
def precheck(self, trainer):
|
||||||
if self.show_last_only:
|
if self.show_last_only:
|
||||||
if trainer.current_epoch != trainer.max_epochs - 1:
|
if trainer.current_epoch != trainer.max_epochs - 1:
|
||||||
@ -125,6 +138,11 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
def log_and_display(self, trainer, pl_module):
|
def log_and_display(self, trainer, pl_module):
|
||||||
if self.tensorboard:
|
if self.tensorboard:
|
||||||
self.add_to_tensorboard(trainer, pl_module)
|
self.add_to_tensorboard(trainer, pl_module)
|
||||||
|
if self.save:
|
||||||
|
plt.tight_layout()
|
||||||
|
self.fig.set_size_inches(*self.fig_size, forward=False)
|
||||||
|
plt.savefig(f"{self.save_dir}/{trainer.current_epoch}.png",
|
||||||
|
dpi=self.dpi)
|
||||||
if self.show:
|
if self.show:
|
||||||
if not self.block:
|
if not self.block:
|
||||||
plt.pause(self.pause_time)
|
plt.pause(self.pause_time)
|
||||||
|
Loading…
Reference in New Issue
Block a user