prototorch_models/prototorch/models/vis.py

456 lines
15 KiB
Python
Raw Normal View History

import os
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText
from prototorch.utils.celluloid import Camera
from prototorch.utils.colors import color_scheme
from prototorch.utils.utils import (gif_from_dir, make_directory,
prettify_string)
from torch.utils.data import DataLoader, Dataset
class VisWeights(pl.Callback):
"""Abstract weight visualization callback."""
def __init__(
self,
data=None,
ignore_last_output_row=False,
label_map=None,
project_mesh=False,
project_protos=False,
voronoi=False,
axis_off=True,
cmap="viridis",
show=True,
display_logs=True,
display_logs_settings={},
pause_time=0.5,
border=1,
resolution=10,
interval=False,
save=False,
snap=True,
save_dir="./img",
make_gif=False,
make_mp4=False,
verbose=True,
dpi=500,
fps=5,
figsize=(11, 8.5), # standard paper in inches
prefix="",
distance_layer_index=-1,
**kwargs,
):
super().__init__(**kwargs)
self.data = data
self.ignore_last_output_row = ignore_last_output_row
self.label_map = label_map
self.voronoi = voronoi
self.axis_off = True
self.project_mesh = project_mesh
self.project_protos = project_protos
self.cmap = cmap
self.show = show
self.display_logs = display_logs
self.display_logs_settings = display_logs_settings
self.pause_time = pause_time
self.border = border
self.resolution = resolution
self.interval = interval
self.save = save
self.snap = snap
self.save_dir = save_dir
self.make_gif = make_gif
self.make_mp4 = make_mp4
self.verbose = verbose
self.dpi = dpi
self.fps = fps
self.figsize = figsize
self.prefix = prefix
self.distance_layer_index = distance_layer_index
self.title = "Weights Visualization"
make_directory(self.save_dir)
def _skip_epoch(self, epoch):
if self.interval:
if epoch % self.interval != 0:
return True
return False
def _clean_and_setup_ax(self):
ax = self.ax
if not self.snap:
ax.cla()
ax.set_title(self.title)
if self.axis_off:
ax.axis("off")
def _savefig(self, fignum, orientation="horizontal"):
figname = f"{self.save_dir}/{self.prefix}{fignum:05d}.png"
figsize = self.figsize
if orientation == "vertical":
figsize = figsize[::-1]
elif orientation == "horizontal":
pass
else:
pass
self.fig.set_size_inches(figsize, forward=False)
self.fig.savefig(figname, dpi=self.dpi)
def _show_and_save(self, epoch):
if self.show:
plt.pause(self.pause_time)
if self.save:
self._savefig(epoch)
if self.snap:
self.camera.snap()
def _display_logs(self, ax, epoch, logs):
if self.display_logs:
settings = dict(
loc="lower right",
# padding between the text and bounding box
pad=0.5,
# padding between the bounding box and the axes
borderpad=1.0,
# https://matplotlib.org/api/text_api.html#matplotlib.text.Text
prop=dict(
fontfamily="monospace",
fontweight="medium",
fontsize=12,
),
)
# Override settings with self.display_logs_settings.
settings = {**settings, **self.display_logs_settings}
log_string = f"""Epoch: {epoch:04d},
val_loss: {logs.get('val_loss', np.nan):.03f},
val_acc: {logs.get('val_acc', np.nan):.03f},
loss: {logs.get('loss', np.nan):.03f},
acc: {logs.get('acc', np.nan):.03f}
"""
log_string = prettify_string(log_string, end="")
# https://matplotlib.org/api/offsetbox_api.html#matplotlib.offsetbox.AnchoredText
anchored_text = AnchoredText(log_string, **settings)
self.ax.add_artist(anchored_text)
def on_train_start(self, trainer, pl_module, logs={}):
self.fig = plt.figure(self.title)
self.fig.set_size_inches(self.figsize, forward=False)
self.ax = self.fig.add_subplot(111)
self.camera = Camera(self.fig)
def on_train_end(self, trainer, pl_module, logs={}):
if self.make_gif:
gif_from_dir(directory=self.save_dir,
prefix=self.prefix,
duration=1.0 / self.fps)
if self.snap and self.make_mp4:
animation = self.camera.animate()
vid = os.path.join(self.save_dir, f"{self.prefix}animation.mp4")
if self.verbose:
print(f"Saving mp4 under {vid}.")
animation.save(vid, fps=self.fps, dpi=self.dpi)
class VisPointProtos(VisWeights):
"""Visualization of prototypes.
.. TODO::
Still in Progress.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.title = "Point Prototypes Visualization"
self.data_scatter_settings = {
"marker": "o",
"s": 30,
"edgecolor": "k",
"cmap": self.cmap,
}
self.protos_scatter_settings = {
"marker": "D",
"s": 50,
"edgecolor": "k",
"cmap": self.cmap,
}
def on_epoch_start(self, trainer, pl_module, logs={}):
epoch = trainer.current_epoch
if self._skip_epoch(epoch):
return True
self._clean_and_setup_ax()
protos = pl_module.prototypes
labels = pl_module.proto_layer.prototype_labels.detach().cpu().numpy()
if self.project_protos:
protos = self.model.projection(protos).numpy()
color_map = color_scheme(n=len(set(labels)),
cmap=self.cmap,
zero_indexed=True)
# TODO Get rid of the assumption y values in [0, num_of_classes]
label_colors = [color_map[l] for l in labels]
if self.data is not None:
x, y = self.data
# TODO Get rid of the assumption y values in [0, num_of_classes]
y_colors = [color_map[l] for l in y]
# x = self.model.projection(x)
if not isinstance(x, np.ndarray):
x = x.numpy()
# Plot data points.
self.ax.scatter(x[:, 0],
x[:, 1],
c=y_colors,
**self.data_scatter_settings)
# Paint decision regions.
if self.voronoi:
border = self.border
resolution = self.resolution
x = np.vstack((x, protos))
x_min, x_max = x[:, 0].min(), x[:, 0].max()
y_min, y_max = x[:, 1].min(), x[:, 1].max()
x_min, x_max = x_min - border, x_max + border
y_min, y_max = y_min - border, y_max + border
try:
xx, yy = np.meshgrid(
np.arange(x_min, x_max, (x_max - x_min) / resolution),
np.arange(y_min, y_max, (x_max - x_min) / resolution),
)
except ValueError as ve:
print(ve)
raise ValueError(f"x_min: {x_min}, x_max: {x_max}. "
f"x_min - x_max is {x_max - x_min}.")
except MemoryError as me:
print(me)
raise ValueError("Too many points. "
"Try reducing the resolution.")
mesh_input = np.c_[xx.ravel(), yy.ravel()]
# Predict mesh labels.
if self.project_mesh:
mesh_input = self.model.projection(mesh_input)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions.
self.ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.ax.set_xlim(left=x_min + 0, right=x_max - 0)
self.ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
# Plot prototypes.
self.ax.scatter(protos[:, 0],
protos[:, 1],
c=label_colors,
**self.protos_scatter_settings)
# self._show_and_save(epoch)
def on_epoch_end(self, trainer, pl_module, logs={}):
epoch = trainer.current_epoch
self._display_logs(self.ax, epoch, logs)
self._show_and_save(epoch)
class Vis2DAbstract(pl.Callback):
def __init__(self,
data,
title="Prototype Visualization",
cmap="viridis",
border=1,
resolution=50,
2021-05-09 18:53:03 +00:00
show_protos=True,
tensorboard=False,
show_last_only=False,
pause_time=0.1,
block=False):
super().__init__()
if isinstance(data, Dataset):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
x = x.view(len(data), -1) # flatten
else:
x, y = data
self.x_train = x
self.y_train = y
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
self.border = border
self.resolution = resolution
2021-05-09 18:53:03 +00:00
self.show_protos = show_protos
self.tensorboard = tensorboard
self.show_last_only = show_last_only
self.pause_time = pause_time
self.block = block
2021-05-09 18:53:03 +00:00
def precheck(self, trainer):
if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1:
return
def setup_ax(self, xlabel=None, ylabel=None):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
if xlabel:
ax.set_xlabel("Data dimension 1")
if ylabel:
ax.set_ylabel("Data dimension 2")
return ax
def get_mesh_input(self, x):
x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border
y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / self.resolution),
np.arange(y_min, y_max, 1 / self.resolution))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
return mesh_input, xx, yy
2021-05-09 18:53:03 +00:00
def plot_data(self, ax, x, y):
ax.scatter(
x[:, 0],
x[:, 1],
c=y,
cmap=self.cmap,
edgecolor="k",
marker="o",
s=30,
)
def plot_protos(self, ax, protos, plabels):
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
tb.add_figure(tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False)
def log_and_display(self, trainer, pl_module):
if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module)
if not self.block:
plt.pause(self.pause_time)
else:
plt.show(block=True)
2021-05-10 12:30:02 +00:00
def on_train_end(self, trainer, pl_module):
plt.show()
class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
2021-05-09 18:53:03 +00:00
self.precheck(trainer)
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
2021-05-09 18:53:03 +00:00
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisSiameseGLVQ2D(Vis2DAbstract):
2021-05-09 18:53:03 +00:00
def __init__(self, *args, map_protos=True, **kwargs):
super().__init__(*args, **kwargs)
self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module):
2021-05-09 18:53:03 +00:00
self.precheck(trainer)
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
2021-05-09 18:53:03 +00:00
if self.map_protos:
protos = pl_module.backbone(torch.Tensor(protos)).detach()
ax = self.setup_ax()
2021-05-09 18:53:03 +00:00
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
else:
mesh_input, xx, yy = self.get_mesh_input(x_train)
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
2021-05-09 18:53:03 +00:00
self.precheck(trainer)
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
2021-05-09 18:53:03 +00:00
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
2021-05-09 18:53:03 +00:00
self.precheck(trainer)
x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
2021-05-09 18:53:03 +00:00
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
# Draw connections
for i in range(len(protos)):
2021-05-09 18:53:03 +00:00
for j in range(i, len(protos)):
if cmat[i][j]:
ax.plot(
[protos[i, 0], protos[j, 0]],
[protos[i, 1], protos[j, 1]],
"k-",
)
self.log_and_display(trainer, pl_module)