456 lines
15 KiB
Python
456 lines
15 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from matplotlib import pyplot as plt
|
|
from matplotlib.offsetbox import AnchoredText
|
|
from prototorch.utils.celluloid import Camera
|
|
from prototorch.utils.colors import color_scheme
|
|
from prototorch.utils.utils import (gif_from_dir, make_directory,
|
|
prettify_string)
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
|
|
class VisWeights(pl.Callback):
|
|
"""Abstract weight visualization callback."""
|
|
def __init__(
|
|
self,
|
|
data=None,
|
|
ignore_last_output_row=False,
|
|
label_map=None,
|
|
project_mesh=False,
|
|
project_protos=False,
|
|
voronoi=False,
|
|
axis_off=True,
|
|
cmap="viridis",
|
|
show=True,
|
|
display_logs=True,
|
|
display_logs_settings={},
|
|
pause_time=0.5,
|
|
border=1,
|
|
resolution=10,
|
|
interval=False,
|
|
save=False,
|
|
snap=True,
|
|
save_dir="./img",
|
|
make_gif=False,
|
|
make_mp4=False,
|
|
verbose=True,
|
|
dpi=500,
|
|
fps=5,
|
|
figsize=(11, 8.5), # standard paper in inches
|
|
prefix="",
|
|
distance_layer_index=-1,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.data = data
|
|
self.ignore_last_output_row = ignore_last_output_row
|
|
self.label_map = label_map
|
|
self.voronoi = voronoi
|
|
self.axis_off = True
|
|
self.project_mesh = project_mesh
|
|
self.project_protos = project_protos
|
|
self.cmap = cmap
|
|
self.show = show
|
|
self.display_logs = display_logs
|
|
self.display_logs_settings = display_logs_settings
|
|
self.pause_time = pause_time
|
|
self.border = border
|
|
self.resolution = resolution
|
|
self.interval = interval
|
|
self.save = save
|
|
self.snap = snap
|
|
self.save_dir = save_dir
|
|
self.make_gif = make_gif
|
|
self.make_mp4 = make_mp4
|
|
self.verbose = verbose
|
|
self.dpi = dpi
|
|
self.fps = fps
|
|
self.figsize = figsize
|
|
self.prefix = prefix
|
|
self.distance_layer_index = distance_layer_index
|
|
self.title = "Weights Visualization"
|
|
make_directory(self.save_dir)
|
|
|
|
def _skip_epoch(self, epoch):
|
|
if self.interval:
|
|
if epoch % self.interval != 0:
|
|
return True
|
|
return False
|
|
|
|
def _clean_and_setup_ax(self):
|
|
ax = self.ax
|
|
if not self.snap:
|
|
ax.cla()
|
|
ax.set_title(self.title)
|
|
if self.axis_off:
|
|
ax.axis("off")
|
|
|
|
def _savefig(self, fignum, orientation="horizontal"):
|
|
figname = f"{self.save_dir}/{self.prefix}{fignum:05d}.png"
|
|
figsize = self.figsize
|
|
if orientation == "vertical":
|
|
figsize = figsize[::-1]
|
|
elif orientation == "horizontal":
|
|
pass
|
|
else:
|
|
pass
|
|
self.fig.set_size_inches(figsize, forward=False)
|
|
self.fig.savefig(figname, dpi=self.dpi)
|
|
|
|
def _show_and_save(self, epoch):
|
|
if self.show:
|
|
plt.pause(self.pause_time)
|
|
if self.save:
|
|
self._savefig(epoch)
|
|
if self.snap:
|
|
self.camera.snap()
|
|
|
|
def _display_logs(self, ax, epoch, logs):
|
|
if self.display_logs:
|
|
settings = dict(
|
|
loc="lower right",
|
|
# padding between the text and bounding box
|
|
pad=0.5,
|
|
# padding between the bounding box and the axes
|
|
borderpad=1.0,
|
|
# https://matplotlib.org/api/text_api.html#matplotlib.text.Text
|
|
prop=dict(
|
|
fontfamily="monospace",
|
|
fontweight="medium",
|
|
fontsize=12,
|
|
),
|
|
)
|
|
|
|
# Override settings with self.display_logs_settings.
|
|
settings = {**settings, **self.display_logs_settings}
|
|
|
|
log_string = f"""Epoch: {epoch:04d},
|
|
val_loss: {logs.get('val_loss', np.nan):.03f},
|
|
val_acc: {logs.get('val_acc', np.nan):.03f},
|
|
loss: {logs.get('loss', np.nan):.03f},
|
|
acc: {logs.get('acc', np.nan):.03f}
|
|
"""
|
|
log_string = prettify_string(log_string, end="")
|
|
# https://matplotlib.org/api/offsetbox_api.html#matplotlib.offsetbox.AnchoredText
|
|
anchored_text = AnchoredText(log_string, **settings)
|
|
self.ax.add_artist(anchored_text)
|
|
|
|
def on_train_start(self, trainer, pl_module, logs={}):
|
|
self.fig = plt.figure(self.title)
|
|
self.fig.set_size_inches(self.figsize, forward=False)
|
|
self.ax = self.fig.add_subplot(111)
|
|
self.camera = Camera(self.fig)
|
|
|
|
def on_train_end(self, trainer, pl_module, logs={}):
|
|
if self.make_gif:
|
|
gif_from_dir(directory=self.save_dir,
|
|
prefix=self.prefix,
|
|
duration=1.0 / self.fps)
|
|
if self.snap and self.make_mp4:
|
|
animation = self.camera.animate()
|
|
vid = os.path.join(self.save_dir, f"{self.prefix}animation.mp4")
|
|
if self.verbose:
|
|
print(f"Saving mp4 under {vid}.")
|
|
animation.save(vid, fps=self.fps, dpi=self.dpi)
|
|
|
|
|
|
class VisPointProtos(VisWeights):
|
|
"""Visualization of prototypes.
|
|
.. TODO::
|
|
Still in Progress.
|
|
"""
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.title = "Point Prototypes Visualization"
|
|
self.data_scatter_settings = {
|
|
"marker": "o",
|
|
"s": 30,
|
|
"edgecolor": "k",
|
|
"cmap": self.cmap,
|
|
}
|
|
self.protos_scatter_settings = {
|
|
"marker": "D",
|
|
"s": 50,
|
|
"edgecolor": "k",
|
|
"cmap": self.cmap,
|
|
}
|
|
|
|
def on_epoch_start(self, trainer, pl_module, logs={}):
|
|
epoch = trainer.current_epoch
|
|
if self._skip_epoch(epoch):
|
|
return True
|
|
|
|
self._clean_and_setup_ax()
|
|
|
|
protos = pl_module.prototypes
|
|
labels = pl_module.proto_layer.prototype_labels.detach().cpu().numpy()
|
|
|
|
if self.project_protos:
|
|
protos = self.model.projection(protos).numpy()
|
|
|
|
color_map = color_scheme(n=len(set(labels)),
|
|
cmap=self.cmap,
|
|
zero_indexed=True)
|
|
# TODO Get rid of the assumption y values in [0, num_of_classes]
|
|
label_colors = [color_map[l] for l in labels]
|
|
|
|
if self.data is not None:
|
|
x, y = self.data
|
|
# TODO Get rid of the assumption y values in [0, num_of_classes]
|
|
y_colors = [color_map[l] for l in y]
|
|
# x = self.model.projection(x)
|
|
if not isinstance(x, np.ndarray):
|
|
x = x.numpy()
|
|
|
|
# Plot data points.
|
|
self.ax.scatter(x[:, 0],
|
|
x[:, 1],
|
|
c=y_colors,
|
|
**self.data_scatter_settings)
|
|
|
|
# Paint decision regions.
|
|
if self.voronoi:
|
|
border = self.border
|
|
resolution = self.resolution
|
|
x = np.vstack((x, protos))
|
|
x_min, x_max = x[:, 0].min(), x[:, 0].max()
|
|
y_min, y_max = x[:, 1].min(), x[:, 1].max()
|
|
x_min, x_max = x_min - border, x_max + border
|
|
y_min, y_max = y_min - border, y_max + border
|
|
try:
|
|
xx, yy = np.meshgrid(
|
|
np.arange(x_min, x_max, (x_max - x_min) / resolution),
|
|
np.arange(y_min, y_max, (x_max - x_min) / resolution),
|
|
)
|
|
except ValueError as ve:
|
|
print(ve)
|
|
raise ValueError(f"x_min: {x_min}, x_max: {x_max}. "
|
|
f"x_min - x_max is {x_max - x_min}.")
|
|
except MemoryError as me:
|
|
print(me)
|
|
raise ValueError("Too many points. "
|
|
"Try reducing the resolution.")
|
|
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
|
|
# Predict mesh labels.
|
|
if self.project_mesh:
|
|
mesh_input = self.model.projection(mesh_input)
|
|
|
|
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
|
y_pred = y_pred.reshape(xx.shape)
|
|
|
|
# Plot voronoi regions.
|
|
self.ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
|
|
|
self.ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
|
self.ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
|
|
|
# Plot prototypes.
|
|
self.ax.scatter(protos[:, 0],
|
|
protos[:, 1],
|
|
c=label_colors,
|
|
**self.protos_scatter_settings)
|
|
|
|
# self._show_and_save(epoch)
|
|
|
|
def on_epoch_end(self, trainer, pl_module, logs={}):
|
|
epoch = trainer.current_epoch
|
|
self._display_logs(self.ax, epoch, logs)
|
|
self._show_and_save(epoch)
|
|
|
|
|
|
class Vis2DAbstract(pl.Callback):
|
|
def __init__(self,
|
|
data,
|
|
title="Prototype Visualization",
|
|
cmap="viridis",
|
|
border=1,
|
|
resolution=50,
|
|
show_protos=True,
|
|
tensorboard=False,
|
|
show_last_only=False,
|
|
pause_time=0.1,
|
|
block=False):
|
|
super().__init__()
|
|
|
|
if isinstance(data, Dataset):
|
|
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
|
x = x.view(len(data), -1) # flatten
|
|
else:
|
|
x, y = data
|
|
self.x_train = x
|
|
self.y_train = y
|
|
|
|
self.title = title
|
|
self.fig = plt.figure(self.title)
|
|
self.cmap = cmap
|
|
self.border = border
|
|
self.resolution = resolution
|
|
self.show_protos = show_protos
|
|
self.tensorboard = tensorboard
|
|
self.show_last_only = show_last_only
|
|
self.pause_time = pause_time
|
|
self.block = block
|
|
|
|
def precheck(self, trainer):
|
|
if self.show_last_only:
|
|
if trainer.current_epoch != trainer.max_epochs - 1:
|
|
return
|
|
|
|
def setup_ax(self, xlabel=None, ylabel=None):
|
|
ax = self.fig.gca()
|
|
ax.cla()
|
|
ax.set_title(self.title)
|
|
ax.axis("off")
|
|
if xlabel:
|
|
ax.set_xlabel("Data dimension 1")
|
|
if ylabel:
|
|
ax.set_ylabel("Data dimension 2")
|
|
return ax
|
|
|
|
def get_mesh_input(self, x):
|
|
x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border
|
|
y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border
|
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / self.resolution),
|
|
np.arange(y_min, y_max, 1 / self.resolution))
|
|
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
return mesh_input, xx, yy
|
|
|
|
def plot_data(self, ax, x, y):
|
|
ax.scatter(
|
|
x[:, 0],
|
|
x[:, 1],
|
|
c=y,
|
|
cmap=self.cmap,
|
|
edgecolor="k",
|
|
marker="o",
|
|
s=30,
|
|
)
|
|
|
|
def plot_protos(self, ax, protos, plabels):
|
|
ax.scatter(
|
|
protos[:, 0],
|
|
protos[:, 1],
|
|
c=plabels,
|
|
cmap=self.cmap,
|
|
edgecolor="k",
|
|
marker="D",
|
|
s=50,
|
|
)
|
|
|
|
def add_to_tensorboard(self, trainer, pl_module):
|
|
tb = pl_module.logger.experiment
|
|
tb.add_figure(tag=f"{self.title}",
|
|
figure=self.fig,
|
|
global_step=trainer.current_epoch,
|
|
close=False)
|
|
|
|
def log_and_display(self, trainer, pl_module):
|
|
if self.tensorboard:
|
|
self.add_to_tensorboard(trainer, pl_module)
|
|
if not self.block:
|
|
plt.pause(self.pause_time)
|
|
else:
|
|
plt.show(block=True)
|
|
|
|
def on_train_end(self, trainer, pl_module):
|
|
plt.show()
|
|
|
|
|
|
class VisGLVQ2D(Vis2DAbstract):
|
|
def on_epoch_end(self, trainer, pl_module):
|
|
self.precheck(trainer)
|
|
|
|
protos = pl_module.prototypes
|
|
plabels = pl_module.prototype_labels
|
|
x_train, y_train = self.x_train, self.y_train
|
|
ax = self.setup_ax(xlabel="Data dimension 1",
|
|
ylabel="Data dimension 2")
|
|
self.plot_data(ax, x_train, y_train)
|
|
self.plot_protos(ax, protos, plabels)
|
|
x = np.vstack((x_train, protos))
|
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
|
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
|
y_pred = y_pred.reshape(xx.shape)
|
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
|
|
|
self.log_and_display(trainer, pl_module)
|
|
|
|
|
|
class VisSiameseGLVQ2D(Vis2DAbstract):
|
|
def __init__(self, *args, map_protos=True, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.map_protos = map_protos
|
|
|
|
def on_epoch_end(self, trainer, pl_module):
|
|
self.precheck(trainer)
|
|
|
|
protos = pl_module.prototypes
|
|
plabels = pl_module.prototype_labels
|
|
x_train, y_train = self.x_train, self.y_train
|
|
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
|
|
if self.map_protos:
|
|
protos = pl_module.backbone(torch.Tensor(protos)).detach()
|
|
ax = self.setup_ax()
|
|
self.plot_data(ax, x_train, y_train)
|
|
if self.show_protos:
|
|
self.plot_protos(ax, protos, plabels)
|
|
x = np.vstack((x_train, protos))
|
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
|
else:
|
|
mesh_input, xx, yy = self.get_mesh_input(x_train)
|
|
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
|
|
y_pred = y_pred.reshape(xx.shape)
|
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
|
|
|
self.log_and_display(trainer, pl_module)
|
|
|
|
|
|
class VisCBC2D(Vis2DAbstract):
|
|
def on_epoch_end(self, trainer, pl_module):
|
|
self.precheck(trainer)
|
|
|
|
x_train, y_train = self.x_train, self.y_train
|
|
protos = pl_module.components
|
|
ax = self.setup_ax(xlabel="Data dimension 1",
|
|
ylabel="Data dimension 2")
|
|
self.plot_data(ax, x_train, y_train)
|
|
self.plot_protos(ax, protos, plabels)
|
|
x = np.vstack((x_train, protos))
|
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
|
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
|
y_pred = y_pred.reshape(xx.shape)
|
|
|
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
|
|
|
self.log_and_display(trainer, pl_module)
|
|
|
|
|
|
class VisNG2D(Vis2DAbstract):
|
|
def on_epoch_end(self, trainer, pl_module):
|
|
self.precheck(trainer)
|
|
|
|
x_train, y_train = self.x_train, self.y_train
|
|
protos = pl_module.prototypes
|
|
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
|
|
|
ax = self.setup_ax(xlabel="Data dimension 1",
|
|
ylabel="Data dimension 2")
|
|
self.plot_data(ax, x_train, y_train)
|
|
self.plot_protos(ax, protos, "w")
|
|
|
|
# Draw connections
|
|
for i in range(len(protos)):
|
|
for j in range(i, len(protos)):
|
|
if cmat[i][j]:
|
|
ax.plot(
|
|
[protos[i, 0], protos[j, 0]],
|
|
[protos[i, 1], protos[j, 1]],
|
|
"k-",
|
|
)
|
|
|
|
self.log_and_display(trainer, pl_module)
|