64 lines
1.9 KiB
Python
64 lines
1.9 KiB
Python
from typing import Optional, Type
|
|
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
import torchmetrics
|
|
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
|
|
from prototorch.models.vis import Vis2DAbstract
|
|
from prototorch.utils.utils import mesh2d
|
|
|
|
|
|
class LogTorchmetricCallback(pl.Callback):
|
|
|
|
def __init__(
|
|
self,
|
|
name,
|
|
metric: Type[torchmetrics.Metric],
|
|
on="prediction",
|
|
**metric_kwargs,
|
|
) -> None:
|
|
self.name = name
|
|
self.metric = metric
|
|
self.metric_kwargs = metric_kwargs
|
|
self.on = on
|
|
|
|
def setup(
|
|
self,
|
|
trainer: pl.Trainer,
|
|
pl_module: BaseYArchitecture,
|
|
stage: Optional[str] = None,
|
|
) -> None:
|
|
if self.on == "prediction":
|
|
pl_module.register_torchmetric(
|
|
self.name,
|
|
self.metric,
|
|
**self.metric_kwargs,
|
|
)
|
|
else:
|
|
raise ValueError(f"{self.on} is no valid metric hook")
|
|
|
|
|
|
class VisGLVQ2D(Vis2DAbstract):
|
|
|
|
def visualize(self, pl_module):
|
|
protos = pl_module.prototypes
|
|
plabels = pl_module.prototype_labels
|
|
x_train, y_train = self.x_train, self.y_train
|
|
ax = self.setup_ax()
|
|
self.plot_protos(ax, protos, plabels)
|
|
if x_train is not None:
|
|
self.plot_data(ax, x_train, y_train)
|
|
mesh_input, xx, yy = mesh2d(
|
|
np.vstack([x_train, protos]),
|
|
self.border,
|
|
self.resolution,
|
|
)
|
|
else:
|
|
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
|
_components = pl_module.components_layer.components
|
|
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
|
y_pred = pl_module.predict(mesh_input)
|
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|