Clean visualization callbacks

This commit is contained in:
Jensun Ravichandran 2021-05-17 16:59:22 +02:00
parent 6e7d80be88
commit 77b7b59bad

View File

@ -6,262 +6,11 @@ import torch
import torchvision import torchvision
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText from matplotlib.offsetbox import AnchoredText
from torch.utils.data import DataLoader, Dataset
from prototorch.utils.celluloid import Camera from prototorch.utils.celluloid import Camera
from prototorch.utils.colors import color_scheme from prototorch.utils.colors import color_scheme
from prototorch.utils.utils import (gif_from_dir, make_directory, from prototorch.utils.utils import (gif_from_dir, make_directory,
prettify_string) prettify_string)
from torch.utils.data import DataLoader, Dataset
class VisWeights(pl.Callback):
"""Abstract weight visualization callback."""
def __init__(
self,
data=None,
ignore_last_output_row=False,
label_map=None,
project_mesh=False,
project_protos=False,
voronoi=False,
axis_off=True,
cmap="viridis",
show=True,
display_logs=True,
display_logs_settings={},
pause_time=0.5,
border=1,
resolution=10,
interval=False,
save=False,
snap=True,
save_dir="./img",
make_gif=False,
make_mp4=False,
verbose=True,
dpi=500,
fps=5,
figsize=(11, 8.5), # standard paper in inches
prefix="",
distance_layer_index=-1,
**kwargs,
):
super().__init__(**kwargs)
self.data = data
self.ignore_last_output_row = ignore_last_output_row
self.label_map = label_map
self.voronoi = voronoi
self.axis_off = True
self.project_mesh = project_mesh
self.project_protos = project_protos
self.cmap = cmap
self.show = show
self.display_logs = display_logs
self.display_logs_settings = display_logs_settings
self.pause_time = pause_time
self.border = border
self.resolution = resolution
self.interval = interval
self.save = save
self.snap = snap
self.save_dir = save_dir
self.make_gif = make_gif
self.make_mp4 = make_mp4
self.verbose = verbose
self.dpi = dpi
self.fps = fps
self.figsize = figsize
self.prefix = prefix
self.distance_layer_index = distance_layer_index
self.title = "Weights Visualization"
make_directory(self.save_dir)
def _skip_epoch(self, epoch):
if self.interval:
if epoch % self.interval != 0:
return True
return False
def _clean_and_setup_ax(self):
ax = self.ax
if not self.snap:
ax.cla()
ax.set_title(self.title)
if self.axis_off:
ax.axis("off")
def _savefig(self, fignum, orientation="horizontal"):
figname = f"{self.save_dir}/{self.prefix}{fignum:05d}.png"
figsize = self.figsize
if orientation == "vertical":
figsize = figsize[::-1]
elif orientation == "horizontal":
pass
else:
pass
self.fig.set_size_inches(figsize, forward=False)
self.fig.savefig(figname, dpi=self.dpi)
def _show_and_save(self, epoch):
if self.show:
plt.pause(self.pause_time)
if self.save:
self._savefig(epoch)
if self.snap:
self.camera.snap()
def _display_logs(self, ax, epoch, logs):
if self.display_logs:
settings = dict(
loc="lower right",
# padding between the text and bounding box
pad=0.5,
# padding between the bounding box and the axes
borderpad=1.0,
# https://matplotlib.org/api/text_api.html#matplotlib.text.Text
prop=dict(
fontfamily="monospace",
fontweight="medium",
fontsize=12,
),
)
# Override settings with self.display_logs_settings.
settings = {**settings, **self.display_logs_settings}
log_string = f"""Epoch: {epoch:04d},
val_loss: {logs.get('val_loss', np.nan):.03f},
val_acc: {logs.get('val_acc', np.nan):.03f},
loss: {logs.get('loss', np.nan):.03f},
acc: {logs.get('acc', np.nan):.03f}
"""
log_string = prettify_string(log_string, end="")
# https://matplotlib.org/api/offsetbox_api.html#matplotlib.offsetbox.AnchoredText
anchored_text = AnchoredText(log_string, **settings)
self.ax.add_artist(anchored_text)
def on_train_start(self, trainer, pl_module, logs={}):
self.fig = plt.figure(self.title)
self.fig.set_size_inches(self.figsize, forward=False)
self.ax = self.fig.add_subplot(111)
self.camera = Camera(self.fig)
def on_train_end(self, trainer, pl_module, logs={}):
if self.make_gif:
gif_from_dir(directory=self.save_dir,
prefix=self.prefix,
duration=1.0 / self.fps)
if self.snap and self.make_mp4:
animation = self.camera.animate()
vid = os.path.join(self.save_dir, f"{self.prefix}animation.mp4")
if self.verbose:
print(f"Saving mp4 under {vid}.")
animation.save(vid, fps=self.fps, dpi=self.dpi)
class VisPointProtos(VisWeights):
"""Visualization of prototypes.
.. TODO::
Still in Progress.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.title = "Point Prototypes Visualization"
self.data_scatter_settings = {
"marker": "o",
"s": 30,
"edgecolor": "k",
"cmap": self.cmap,
}
self.protos_scatter_settings = {
"marker": "D",
"s": 50,
"edgecolor": "k",
"cmap": self.cmap,
}
def on_epoch_start(self, trainer, pl_module, logs={}):
epoch = trainer.current_epoch
if self._skip_epoch(epoch):
return True
self._clean_and_setup_ax()
protos = pl_module.prototypes
labels = pl_module.proto_layer.prototype_labels.detach().cpu().numpy()
if self.project_protos:
protos = self.model.projection(protos).numpy()
color_map = color_scheme(n=len(set(labels)),
cmap=self.cmap,
zero_indexed=True)
# TODO Get rid of the assumption y values in [0, num_of_classes]
label_colors = [color_map[l] for l in labels]
if self.data is not None:
x, y = self.data
# TODO Get rid of the assumption y values in [0, num_of_classes]
y_colors = [color_map[l] for l in y]
# x = self.model.projection(x)
if not isinstance(x, np.ndarray):
x = x.numpy()
# Plot data points.
self.ax.scatter(x[:, 0],
x[:, 1],
c=y_colors,
**self.data_scatter_settings)
# Paint decision regions.
if self.voronoi:
border = self.border
resolution = self.resolution
x = np.vstack((x, protos))
x_min, x_max = x[:, 0].min(), x[:, 0].max()
y_min, y_max = x[:, 1].min(), x[:, 1].max()
x_min, x_max = x_min - border, x_max + border
y_min, y_max = y_min - border, y_max + border
try:
xx, yy = np.meshgrid(
np.arange(x_min, x_max, (x_max - x_min) / resolution),
np.arange(y_min, y_max, (x_max - x_min) / resolution),
)
except ValueError as ve:
print(ve)
raise ValueError(f"x_min: {x_min}, x_max: {x_max}. "
f"x_min - x_max is {x_max - x_min}.")
except MemoryError as me:
print(me)
raise ValueError("Too many points. "
"Try reducing the resolution.")
mesh_input = np.c_[xx.ravel(), yy.ravel()]
# Predict mesh labels.
if self.project_mesh:
mesh_input = self.model.projection(mesh_input)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions.
self.ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.ax.set_xlim(left=x_min + 0, right=x_max - 0)
self.ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
# Plot prototypes.
self.ax.scatter(protos[:, 0],
protos[:, 1],
c=label_colors,
**self.protos_scatter_settings)
# self._show_and_save(epoch)
def on_epoch_end(self, trainer, pl_module, logs={}):
epoch = trainer.current_epoch
self._display_logs(self.ax, epoch, logs)
self._show_and_save(epoch)
class Vis2DAbstract(pl.Callback): class Vis2DAbstract(pl.Callback):
@ -269,8 +18,9 @@ class Vis2DAbstract(pl.Callback):
data, data,
title="Prototype Visualization", title="Prototype Visualization",
cmap="viridis", cmap="viridis",
border=1, border=0.1,
resolution=50, resolution=100,
axis_off=False,
show_protos=True, show_protos=True,
show=True, show=True,
tensorboard=False, tensorboard=False,
@ -292,6 +42,7 @@ class Vis2DAbstract(pl.Callback):
self.cmap = cmap self.cmap = cmap
self.border = border self.border = border
self.resolution = resolution self.resolution = resolution
self.axis_off = axis_off
self.show_protos = show_protos self.show_protos = show_protos
self.show = show self.show = show
self.tensorboard = tensorboard self.tensorboard = tensorboard
@ -309,18 +60,21 @@ class Vis2DAbstract(pl.Callback):
ax = self.fig.gca() ax = self.fig.gca()
ax.cla() ax.cla()
ax.set_title(self.title) ax.set_title(self.title)
ax.axis("off")
if xlabel: if xlabel:
ax.set_xlabel("Data dimension 1") ax.set_xlabel("Data dimension 1")
if ylabel: if ylabel:
ax.set_ylabel("Data dimension 2") ax.set_ylabel("Data dimension 2")
if self.axis_off:
ax.axis("off")
return ax return ax
def get_mesh_input(self, x): def get_mesh_input(self, x):
x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border x_shift = self.border * np.ptp(x[:, 0])
y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border y_shift = self.border * np.ptp(x[:, 1])
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / self.resolution), x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
np.arange(y_min, y_max, 1 / self.resolution)) y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution),
np.linspace(y_min, y_max, self.resolution))
mesh_input = np.c_[xx.ravel(), yy.ravel()] mesh_input = np.c_[xx.ravel(), yy.ravel()]
return mesh_input, xx, yy return mesh_input, xx, yy
@ -381,8 +135,8 @@ class VisGLVQ2D(Vis2DAbstract):
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = self.get_mesh_input(x)
_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
y_pred = pl_module.predict( mesh_input = torch.Tensor(mesh_input).type_as(_components)
torch.Tensor(mesh_input).type_as(_components)) y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape) y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
@ -401,11 +155,14 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone( device = pl_module.device
torch.Tensor(x_train).to(pl_module.device)).cpu().detach() with torch.no_grad():
x_train = pl_module.backbone(torch.Tensor(x_train).to(device))
x_train = x_train.cpu().detach()
if self.map_protos: if self.map_protos:
protos = pl_module.backbone( with torch.no_grad():
torch.Tensor(protos).to(pl_module.device)).cpu().detach() protos = pl_module.backbone(torch.Tensor(protos).to(device))
protos = protos.cpu().detach()
ax = self.setup_ax() ax = self.setup_ax()
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
if self.show_protos: if self.show_protos:
@ -415,8 +172,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
else: else:
mesh_input, xx, yy = self.get_mesh_input(x_train) mesh_input, xx, yy = self.get_mesh_input(x_train)
_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
y_pred = pl_module.predict_latent( mesh_input = torch.Tensor(mesh_input).type_as(_components)
torch.Tensor(mesh_input).type_as(_components)) y_pred = pl_module.predict_latent(mesh_input,
map_protos=self.map_protos)
y_pred = y_pred.cpu().reshape(xx.shape) y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)