prototorch_models/prototorch/models/proto_y_architecture/callbacks.py

64 lines
1.9 KiB
Python
Raw Normal View History

2022-05-17 15:25:51 +00:00
from typing import Optional, Type
2022-05-17 14:25:43 +00:00
2022-05-17 15:25:51 +00:00
import numpy as np
2022-05-17 14:25:43 +00:00
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
2022-05-17 15:25:51 +00:00
from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d
2022-05-17 14:25:43 +00:00
class LogTorchmetricCallback(pl.Callback):
2022-05-17 15:25:51 +00:00
def __init__(
self,
name,
metric: Type[torchmetrics.Metric],
on="prediction",
**metric_kwargs,
) -> None:
2022-05-17 14:25:43 +00:00
self.name = name
self.metric = metric
2022-05-17 15:25:51 +00:00
self.metric_kwargs = metric_kwargs
2022-05-17 14:25:43 +00:00
self.on = on
2022-05-17 15:25:51 +00:00
def setup(
self,
trainer: pl.Trainer,
pl_module: BaseYArchitecture,
2022-05-17 15:25:51 +00:00
stage: Optional[str] = None,
) -> None:
2022-05-17 14:25:43 +00:00
if self.on == "prediction":
2022-05-17 15:25:51 +00:00
pl_module.register_torchmetric(
self.name,
self.metric,
**self.metric_kwargs,
)
2022-05-17 14:25:43 +00:00
else:
raise ValueError(f"{self.on} is no valid metric hook")
2022-05-17 15:25:51 +00:00
class VisGLVQ2D(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(
np.vstack([x_train, protos]),
self.border,
self.resolution,
)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.components_layer.components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)