2022-05-19 14:13:08 +00:00
|
|
|
import warnings
|
|
|
|
from typing import Optional, Type
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torch
|
|
|
|
import torchmetrics
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
from prototorch.models.vis import Vis2DAbstract
|
|
|
|
from prototorch.utils.utils import mesh2d
|
2022-06-03 08:39:11 +00:00
|
|
|
from prototorch.y.architectures.base import BaseYArchitecture
|
|
|
|
from prototorch.y.library.gmlvq import GMLVQ
|
2022-05-19 14:13:08 +00:00
|
|
|
from pytorch_lightning.loggers import TensorBoardLogger
|
|
|
|
|
|
|
|
DIVERGING_COLOR_MAPS = [
|
|
|
|
'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn',
|
|
|
|
'Spectral', 'coolwarm', 'bwr', 'seismic'
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
class LogTorchmetricCallback(pl.Callback):
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
name,
|
|
|
|
metric: Type[torchmetrics.Metric],
|
|
|
|
on="prediction",
|
|
|
|
**metric_kwargs,
|
|
|
|
) -> None:
|
|
|
|
self.name = name
|
|
|
|
self.metric = metric
|
|
|
|
self.metric_kwargs = metric_kwargs
|
|
|
|
self.on = on
|
|
|
|
|
|
|
|
def setup(
|
|
|
|
self,
|
|
|
|
trainer: pl.Trainer,
|
|
|
|
pl_module: BaseYArchitecture,
|
|
|
|
stage: Optional[str] = None,
|
|
|
|
) -> None:
|
|
|
|
if self.on == "prediction":
|
|
|
|
pl_module.register_torchmetric(
|
|
|
|
self.name,
|
|
|
|
self.metric,
|
|
|
|
**self.metric_kwargs,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"{self.on} is no valid metric hook")
|
|
|
|
|
|
|
|
|
|
|
|
class VisGLVQ2D(Vis2DAbstract):
|
|
|
|
|
|
|
|
def visualize(self, pl_module):
|
|
|
|
protos = pl_module.prototypes
|
|
|
|
plabels = pl_module.prototype_labels
|
|
|
|
x_train, y_train = self.x_train, self.y_train
|
|
|
|
ax = self.setup_ax()
|
|
|
|
self.plot_protos(ax, protos, plabels)
|
|
|
|
if x_train is not None:
|
|
|
|
self.plot_data(ax, x_train, y_train)
|
|
|
|
mesh_input, xx, yy = mesh2d(
|
|
|
|
np.vstack([x_train, protos]),
|
|
|
|
self.border,
|
|
|
|
self.resolution,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
|
|
|
_components = pl_module.components_layer.components
|
|
|
|
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
|
|
|
y_pred = pl_module.predict(mesh_input)
|
|
|
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
|
|
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
|
|
|
|
|
|
|
|
|
|
|
class VisGMLVQ2D(Vis2DAbstract):
|
|
|
|
|
|
|
|
def __init__(self, *args, ev_proj=True, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.ev_proj = ev_proj
|
|
|
|
|
|
|
|
def visualize(self, pl_module):
|
|
|
|
protos = pl_module.prototypes
|
|
|
|
plabels = pl_module.prototype_labels
|
|
|
|
x_train, y_train = self.x_train, self.y_train
|
|
|
|
device = pl_module.device
|
|
|
|
omega = pl_module._omega.detach()
|
|
|
|
lam = omega @ omega.T
|
|
|
|
u, _, _ = torch.pca_lowrank(lam, q=2)
|
|
|
|
with torch.no_grad():
|
|
|
|
x_train = torch.Tensor(x_train).to(device)
|
|
|
|
x_train = x_train @ u
|
|
|
|
x_train = x_train.cpu().detach()
|
|
|
|
if self.show_protos:
|
|
|
|
with torch.no_grad():
|
|
|
|
protos = torch.Tensor(protos).to(device)
|
|
|
|
protos = protos @ u
|
|
|
|
protos = protos.cpu().detach()
|
|
|
|
ax = self.setup_ax()
|
|
|
|
self.plot_data(ax, x_train, y_train)
|
|
|
|
if self.show_protos:
|
|
|
|
self.plot_protos(ax, protos, plabels)
|
|
|
|
|
|
|
|
|
|
|
|
class PlotLambdaMatrixToTensorboard(pl.Callback):
|
|
|
|
|
|
|
|
def __init__(self, cmap='seismic') -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.cmap = cmap
|
|
|
|
|
|
|
|
if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str:
|
|
|
|
warnings.warn(
|
|
|
|
f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}"
|
|
|
|
)
|
|
|
|
|
|
|
|
def on_train_start(self, trainer, pl_module: GMLVQ):
|
|
|
|
self.plot_lambda(trainer, pl_module)
|
|
|
|
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
|
|
|
|
self.plot_lambda(trainer, pl_module)
|
|
|
|
|
|
|
|
def plot_lambda(self, trainer, pl_module: GMLVQ):
|
|
|
|
|
|
|
|
self.fig, self.ax = plt.subplots(1, 1)
|
|
|
|
|
|
|
|
# plot lambda matrix
|
|
|
|
l_matrix = pl_module.lambda_matrix
|
|
|
|
|
|
|
|
# normalize lambda matrix
|
|
|
|
l_matrix = l_matrix / torch.max(torch.abs(l_matrix))
|
|
|
|
|
|
|
|
# plot lambda matrix
|
|
|
|
self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1)
|
|
|
|
|
|
|
|
self.fig.colorbar(self.ax.images[-1])
|
|
|
|
|
|
|
|
# add title
|
|
|
|
self.ax.set_title('Lambda Matrix')
|
|
|
|
|
|
|
|
# add to tensorboard
|
|
|
|
if isinstance(trainer.logger, TensorBoardLogger):
|
|
|
|
trainer.logger.experiment.add_figure(
|
|
|
|
f"lambda_matrix",
|
|
|
|
self.fig,
|
|
|
|
trainer.global_step,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
warnings.warn(
|
|
|
|
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
|
|
|
|
)
|