feat: CLCC register torchmetrics added

This commit is contained in:
Alexander Engelsberger 2021-10-15 15:18:02 +02:00
parent d1985571b3
commit 5ce326ce62
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
2 changed files with 89 additions and 4 deletions

View File

@ -8,11 +8,18 @@ CLCC is a LVQ scheme containing 4 steps
- Competition - Competition
""" """
from typing import Dict, Set, Type
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics
class CLCCScheme(pl.LightningModule): class CLCCScheme(pl.LightningModule):
registered_metrics: Dict[Type[torchmetrics.Metric],
torchmetrics.Metric] = {}
registered_metric_names: Dict[Type[torchmetrics.Metric], Set[str]] = {}
def __init__(self, hparams) -> None: def __init__(self, hparams) -> None:
super().__init__() super().__init__()
@ -28,7 +35,18 @@ class CLCCScheme(pl.LightningModule):
# Inference Steps # Inference Steps
self.init_inference(hparams) self.init_inference(hparams)
# API # Initialize Model Metrics
self.init_model_metrics()
# internal API, called by models and callbacks
def register_torchmetric(self, name: str, metric: torchmetrics.Metric):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric()
self.registered_metric_names[metric] = {name}
else:
self.registered_metric_names[metric].add(name)
# external API
def get_competion(self, batch, components): def get_competion(self, batch, components):
latent_batch, latent_components = self.latent(batch, components) latent_batch, latent_components = self.latent(batch, components)
# TODO: => Latent Hook # TODO: => Latent Hook
@ -81,6 +99,9 @@ class CLCCScheme(pl.LightningModule):
def init_inference(self, hparams): def init_inference(self, hparams):
... ...
def init_model_metrics(self):
self.register_torchmetric('train_accuracy', torchmetrics.Accuracy)
# Empty Steps # Empty Steps
# TODO: Type hints # TODO: Type hints
def components(self): def components(self):
@ -136,10 +157,34 @@ class CLCCScheme(pl.LightningModule):
raise NotImplementedError( raise NotImplementedError(
"The inference step has no reasonable default.") "The inference step has no reasonable default.")
def update_metrics_step(self, batch):
x, y = batch
preds = self(x)
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance(y, preds)
for name in self.registered_metric_names[metric]:
self.log(name, value)
def update_metrics_epoch(self):
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance.compute()
for name in self.registered_metric_names[metric]:
self.log(name, value)
# Lightning Hooks # Lightning Hooks
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch)
return self.loss_forward(batch) return self.loss_forward(batch)
def train_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
return self.loss_forward(batch) return self.loss_forward(batch)

View File

@ -1,20 +1,47 @@
from typing import Optional
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.models.vis import Visualize2DVoronoiCallback from prototorch.models.vis import Visualize2DVoronoiCallback
# NEW STUFF # NEW STUFF
# ############################################################################## # ##############################################################################
# TODO: Metrics
class MetricsTestCallback(pl.Callback):
metric_name = "test_cb_acc"
def setup(self,
trainer: pl.Trainer,
pl_module: CLCCScheme,
stage: Optional[str] = None) -> None:
pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy)
def on_epoch_end(self, trainer: pl.Trainer,
pl_module: pl.LightningModule) -> None:
metric = trainer.logged_metrics[self.metric_name]
if metric > 0.95:
trainer.should_stop = True
# TODO: Pruning
# ############################################################################## # ##############################################################################
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=64,
num_workers=8)
components_initializer = SMCI(train_ds) components_initializer = SMCI(train_ds)
@ -29,8 +56,21 @@ if __name__ == "__main__":
print(model) print(model)
# Callbacks # Callbacks
vis = Visualize2DVoronoiCallback(data=train_ds, resolution=500) vis = Visualize2DVoronoiCallback(
data=train_ds,
resolution=500,
)
metrics = MetricsTestCallback()
# Train # Train
trainer = pl.Trainer(callbacks=[vis], gpus=1, max_epochs=100) trainer = pl.Trainer(
callbacks=[
#vis,
metrics,
],
gpus=1,
max_epochs=100,
weights_summary=None,
log_every_n_steps=1,
)
trainer.fit(model, train_loader) trainer.fit(model, train_loader)