feat: add confusion matrix callback

This commit is contained in:
Alexander Engelsberger 2022-06-09 14:55:59 +02:00
parent 696719600b
commit a7df7be1c8
2 changed files with 103 additions and 34 deletions

View File

@ -5,6 +5,7 @@ Network architecture for Component based Learning.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Callable,
Dict, Dict,
Set, Set,
Type, Type,
@ -13,7 +14,6 @@ from typing import (
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torchmetrics import Metric from torchmetrics import Metric
from torchmetrics.classification.accuracy import Accuracy
class BaseYArchitecture(pl.LightningModule): class BaseYArchitecture(pl.LightningModule):
@ -22,9 +22,11 @@ class BaseYArchitecture(pl.LightningModule):
class HyperParameters: class HyperParameters:
... ...
# Fields
registered_metrics: Dict[Type[Metric], Metric] = {} registered_metrics: Dict[Type[Metric], Metric] = {}
registered_metric_names: Dict[Type[Metric], Set[str]] = {} registered_metric_callbacks: Dict[Type[Metric], Set[Callable]] = {}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module components_layer: torch.nn.Module
def __init__(self, hparams) -> None: def __init__(self, hparams) -> None:
@ -42,22 +44,6 @@ class BaseYArchitecture(pl.LightningModule):
# Inference Steps # Inference Steps
self.init_inference(hparams) self.init_inference(hparams)
# Initialize Model Metrics
self.init_model_metrics()
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: str,
metric: Type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_names[metric] = {name}
else:
self.registered_metric_names[metric].add(name)
# external API # external API
def get_competition(self, batch, components): def get_competition(self, batch, components):
latent_batch, latent_components = self.latent(batch, components) latent_batch, latent_components = self.latent(batch, components)
@ -99,7 +85,6 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss(comparison_tensor, batch, components) return self.loss(comparison_tensor, batch, components)
# Empty Initialization # Empty Initialization
# TODO: Type hints
# TODO: Docs # TODO: Docs
def init_components(self, hparams: HyperParameters) -> None: def init_components(self, hparams: HyperParameters) -> None:
... ...
@ -119,9 +104,6 @@ class BaseYArchitecture(pl.LightningModule):
def init_inference(self, hparams: HyperParameters) -> None: def init_inference(self, hparams: HyperParameters) -> None:
... ...
def init_model_metrics(self) -> None:
self.register_torchmetric('accuracy', Accuracy)
# Empty Steps # Empty Steps
# TODO: Type hints # TODO: Type hints
def components(self): def components(self):
@ -177,11 +159,26 @@ class BaseYArchitecture(pl.LightningModule):
raise NotImplementedError( raise NotImplementedError(
"The inference step has no reasonable default.") "The inference step has no reasonable default.")
def update_metrics_step(self, batch): # Y Architecture Hooks
x, y = batch
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: Callable,
metric: Type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[metric] = {name}
else:
self.registered_metric_callbacks[metric].add(name)
def update_metrics_step(self, batch):
# Prediction Metrics # Prediction Metrics
preds = self(x) preds = self(batch)
x, y = batch
for metric in self.registered_metrics: for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device) instance = self.registered_metrics[metric].to(self.device)
instance(y, preds) instance(y, preds)
@ -191,22 +188,25 @@ class BaseYArchitecture(pl.LightningModule):
instance = self.registered_metrics[metric].to(self.device) instance = self.registered_metrics[metric].to(self.device)
value = instance.compute() value = instance.compute()
for name in self.registered_metric_names[metric]: for callback in self.registered_metric_callbacks[metric]:
self.log(name, value) callback(value, self)
instance.reset() instance.reset()
# Lightning Hooks # Lightning Hooks
# Steps
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch) self.update_metrics_step([torch.clone(el) for el in batch])
return self.loss_forward(batch) return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
return self.loss_forward(batch) return self.loss_forward(batch)
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
return self.loss_forward(batch) return self.loss_forward(batch)
# Other Hooks
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()

View File

@ -13,8 +13,18 @@ from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
DIVERGING_COLOR_MAPS = [ DIVERGING_COLOR_MAPS = [
'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn', 'PiYG',
'Spectral', 'coolwarm', 'bwr', 'seismic' 'PRGn',
'BrBG',
'PuOr',
'RdGy',
'RdBu',
'RdYlBu',
'RdYlGn',
'Spectral',
'coolwarm',
'bwr',
'seismic',
] ]
@ -40,13 +50,72 @@ class LogTorchmetricCallback(pl.Callback):
) -> None: ) -> None:
if self.on == "prediction": if self.on == "prediction":
pl_module.register_torchmetric( pl_module.register_torchmetric(
self.name, self,
self.metric, self.metric,
**self.metric_kwargs, **self.metric_kwargs,
) )
else: else:
raise ValueError(f"{self.on} is no valid metric hook") raise ValueError(f"{self.on} is no valid metric hook")
def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value)
class LogConfusionMatrix(LogTorchmetricCallback):
def __init__(
self,
num_classes,
name="confusion",
on='prediction',
**kwargs,
):
super().__init__(
name,
torchmetrics.ConfusionMatrix,
on=on,
num_classes=num_classes,
**kwargs,
)
def __call__(self, value, pl_module: BaseYArchitecture):
fig, ax = plt.subplots()
ax.imshow(value.detach().cpu().numpy())
# Show all ticks and label them with the respective list entries
# ax.set_xticks(np.arange(len(farmers)), labels=farmers)
# ax.set_yticks(np.arange(len(vegetables)), labels=vegetables)
# Rotate the tick labels and set their alignment.
plt.setp(
ax.get_xticklabels(),
rotation=45,
ha="right",
rotation_mode="anchor",
)
# Loop over data dimensions and create text annotations.
for i in range(len(value)):
for j in range(len(value)):
text = ax.text(
j,
i,
value[i, j].item(),
ha="center",
va="center",
color="w",
)
ax.set_title(self.name)
fig.tight_layout()
pl_module.logger.experiment.add_figure(
tag=self.name,
figure=fig,
close=True,
global_step=pl_module.global_step,
)
class VisGLVQ2D(Vis2DAbstract): class VisGLVQ2D(Vis2DAbstract):