Clean visualization callbacks
This commit is contained in:
parent
6e7d80be88
commit
77b7b59bad
@ -6,262 +6,11 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from matplotlib.offsetbox import AnchoredText
|
from matplotlib.offsetbox import AnchoredText
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
|
|
||||||
from prototorch.utils.celluloid import Camera
|
from prototorch.utils.celluloid import Camera
|
||||||
from prototorch.utils.colors import color_scheme
|
from prototorch.utils.colors import color_scheme
|
||||||
from prototorch.utils.utils import (gif_from_dir, make_directory,
|
from prototorch.utils.utils import (gif_from_dir, make_directory,
|
||||||
prettify_string)
|
prettify_string)
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
class VisWeights(pl.Callback):
|
|
||||||
"""Abstract weight visualization callback."""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
data=None,
|
|
||||||
ignore_last_output_row=False,
|
|
||||||
label_map=None,
|
|
||||||
project_mesh=False,
|
|
||||||
project_protos=False,
|
|
||||||
voronoi=False,
|
|
||||||
axis_off=True,
|
|
||||||
cmap="viridis",
|
|
||||||
show=True,
|
|
||||||
display_logs=True,
|
|
||||||
display_logs_settings={},
|
|
||||||
pause_time=0.5,
|
|
||||||
border=1,
|
|
||||||
resolution=10,
|
|
||||||
interval=False,
|
|
||||||
save=False,
|
|
||||||
snap=True,
|
|
||||||
save_dir="./img",
|
|
||||||
make_gif=False,
|
|
||||||
make_mp4=False,
|
|
||||||
verbose=True,
|
|
||||||
dpi=500,
|
|
||||||
fps=5,
|
|
||||||
figsize=(11, 8.5), # standard paper in inches
|
|
||||||
prefix="",
|
|
||||||
distance_layer_index=-1,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.data = data
|
|
||||||
self.ignore_last_output_row = ignore_last_output_row
|
|
||||||
self.label_map = label_map
|
|
||||||
self.voronoi = voronoi
|
|
||||||
self.axis_off = True
|
|
||||||
self.project_mesh = project_mesh
|
|
||||||
self.project_protos = project_protos
|
|
||||||
self.cmap = cmap
|
|
||||||
self.show = show
|
|
||||||
self.display_logs = display_logs
|
|
||||||
self.display_logs_settings = display_logs_settings
|
|
||||||
self.pause_time = pause_time
|
|
||||||
self.border = border
|
|
||||||
self.resolution = resolution
|
|
||||||
self.interval = interval
|
|
||||||
self.save = save
|
|
||||||
self.snap = snap
|
|
||||||
self.save_dir = save_dir
|
|
||||||
self.make_gif = make_gif
|
|
||||||
self.make_mp4 = make_mp4
|
|
||||||
self.verbose = verbose
|
|
||||||
self.dpi = dpi
|
|
||||||
self.fps = fps
|
|
||||||
self.figsize = figsize
|
|
||||||
self.prefix = prefix
|
|
||||||
self.distance_layer_index = distance_layer_index
|
|
||||||
self.title = "Weights Visualization"
|
|
||||||
make_directory(self.save_dir)
|
|
||||||
|
|
||||||
def _skip_epoch(self, epoch):
|
|
||||||
if self.interval:
|
|
||||||
if epoch % self.interval != 0:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _clean_and_setup_ax(self):
|
|
||||||
ax = self.ax
|
|
||||||
if not self.snap:
|
|
||||||
ax.cla()
|
|
||||||
ax.set_title(self.title)
|
|
||||||
if self.axis_off:
|
|
||||||
ax.axis("off")
|
|
||||||
|
|
||||||
def _savefig(self, fignum, orientation="horizontal"):
|
|
||||||
figname = f"{self.save_dir}/{self.prefix}{fignum:05d}.png"
|
|
||||||
figsize = self.figsize
|
|
||||||
if orientation == "vertical":
|
|
||||||
figsize = figsize[::-1]
|
|
||||||
elif orientation == "horizontal":
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
self.fig.set_size_inches(figsize, forward=False)
|
|
||||||
self.fig.savefig(figname, dpi=self.dpi)
|
|
||||||
|
|
||||||
def _show_and_save(self, epoch):
|
|
||||||
if self.show:
|
|
||||||
plt.pause(self.pause_time)
|
|
||||||
if self.save:
|
|
||||||
self._savefig(epoch)
|
|
||||||
if self.snap:
|
|
||||||
self.camera.snap()
|
|
||||||
|
|
||||||
def _display_logs(self, ax, epoch, logs):
|
|
||||||
if self.display_logs:
|
|
||||||
settings = dict(
|
|
||||||
loc="lower right",
|
|
||||||
# padding between the text and bounding box
|
|
||||||
pad=0.5,
|
|
||||||
# padding between the bounding box and the axes
|
|
||||||
borderpad=1.0,
|
|
||||||
# https://matplotlib.org/api/text_api.html#matplotlib.text.Text
|
|
||||||
prop=dict(
|
|
||||||
fontfamily="monospace",
|
|
||||||
fontweight="medium",
|
|
||||||
fontsize=12,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Override settings with self.display_logs_settings.
|
|
||||||
settings = {**settings, **self.display_logs_settings}
|
|
||||||
|
|
||||||
log_string = f"""Epoch: {epoch:04d},
|
|
||||||
val_loss: {logs.get('val_loss', np.nan):.03f},
|
|
||||||
val_acc: {logs.get('val_acc', np.nan):.03f},
|
|
||||||
loss: {logs.get('loss', np.nan):.03f},
|
|
||||||
acc: {logs.get('acc', np.nan):.03f}
|
|
||||||
"""
|
|
||||||
log_string = prettify_string(log_string, end="")
|
|
||||||
# https://matplotlib.org/api/offsetbox_api.html#matplotlib.offsetbox.AnchoredText
|
|
||||||
anchored_text = AnchoredText(log_string, **settings)
|
|
||||||
self.ax.add_artist(anchored_text)
|
|
||||||
|
|
||||||
def on_train_start(self, trainer, pl_module, logs={}):
|
|
||||||
self.fig = plt.figure(self.title)
|
|
||||||
self.fig.set_size_inches(self.figsize, forward=False)
|
|
||||||
self.ax = self.fig.add_subplot(111)
|
|
||||||
self.camera = Camera(self.fig)
|
|
||||||
|
|
||||||
def on_train_end(self, trainer, pl_module, logs={}):
|
|
||||||
if self.make_gif:
|
|
||||||
gif_from_dir(directory=self.save_dir,
|
|
||||||
prefix=self.prefix,
|
|
||||||
duration=1.0 / self.fps)
|
|
||||||
if self.snap and self.make_mp4:
|
|
||||||
animation = self.camera.animate()
|
|
||||||
vid = os.path.join(self.save_dir, f"{self.prefix}animation.mp4")
|
|
||||||
if self.verbose:
|
|
||||||
print(f"Saving mp4 under {vid}.")
|
|
||||||
animation.save(vid, fps=self.fps, dpi=self.dpi)
|
|
||||||
|
|
||||||
|
|
||||||
class VisPointProtos(VisWeights):
|
|
||||||
"""Visualization of prototypes.
|
|
||||||
.. TODO::
|
|
||||||
Still in Progress.
|
|
||||||
"""
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.title = "Point Prototypes Visualization"
|
|
||||||
self.data_scatter_settings = {
|
|
||||||
"marker": "o",
|
|
||||||
"s": 30,
|
|
||||||
"edgecolor": "k",
|
|
||||||
"cmap": self.cmap,
|
|
||||||
}
|
|
||||||
self.protos_scatter_settings = {
|
|
||||||
"marker": "D",
|
|
||||||
"s": 50,
|
|
||||||
"edgecolor": "k",
|
|
||||||
"cmap": self.cmap,
|
|
||||||
}
|
|
||||||
|
|
||||||
def on_epoch_start(self, trainer, pl_module, logs={}):
|
|
||||||
epoch = trainer.current_epoch
|
|
||||||
if self._skip_epoch(epoch):
|
|
||||||
return True
|
|
||||||
|
|
||||||
self._clean_and_setup_ax()
|
|
||||||
|
|
||||||
protos = pl_module.prototypes
|
|
||||||
labels = pl_module.proto_layer.prototype_labels.detach().cpu().numpy()
|
|
||||||
|
|
||||||
if self.project_protos:
|
|
||||||
protos = self.model.projection(protos).numpy()
|
|
||||||
|
|
||||||
color_map = color_scheme(n=len(set(labels)),
|
|
||||||
cmap=self.cmap,
|
|
||||||
zero_indexed=True)
|
|
||||||
# TODO Get rid of the assumption y values in [0, num_of_classes]
|
|
||||||
label_colors = [color_map[l] for l in labels]
|
|
||||||
|
|
||||||
if self.data is not None:
|
|
||||||
x, y = self.data
|
|
||||||
# TODO Get rid of the assumption y values in [0, num_of_classes]
|
|
||||||
y_colors = [color_map[l] for l in y]
|
|
||||||
# x = self.model.projection(x)
|
|
||||||
if not isinstance(x, np.ndarray):
|
|
||||||
x = x.numpy()
|
|
||||||
|
|
||||||
# Plot data points.
|
|
||||||
self.ax.scatter(x[:, 0],
|
|
||||||
x[:, 1],
|
|
||||||
c=y_colors,
|
|
||||||
**self.data_scatter_settings)
|
|
||||||
|
|
||||||
# Paint decision regions.
|
|
||||||
if self.voronoi:
|
|
||||||
border = self.border
|
|
||||||
resolution = self.resolution
|
|
||||||
x = np.vstack((x, protos))
|
|
||||||
x_min, x_max = x[:, 0].min(), x[:, 0].max()
|
|
||||||
y_min, y_max = x[:, 1].min(), x[:, 1].max()
|
|
||||||
x_min, x_max = x_min - border, x_max + border
|
|
||||||
y_min, y_max = y_min - border, y_max + border
|
|
||||||
try:
|
|
||||||
xx, yy = np.meshgrid(
|
|
||||||
np.arange(x_min, x_max, (x_max - x_min) / resolution),
|
|
||||||
np.arange(y_min, y_max, (x_max - x_min) / resolution),
|
|
||||||
)
|
|
||||||
except ValueError as ve:
|
|
||||||
print(ve)
|
|
||||||
raise ValueError(f"x_min: {x_min}, x_max: {x_max}. "
|
|
||||||
f"x_min - x_max is {x_max - x_min}.")
|
|
||||||
except MemoryError as me:
|
|
||||||
print(me)
|
|
||||||
raise ValueError("Too many points. "
|
|
||||||
"Try reducing the resolution.")
|
|
||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
||||||
|
|
||||||
# Predict mesh labels.
|
|
||||||
if self.project_mesh:
|
|
||||||
mesh_input = self.model.projection(mesh_input)
|
|
||||||
|
|
||||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
|
||||||
|
|
||||||
# Plot voronoi regions.
|
|
||||||
self.ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
|
||||||
|
|
||||||
self.ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
|
||||||
self.ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
|
||||||
|
|
||||||
# Plot prototypes.
|
|
||||||
self.ax.scatter(protos[:, 0],
|
|
||||||
protos[:, 1],
|
|
||||||
c=label_colors,
|
|
||||||
**self.protos_scatter_settings)
|
|
||||||
|
|
||||||
# self._show_and_save(epoch)
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module, logs={}):
|
|
||||||
epoch = trainer.current_epoch
|
|
||||||
self._display_logs(self.ax, epoch, logs)
|
|
||||||
self._show_and_save(epoch)
|
|
||||||
|
|
||||||
|
|
||||||
class Vis2DAbstract(pl.Callback):
|
class Vis2DAbstract(pl.Callback):
|
||||||
@ -269,8 +18,9 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
data,
|
data,
|
||||||
title="Prototype Visualization",
|
title="Prototype Visualization",
|
||||||
cmap="viridis",
|
cmap="viridis",
|
||||||
border=1,
|
border=0.1,
|
||||||
resolution=50,
|
resolution=100,
|
||||||
|
axis_off=False,
|
||||||
show_protos=True,
|
show_protos=True,
|
||||||
show=True,
|
show=True,
|
||||||
tensorboard=False,
|
tensorboard=False,
|
||||||
@ -292,6 +42,7 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
self.cmap = cmap
|
self.cmap = cmap
|
||||||
self.border = border
|
self.border = border
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
|
self.axis_off = axis_off
|
||||||
self.show_protos = show_protos
|
self.show_protos = show_protos
|
||||||
self.show = show
|
self.show = show
|
||||||
self.tensorboard = tensorboard
|
self.tensorboard = tensorboard
|
||||||
@ -309,18 +60,21 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
ax = self.fig.gca()
|
ax = self.fig.gca()
|
||||||
ax.cla()
|
ax.cla()
|
||||||
ax.set_title(self.title)
|
ax.set_title(self.title)
|
||||||
ax.axis("off")
|
|
||||||
if xlabel:
|
if xlabel:
|
||||||
ax.set_xlabel("Data dimension 1")
|
ax.set_xlabel("Data dimension 1")
|
||||||
if ylabel:
|
if ylabel:
|
||||||
ax.set_ylabel("Data dimension 2")
|
ax.set_ylabel("Data dimension 2")
|
||||||
|
if self.axis_off:
|
||||||
|
ax.axis("off")
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
def get_mesh_input(self, x):
|
def get_mesh_input(self, x):
|
||||||
x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border
|
x_shift = self.border * np.ptp(x[:, 0])
|
||||||
y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border
|
y_shift = self.border * np.ptp(x[:, 1])
|
||||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / self.resolution),
|
x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
|
||||||
np.arange(y_min, y_max, 1 / self.resolution))
|
y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
|
||||||
|
xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution),
|
||||||
|
np.linspace(y_min, y_max, self.resolution))
|
||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||||
return mesh_input, xx, yy
|
return mesh_input, xx, yy
|
||||||
|
|
||||||
@ -381,8 +135,8 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||||
_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
y_pred = pl_module.predict(
|
mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
||||||
torch.Tensor(mesh_input).type_as(_components))
|
y_pred = pl_module.predict(mesh_input)
|
||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
@ -401,11 +155,14 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
x_train = pl_module.backbone(
|
device = pl_module.device
|
||||||
torch.Tensor(x_train).to(pl_module.device)).cpu().detach()
|
with torch.no_grad():
|
||||||
|
x_train = pl_module.backbone(torch.Tensor(x_train).to(device))
|
||||||
|
x_train = x_train.cpu().detach()
|
||||||
if self.map_protos:
|
if self.map_protos:
|
||||||
protos = pl_module.backbone(
|
with torch.no_grad():
|
||||||
torch.Tensor(protos).to(pl_module.device)).cpu().detach()
|
protos = pl_module.backbone(torch.Tensor(protos).to(device))
|
||||||
|
protos = protos.cpu().detach()
|
||||||
ax = self.setup_ax()
|
ax = self.setup_ax()
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
if self.show_protos:
|
if self.show_protos:
|
||||||
@ -415,8 +172,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
else:
|
else:
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x_train)
|
mesh_input, xx, yy = self.get_mesh_input(x_train)
|
||||||
_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
y_pred = pl_module.predict_latent(
|
mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
||||||
torch.Tensor(mesh_input).type_as(_components))
|
y_pred = pl_module.predict_latent(mesh_input,
|
||||||
|
map_protos=self.map_protos)
|
||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user