From 5ce326ce62c7d249c1114b0fd614701aaa07e0bf Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Fri, 15 Oct 2021 15:18:02 +0200 Subject: [PATCH] feat: CLCC register torchmetrics added --- prototorch/models/clcc/clcc_scheme.py | 47 ++++++++++++++++++++++++++- prototorch/models/clcc/test_clcc.py | 46 ++++++++++++++++++++++++-- 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/prototorch/models/clcc/clcc_scheme.py b/prototorch/models/clcc/clcc_scheme.py index b0f7135..d0fdb23 100644 --- a/prototorch/models/clcc/clcc_scheme.py +++ b/prototorch/models/clcc/clcc_scheme.py @@ -8,11 +8,18 @@ CLCC is a LVQ scheme containing 4 steps - Competition """ +from typing import Dict, Set, Type + import pytorch_lightning as pl import torch +import torchmetrics class CLCCScheme(pl.LightningModule): + registered_metrics: Dict[Type[torchmetrics.Metric], + torchmetrics.Metric] = {} + registered_metric_names: Dict[Type[torchmetrics.Metric], Set[str]] = {} + def __init__(self, hparams) -> None: super().__init__() @@ -28,7 +35,18 @@ class CLCCScheme(pl.LightningModule): # Inference Steps self.init_inference(hparams) - # API + # Initialize Model Metrics + self.init_model_metrics() + + # internal API, called by models and callbacks + def register_torchmetric(self, name: str, metric: torchmetrics.Metric): + if metric not in self.registered_metrics: + self.registered_metrics[metric] = metric() + self.registered_metric_names[metric] = {name} + else: + self.registered_metric_names[metric].add(name) + + # external API def get_competion(self, batch, components): latent_batch, latent_components = self.latent(batch, components) # TODO: => Latent Hook @@ -81,6 +99,9 @@ class CLCCScheme(pl.LightningModule): def init_inference(self, hparams): ... + def init_model_metrics(self): + self.register_torchmetric('train_accuracy', torchmetrics.Accuracy) + # Empty Steps # TODO: Type hints def components(self): @@ -136,10 +157,34 @@ class CLCCScheme(pl.LightningModule): raise NotImplementedError( "The inference step has no reasonable default.") + def update_metrics_step(self, batch): + x, y = batch + preds = self(x) + + for metric in self.registered_metrics: + instance = self.registered_metrics[metric].to(self.device) + value = instance(y, preds) + + for name in self.registered_metric_names[metric]: + self.log(name, value) + + def update_metrics_epoch(self): + for metric in self.registered_metrics: + instance = self.registered_metrics[metric].to(self.device) + value = instance.compute() + + for name in self.registered_metric_names[metric]: + self.log(name, value) + # Lightning Hooks def training_step(self, batch, batch_idx, optimizer_idx=None): + self.update_metrics_step(batch) + return self.loss_forward(batch) + def train_epoch_end(self, outs) -> None: + self.update_metrics_epoch() + def validation_step(self, batch, batch_idx): return self.loss_forward(batch) diff --git a/prototorch/models/clcc/test_clcc.py b/prototorch/models/clcc/test_clcc.py index ef5234e..f0f39b7 100644 --- a/prototorch/models/clcc/test_clcc.py +++ b/prototorch/models/clcc/test_clcc.py @@ -1,20 +1,47 @@ +from typing import Optional + import matplotlib.pyplot as plt import prototorch as pt import pytorch_lightning as pl import torch +import torchmetrics from prototorch.core.initializers import SMCI, RandomNormalCompInitializer from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams +from prototorch.models.clcc.clcc_scheme import CLCCScheme from prototorch.models.vis import Visualize2DVoronoiCallback # NEW STUFF # ############################################################################## + + +# TODO: Metrics +class MetricsTestCallback(pl.Callback): + metric_name = "test_cb_acc" + + def setup(self, + trainer: pl.Trainer, + pl_module: CLCCScheme, + stage: Optional[str] = None) -> None: + pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy) + + def on_epoch_end(self, trainer: pl.Trainer, + pl_module: pl.LightningModule) -> None: + metric = trainer.logged_metrics[self.metric_name] + if metric > 0.95: + trainer.should_stop = True + + +# TODO: Pruning + # ############################################################################## if __name__ == "__main__": # Dataset train_ds = pt.datasets.Iris(dims=[0, 2]) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) + train_loader = torch.utils.data.DataLoader(train_ds, + batch_size=64, + num_workers=8) components_initializer = SMCI(train_ds) @@ -29,8 +56,21 @@ if __name__ == "__main__": print(model) # Callbacks - vis = Visualize2DVoronoiCallback(data=train_ds, resolution=500) + vis = Visualize2DVoronoiCallback( + data=train_ds, + resolution=500, + ) + metrics = MetricsTestCallback() # Train - trainer = pl.Trainer(callbacks=[vis], gpus=1, max_epochs=100) + trainer = pl.Trainer( + callbacks=[ + #vis, + metrics, + ], + gpus=1, + max_epochs=100, + weights_summary=None, + log_every_n_steps=1, + ) trainer.fit(model, train_loader)