feat: add confusion matrix callback

This commit is contained in:
Alexander Engelsberger
2022-06-09 14:55:59 +02:00
parent 696719600b
commit a7df7be1c8
2 changed files with 103 additions and 34 deletions

View File

@@ -5,6 +5,7 @@ Network architecture for Component based Learning.
"""
from dataclasses import dataclass
from typing import (
Callable,
Dict,
Set,
Type,
@@ -13,7 +14,6 @@ from typing import (
import pytorch_lightning as pl
import torch
from torchmetrics import Metric
from torchmetrics.classification.accuracy import Accuracy
class BaseYArchitecture(pl.LightningModule):
@@ -22,9 +22,11 @@ class BaseYArchitecture(pl.LightningModule):
class HyperParameters:
...
# Fields
registered_metrics: Dict[Type[Metric], Metric] = {}
registered_metric_names: Dict[Type[Metric], Set[str]] = {}
registered_metric_callbacks: Dict[Type[Metric], Set[Callable]] = {}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
@@ -42,22 +44,6 @@ class BaseYArchitecture(pl.LightningModule):
# Inference Steps
self.init_inference(hparams)
# Initialize Model Metrics
self.init_model_metrics()
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: str,
metric: Type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_names[metric] = {name}
else:
self.registered_metric_names[metric].add(name)
# external API
def get_competition(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
@@ -99,7 +85,6 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Type hints
# TODO: Docs
def init_components(self, hparams: HyperParameters) -> None:
...
@@ -119,9 +104,6 @@ class BaseYArchitecture(pl.LightningModule):
def init_inference(self, hparams: HyperParameters) -> None:
...
def init_model_metrics(self) -> None:
self.register_torchmetric('accuracy', Accuracy)
# Empty Steps
# TODO: Type hints
def components(self):
@@ -177,11 +159,26 @@ class BaseYArchitecture(pl.LightningModule):
raise NotImplementedError(
"The inference step has no reasonable default.")
def update_metrics_step(self, batch):
x, y = batch
# Y Architecture Hooks
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: Callable,
metric: Type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[metric] = {name}
else:
self.registered_metric_callbacks[metric].add(name)
def update_metrics_step(self, batch):
# Prediction Metrics
preds = self(x)
preds = self(batch)
x, y = batch
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
instance(y, preds)
@@ -191,22 +188,25 @@ class BaseYArchitecture(pl.LightningModule):
instance = self.registered_metrics[metric].to(self.device)
value = instance.compute()
for name in self.registered_metric_names[metric]:
self.log(name, value)
for callback in self.registered_metric_callbacks[metric]:
callback(value, self)
instance.reset()
# Lightning Hooks
# Steps
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch)
self.update_metrics_step([torch.clone(el) for el in batch])
return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)
# Other Hooks
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()