feat: add confusion matrix callback
This commit is contained in:
parent
696719600b
commit
a7df7be1c8
@ -5,6 +5,7 @@ Network architecture for Component based Learning.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Set,
|
||||
Type,
|
||||
@ -13,7 +14,6 @@ from typing import (
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torchmetrics import Metric
|
||||
from torchmetrics.classification.accuracy import Accuracy
|
||||
|
||||
|
||||
class BaseYArchitecture(pl.LightningModule):
|
||||
@ -22,9 +22,11 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
class HyperParameters:
|
||||
...
|
||||
|
||||
# Fields
|
||||
registered_metrics: Dict[Type[Metric], Metric] = {}
|
||||
registered_metric_names: Dict[Type[Metric], Set[str]] = {}
|
||||
registered_metric_callbacks: Dict[Type[Metric], Set[Callable]] = {}
|
||||
|
||||
# Type Hints for Necessary Fields
|
||||
components_layer: torch.nn.Module
|
||||
|
||||
def __init__(self, hparams) -> None:
|
||||
@ -42,22 +44,6 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
# Inference Steps
|
||||
self.init_inference(hparams)
|
||||
|
||||
# Initialize Model Metrics
|
||||
self.init_model_metrics()
|
||||
|
||||
# internal API, called by models and callbacks
|
||||
def register_torchmetric(
|
||||
self,
|
||||
name: str,
|
||||
metric: Type[Metric],
|
||||
**metric_kwargs,
|
||||
):
|
||||
if metric not in self.registered_metrics:
|
||||
self.registered_metrics[metric] = metric(**metric_kwargs)
|
||||
self.registered_metric_names[metric] = {name}
|
||||
else:
|
||||
self.registered_metric_names[metric].add(name)
|
||||
|
||||
# external API
|
||||
def get_competition(self, batch, components):
|
||||
latent_batch, latent_components = self.latent(batch, components)
|
||||
@ -99,7 +85,6 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
return self.loss(comparison_tensor, batch, components)
|
||||
|
||||
# Empty Initialization
|
||||
# TODO: Type hints
|
||||
# TODO: Docs
|
||||
def init_components(self, hparams: HyperParameters) -> None:
|
||||
...
|
||||
@ -119,9 +104,6 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
def init_inference(self, hparams: HyperParameters) -> None:
|
||||
...
|
||||
|
||||
def init_model_metrics(self) -> None:
|
||||
self.register_torchmetric('accuracy', Accuracy)
|
||||
|
||||
# Empty Steps
|
||||
# TODO: Type hints
|
||||
def components(self):
|
||||
@ -177,11 +159,26 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
raise NotImplementedError(
|
||||
"The inference step has no reasonable default.")
|
||||
|
||||
def update_metrics_step(self, batch):
|
||||
x, y = batch
|
||||
# Y Architecture Hooks
|
||||
|
||||
# internal API, called by models and callbacks
|
||||
def register_torchmetric(
|
||||
self,
|
||||
name: Callable,
|
||||
metric: Type[Metric],
|
||||
**metric_kwargs,
|
||||
):
|
||||
if metric not in self.registered_metrics:
|
||||
self.registered_metrics[metric] = metric(**metric_kwargs)
|
||||
self.registered_metric_callbacks[metric] = {name}
|
||||
else:
|
||||
self.registered_metric_callbacks[metric].add(name)
|
||||
|
||||
def update_metrics_step(self, batch):
|
||||
# Prediction Metrics
|
||||
preds = self(x)
|
||||
preds = self(batch)
|
||||
|
||||
x, y = batch
|
||||
for metric in self.registered_metrics:
|
||||
instance = self.registered_metrics[metric].to(self.device)
|
||||
instance(y, preds)
|
||||
@ -191,22 +188,25 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
instance = self.registered_metrics[metric].to(self.device)
|
||||
value = instance.compute()
|
||||
|
||||
for name in self.registered_metric_names[metric]:
|
||||
self.log(name, value)
|
||||
for callback in self.registered_metric_callbacks[metric]:
|
||||
callback(value, self)
|
||||
|
||||
instance.reset()
|
||||
|
||||
# Lightning Hooks
|
||||
|
||||
# Steps
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
self.update_metrics_step(batch)
|
||||
self.update_metrics_step([torch.clone(el) for el in batch])
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def training_epoch_end(self, outs) -> None:
|
||||
self.update_metrics_epoch()
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self.loss_forward(batch)
|
||||
|
||||
# Other Hooks
|
||||
def training_epoch_end(self, outs) -> None:
|
||||
self.update_metrics_epoch()
|
||||
|
@ -13,8 +13,18 @@ from prototorch.y.library.gmlvq import GMLVQ
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
||||
DIVERGING_COLOR_MAPS = [
|
||||
'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn',
|
||||
'Spectral', 'coolwarm', 'bwr', 'seismic'
|
||||
'PiYG',
|
||||
'PRGn',
|
||||
'BrBG',
|
||||
'PuOr',
|
||||
'RdGy',
|
||||
'RdBu',
|
||||
'RdYlBu',
|
||||
'RdYlGn',
|
||||
'Spectral',
|
||||
'coolwarm',
|
||||
'bwr',
|
||||
'seismic',
|
||||
]
|
||||
|
||||
|
||||
@ -40,13 +50,72 @@ class LogTorchmetricCallback(pl.Callback):
|
||||
) -> None:
|
||||
if self.on == "prediction":
|
||||
pl_module.register_torchmetric(
|
||||
self.name,
|
||||
self,
|
||||
self.metric,
|
||||
**self.metric_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{self.on} is no valid metric hook")
|
||||
|
||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||
pl_module.log(self.name, value)
|
||||
|
||||
|
||||
class LogConfusionMatrix(LogTorchmetricCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
name="confusion",
|
||||
on='prediction',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
torchmetrics.ConfusionMatrix,
|
||||
on=on,
|
||||
num_classes=num_classes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(value.detach().cpu().numpy())
|
||||
|
||||
# Show all ticks and label them with the respective list entries
|
||||
# ax.set_xticks(np.arange(len(farmers)), labels=farmers)
|
||||
# ax.set_yticks(np.arange(len(vegetables)), labels=vegetables)
|
||||
|
||||
# Rotate the tick labels and set their alignment.
|
||||
plt.setp(
|
||||
ax.get_xticklabels(),
|
||||
rotation=45,
|
||||
ha="right",
|
||||
rotation_mode="anchor",
|
||||
)
|
||||
|
||||
# Loop over data dimensions and create text annotations.
|
||||
for i in range(len(value)):
|
||||
for j in range(len(value)):
|
||||
text = ax.text(
|
||||
j,
|
||||
i,
|
||||
value[i, j].item(),
|
||||
ha="center",
|
||||
va="center",
|
||||
color="w",
|
||||
)
|
||||
|
||||
ax.set_title(self.name)
|
||||
fig.tight_layout()
|
||||
|
||||
pl_module.logger.experiment.add_figure(
|
||||
tag=self.name,
|
||||
figure=fig,
|
||||
close=True,
|
||||
global_step=pl_module.global_step,
|
||||
)
|
||||
|
||||
|
||||
class VisGLVQ2D(Vis2DAbstract):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user