prototorch_models/prototorch/models/vis.py

364 lines
12 KiB
Python
Raw Normal View History

2021-06-01 22:49:36 +00:00
"""Visualization Callbacks."""
import warnings
from typing import Sized
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from matplotlib import pyplot as plt
2022-05-16 09:12:53 +00:00
from prototorch.utils.colors import get_colors, get_legend_handles
from prototorch.utils.utils import mesh2d
from pytorch_lightning.loggers import TensorBoardLogger
2021-05-17 14:59:22 +00:00
from torch.utils.data import DataLoader, Dataset
class Vis2DAbstract(pl.Callback):
def __init__(self,
data=None,
title="Prototype Visualization",
cmap="viridis",
xlabel="Data dimension 1",
ylabel="Data dimension 2",
2022-03-10 14:24:44 +00:00
legend_labels=None,
2021-05-17 14:59:22 +00:00
border=0.1,
resolution=100,
2021-05-20 14:07:16 +00:00
flatten_data=True,
2021-05-17 14:59:22 +00:00
axis_off=False,
2021-05-09 18:53:03 +00:00
show_protos=True,
show=True,
tensorboard=False,
show_last_only=False,
pause_time=0.1,
block=False):
super().__init__()
if data:
if isinstance(data, Dataset):
if isinstance(data, Sized):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
else:
# TODO: Add support for non-sized datasets
raise NotImplementedError(
"Data must be a dataset with a __len__ method.")
elif isinstance(data, DataLoader):
x = torch.tensor([])
y = torch.tensor([])
for x_b, y_b in data:
x = torch.cat([x, x_b])
y = torch.cat([y, y_b])
else:
x, y = data
2021-05-20 14:07:16 +00:00
if flatten_data:
x = x.reshape(len(x), -1)
2021-05-20 14:07:16 +00:00
self.x_train = x
self.y_train = y
else:
self.x_train = None
self.y_train = None
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
2022-03-10 14:24:44 +00:00
self.legend_labels = legend_labels
self.fig = plt.figure(self.title)
self.cmap = cmap
self.border = border
self.resolution = resolution
2021-05-17 14:59:22 +00:00
self.axis_off = axis_off
2021-05-09 18:53:03 +00:00
self.show_protos = show_protos
self.show = show
self.tensorboard = tensorboard
self.show_last_only = show_last_only
self.pause_time = pause_time
self.block = block
2021-05-09 18:53:03 +00:00
def precheck(self, trainer):
if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1:
return False
return True
2021-05-09 18:53:03 +00:00
def setup_ax(self):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.ylabel)
2021-05-17 14:59:22 +00:00
if self.axis_off:
ax.axis("off")
return ax
2021-05-09 18:53:03 +00:00
def plot_data(self, ax, x, y):
ax.scatter(
x[:, 0],
x[:, 1],
c=y,
cmap=self.cmap,
edgecolor="k",
marker="o",
s=30,
)
def plot_protos(self, ax, protos, plabels):
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
tb.add_figure(tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False)
def log_and_display(self, trainer, pl_module):
if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module)
if self.show:
if not self.block:
plt.pause(self.pause_time)
else:
2021-05-18 08:13:22 +00:00
plt.show(block=self.block)
def on_train_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
self.visualize(pl_module)
self.log_and_display(trainer, pl_module)
2021-05-10 12:30:02 +00:00
def on_train_end(self, trainer, pl_module):
2021-05-18 08:13:22 +00:00
plt.close()
2021-05-10 12:30:02 +00:00
def visualize(self, pl_module):
raise NotImplementedError
class VisGLVQ2D(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
2022-08-15 10:14:14 +00:00
self.plot_protos(ax, protos, plabels)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
self.border, self.resolution)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
2021-05-13 13:22:01 +00:00
_components = pl_module.proto_layer._components
2021-06-04 20:20:32 +00:00
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
2021-05-17 14:59:22 +00:00
y_pred = pl_module.predict(mesh_input)
2021-05-13 13:22:01 +00:00
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
class VisSiameseGLVQ2D(Vis2DAbstract):
2021-05-09 18:53:03 +00:00
def __init__(self, *args, map_protos=True, **kwargs):
super().__init__(*args, **kwargs)
self.map_protos = map_protos
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
2021-05-17 14:59:22 +00:00
device = pl_module.device
with torch.no_grad():
x_train = pl_module.backbone(torch.Tensor(x_train).to(device))
x_train = x_train.cpu().detach()
2021-05-09 18:53:03 +00:00
if self.map_protos:
2021-05-17 14:59:22 +00:00
with torch.no_grad():
protos = pl_module.backbone(torch.Tensor(protos).to(device))
protos = protos.cpu().detach()
ax = self.setup_ax()
2021-05-09 18:53:03 +00:00
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
2021-06-18 11:43:44 +00:00
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
else:
2021-06-18 11:43:44 +00:00
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
_components = pl_module.proto_layer._components
2021-05-17 14:59:22 +00:00
mesh_input = torch.Tensor(mesh_input).type_as(_components)
y_pred = pl_module.predict_latent(mesh_input,
map_protos=self.map_protos)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
2021-09-01 08:49:57 +00:00
class VisGMLVQ2D(Vis2DAbstract):
2021-09-01 08:49:57 +00:00
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def visualize(self, pl_module):
2021-09-01 08:49:57 +00:00
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
class VisCBC2D(Vis2DAbstract):
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
2022-03-30 13:10:06 +00:00
ax = self.setup_ax()
2021-05-09 18:53:03 +00:00
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos))
2021-06-18 11:43:44 +00:00
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.components_layer._components
2021-05-13 13:22:01 +00:00
y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
class VisNG2D(Vis2DAbstract):
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
2022-03-30 13:10:06 +00:00
ax = self.setup_ax()
2021-05-09 18:53:03 +00:00
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
# Draw connections
for i in range(len(protos)):
2021-05-09 18:53:03 +00:00
for j in range(i, len(protos)):
if cmat[i][j]:
ax.plot(
[protos[i, 0], protos[j, 0]],
[protos[i, 1], protos[j, 1]],
"k-",
)
2022-03-10 14:24:44 +00:00
class VisSpectralProtos(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
ax = self.setup_ax()
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
for p, pl in zip(protos, plabels):
ax.plot(p, c=colors[int(pl)])
if self.legend_labels:
handles = get_legend_handles(
colors,
self.legend_labels,
marker="lines",
)
ax.legend(handles=handles)
class VisImgComp(Vis2DAbstract):
def __init__(self,
*args,
random_data=0,
dataformats="CHW",
2021-05-25 13:41:10 +00:00
num_columns=2,
2021-05-20 14:07:16 +00:00
add_embedding=False,
embedding_data=100,
**kwargs):
super().__init__(*args, **kwargs)
self.random_data = random_data
self.dataformats = dataformats
2021-05-25 13:41:10 +00:00
self.num_columns = num_columns
2021-05-20 14:07:16 +00:00
self.add_embedding = add_embedding
self.embedding_data = embedding_data
def on_train_start(self, _, pl_module):
if isinstance(pl_module.logger, TensorBoardLogger):
tb = pl_module.logger.experiment
# Add embedding
if self.add_embedding:
if self.x_train is not None and self.y_train is not None:
ind = np.random.choice(len(self.x_train),
size=self.embedding_data,
replace=False)
data = self.x_train[ind]
tb.add_embedding(data.view(len(ind), -1),
label_img=data,
global_step=None,
tag="Data Embedding",
metadata=self.y_train[ind],
metadata_header=None)
else:
raise ValueError("No data for add embedding flag")
# Random Data
if self.random_data:
if self.x_train is not None:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data = self.x_train[ind]
grid = torchvision.utils.make_grid(data,
nrow=self.num_columns)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=None,
dataformats=self.dataformats)
else:
raise ValueError("No data for random data flag")
else:
warnings.warn(
f"TensorBoardLogger is required, got {type(pl_module.logger)}")
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
components = pl_module.components
2021-05-25 13:41:10 +00:00
grid = torchvision.utils.make_grid(components, nrow=self.num_columns)
tb.add_image(
tag="Components",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats=self.dataformats,
)
def visualize(self, pl_module):
2021-05-18 08:13:22 +00:00
if self.show:
components = pl_module.components
2021-05-25 13:41:10 +00:00
grid = torchvision.utils.make_grid(components,
nrow=self.num_columns)
2021-05-18 08:13:22 +00:00
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)