feat: metrics can be assigned to the different phases

This commit is contained in:
Alexander Engelsberger 2022-06-24 15:04:35 +02:00
parent 46ec7b07d7
commit 736565b768
6 changed files with 237 additions and 69 deletions

View File

@ -23,6 +23,13 @@ ProtoTorch Models Plugins
custom custom
.. toctree::
:hidden:
:maxdepth: 3
:caption: Proto Y Architecture
y-architecture
About About
----------------------------------------- -----------------------------------------
`Prototorch Models <https://github.com/si-cim/prototorch_models>`_ is a Plugin `Prototorch Models <https://github.com/si-cim/prototorch_models>`_ is a Plugin
@ -33,8 +40,10 @@ prototype-based Machine Learning algorithms using `PyTorch-Lightning
Library Library
----------------------------------------- -----------------------------------------
Prototorch Models delivers many application ready models. Prototorch Models delivers many application ready models.
These models have been published in the past and have been adapted to the Prototorch library. These models have been published in the past and have been adapted to the
Prototorch library.
Customizable Customizable
----------------------------------------- -----------------------------------------
Prototorch Models also contains the building blocks to build own models with PyTorch-Lightning and Prototorch. Prototorch Models also contains the building blocks to build own models with
PyTorch-Lightning and Prototorch.

View File

@ -71,7 +71,7 @@ Probabilistic Models
Probabilistic variants assume, that the prototypes generate a probability distribution over the classes. Probabilistic variants assume, that the prototypes generate a probability distribution over the classes.
For a test sample they return a distribution instead of a class assignment. For a test sample they return a distribution instead of a class assignment.
The following two algorihms were presented by :cite:t:`seo2003` . The following two algorithms were presented by :cite:t:`seo2003` .
Every prototypes is a center of a gaussian distribution of its class, generating a mixture model. Every prototypes is a center of a gaussian distribution of its class, generating a mixture model.
.. autoclass:: prototorch.models.probabilistic.SLVQ .. autoclass:: prototorch.models.probabilistic.SLVQ
@ -80,7 +80,7 @@ Every prototypes is a center of a gaussian distribution of its class, generating
.. autoclass:: prototorch.models.probabilistic.RSLVQ .. autoclass:: prototorch.models.probabilistic.RSLVQ
:members: :members:
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incooperate the winning rank into the prior probability calculation. :cite:t:`villmann2018` proposed two changes to RSLVQ: First incorporate the winning rank into the prior probability calculation.
And second use divergence as loss function. And second use divergence as loss function.
.. autoclass:: prototorch.models.probabilistic.PLVQ .. autoclass:: prototorch.models.probabilistic.PLVQ
@ -106,7 +106,7 @@ Visualization
Visualization is very specific to its application. Visualization is very specific to its application.
PrototorchModels delivers visualization for two dimensional data and image data. PrototorchModels delivers visualization for two dimensional data and image data.
The visulizations can be shown in a seperate window and inside a tensorboard. The visualizations can be shown in a separate window and inside a tensorboard.
.. automodule:: prototorch.models.vis .. automodule:: prototorch.models.vis
:members: :members:

View File

@ -0,0 +1,71 @@
.. Documentation of the updated Architecture.
Proto Y Architecture
========================================
Overview
****************************************
The Proto Y Architecture is a framework for abstract prototype learning methods.
It divides the problem into multiple steps:
* **Components** : Recalling the position and metadata of the components/prototypes.
* **Backbone** : Apply a mapping function to data and prototypes.
* **Comparison** : Calculate a dissimilarity based on the latent positions.
* **Competition** : Calculate competition values based on the comparison and the metadata.
* **Loss** : Calculate the loss based on the competition values
* **Inference** : Predict the output based on the competition values.
Depending on the phase (Training or Testing) Loss or Inference is used.
Inheritance Structure
****************************************
The Proto Y Architecture has a single base class that defines all steps and hooks
of the architecture.
.. autoclass:: prototorch.y.architectures.base.BaseYArchitecture
**Steps**
Components
.. automethod:: init_components
.. automethod:: components
Backbone
.. automethod:: init_backbone
.. automethod:: backbone
Comparison
.. automethod:: init_comparison
.. automethod:: comparison
Competition
.. automethod:: init_competition
.. automethod:: competition
Loss
.. automethod:: init_loss
.. automethod:: loss
Inference
.. automethod:: init_inference
.. automethod:: inference
**Hooks**
Torchmetric
.. automethod:: register_torchmetric
Hyperparameters
****************************************
Every model implemented with the Proto Y Architecture has a set of hyperparameters,
which is stored in the ``HyperParameters`` attribute of the architecture.

View File

@ -1,7 +1,10 @@
import logging
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torchmetrics import torchmetrics
from prototorch.core import SMCI from prototorch.core import SMCI
from prototorch.y.architectures.base import Steps
from prototorch.y.callbacks import ( from prototorch.y.callbacks import (
LogTorchmetricCallback, LogTorchmetricCallback,
PlotLambdaMatrixToTensorboard, PlotLambdaMatrixToTensorboard,
@ -9,7 +12,9 @@ from prototorch.y.callbacks import (
) )
from prototorch.y.library.gmlvq import GMLVQ from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, random_split
logging.basicConfig(level=logging.INFO)
# ############################################################################## # ##############################################################################
@ -20,22 +25,42 @@ def main():
# ------------------------------------------------------------ # ------------------------------------------------------------
# Dataset # Dataset
train_ds = pt.datasets.Iris() full_dataset = pt.datasets.Iris()
full_count = len(full_dataset)
train_count = int(full_count * 0.5)
val_count = int(full_count * 0.4)
test_count = int(full_count * 0.1)
train_dataset, val_dataset, test_dataset = random_split(
full_dataset, (train_count, val_count, test_count))
# Dataloader # Dataloader
train_loader = DataLoader( train_loader = DataLoader(
train_ds, train_dataset,
batch_size=32, batch_size=1,
num_workers=0, num_workers=4,
shuffle=True, shuffle=True,
) )
val_loader = DataLoader(
val_dataset,
batch_size=1,
num_workers=4,
shuffle=False,
)
test_loader = DataLoader(
test_dataset,
batch_size=1,
num_workers=0,
shuffle=False,
)
# ------------------------------------------------------------ # ------------------------------------------------------------
# HYPERPARAMETERS # HYPERPARAMETERS
# ------------------------------------------------------------ # ------------------------------------------------------------
# Select Initializer # Select Initializer
components_initializer = SMCI(train_ds) components_initializer = SMCI(full_dataset)
# Define Hyperparameters # Define Hyperparameters
hyperparameters = GMLVQ.HyperParameters( hyperparameters = GMLVQ.HyperParameters(
@ -51,17 +76,23 @@ def main():
# Create Model # Create Model
model = GMLVQ(hyperparameters) model = GMLVQ(hyperparameters)
print(model.hparams)
# ------------------------------------------------------------ # ------------------------------------------------------------
# TRAINING # TRAINING
# ------------------------------------------------------------ # ------------------------------------------------------------
# Controlling Callbacks # Controlling Callbacks
stopping_criterion = LogTorchmetricCallback( recall = LogTorchmetricCallback(
'recall', 'training_recall',
torchmetrics.Recall, torchmetrics.Recall,
num_classes=3, num_classes=3,
step=Steps.TRAINING,
)
stopping_criterion = LogTorchmetricCallback(
'validation_recall',
torchmetrics.Recall,
num_classes=3,
step=Steps.VALIDATION,
) )
es = EarlyStopping( es = EarlyStopping(
@ -71,18 +102,23 @@ def main():
) )
# Visualization Callback # Visualization Callback
vis = VisGMLVQ2D(data=train_ds) vis = VisGMLVQ2D(data=full_dataset)
# Define trainer # Define trainer
trainer = pl.Trainer(callbacks=[ trainer = pl.Trainer(
vis, callbacks=[
stopping_criterion, vis,
es, recall,
PlotLambdaMatrixToTensorboard(), stopping_criterion,
], ) es,
PlotLambdaMatrixToTensorboard(),
],
max_epochs=100,
)
# Train # Train
trainer.fit(model, train_loader) trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)
# Manual save # Manual save
trainer.save_checkpoint("./y_arch.ckpt") trainer.save_checkpoint("./y_arch.ckpt")
@ -93,8 +129,6 @@ def main():
strict=True, strict=True,
) )
print(new_model.hparams)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -13,15 +13,34 @@ import torch
from torchmetrics import Metric from torchmetrics import Metric
class Steps(enumerate):
TRAINING = "training"
VALIDATION = "validation"
TEST = "test"
PREDICT = "predict"
class BaseYArchitecture(pl.LightningModule): class BaseYArchitecture(pl.LightningModule):
@dataclass @dataclass
class HyperParameters: class HyperParameters:
"""
Add all hyperparameters in the inherited class.
"""
... ...
# Fields # Fields
registered_metrics: dict[type[Metric], Metric] = {} registered_metrics: dict[str, dict[type[Metric], Metric]] = {
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {} Steps.TRAINING: {},
Steps.VALIDATION: {},
Steps.TEST: {},
}
registered_metric_callbacks: dict[str, dict[type[Metric],
set[Callable]]] = {
Steps.TRAINING: {},
Steps.VALIDATION: {},
Steps.TEST: {},
}
# Type Hints for Necessary Fields # Type Hints for Necessary Fields
components_layer: torch.nn.Module components_layer: torch.nn.Module
@ -41,7 +60,7 @@ class BaseYArchitecture(pl.LightningModule):
# Common Steps # Common Steps
self.init_components(hparams) self.init_components(hparams)
self.init_latent(hparams) self.init_backbone(hparams)
self.init_comparison(hparams) self.init_comparison(hparams)
self.init_competition(hparams) self.init_competition(hparams)
@ -53,7 +72,7 @@ class BaseYArchitecture(pl.LightningModule):
# external API # external API
def get_competition(self, batch, components): def get_competition(self, batch, components):
latent_batch, latent_components = self.latent(batch, components) latent_batch, latent_components = self.backbone(batch, components)
# TODO: => Latent Hook # TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components) comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook # TODO: => Comparison Hook
@ -92,27 +111,43 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss(comparison_tensor, batch, components) return self.loss(comparison_tensor, batch, components)
# Empty Initialization # Empty Initialization
# TODO: Docs
def init_components(self, hparams: HyperParameters) -> None: def init_components(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the components step.
"""
... ...
def init_latent(self, hparams: HyperParameters) -> None: def init_backbone(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the backbone step.
"""
... ...
def init_comparison(self, hparams: HyperParameters) -> None: def init_comparison(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the comparison step.
"""
... ...
def init_competition(self, hparams: HyperParameters) -> None: def init_competition(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the competition step.
"""
... ...
def init_loss(self, hparams: HyperParameters) -> None: def init_loss(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the loss step.
"""
... ...
def init_inference(self, hparams: HyperParameters) -> None: def init_inference(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the inference step.
"""
... ...
# Empty Steps # Empty Steps
# TODO: Type hints
def components(self): def components(self):
""" """
This step has no input. This step has no input.
@ -122,9 +157,9 @@ class BaseYArchitecture(pl.LightningModule):
raise NotImplementedError( raise NotImplementedError(
"The components step has no reasonable default.") "The components step has no reasonable default.")
def latent(self, batch, components): def backbone(self, batch, components):
""" """
The latent step receives the data batch and the components. The backbone step receives the data batch and the components.
It can transform both by an arbitrary function. It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input. It returns the transformed batch and components, each of the same length as the original input.
@ -173,52 +208,72 @@ class BaseYArchitecture(pl.LightningModule):
self, self,
name: Callable, name: Callable,
metric: type[Metric], metric: type[Metric],
step: str = Steps.TRAINING,
**metric_kwargs, **metric_kwargs,
): ):
if metric not in self.registered_metrics: if step == Steps.PREDICT:
self.registered_metrics[metric] = metric(**metric_kwargs) raise ValueError("Prediction metrics are not supported.")
self.registered_metric_callbacks[metric] = {name}
else:
self.registered_metric_callbacks[metric].add(name)
def update_metrics_step(self, batch): if metric not in self.registered_metrics:
self.registered_metrics[step][metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[step][metric] = {name}
else:
self.registered_metric_callbacks[step][metric].add(name)
def update_metrics_step(self, batch, step):
# Prediction Metrics # Prediction Metrics
preds = self(batch) preds = self(batch)
x, y = batch x, y = batch
for metric in self.registered_metrics: for metric in self.registered_metrics[step]:
instance = self.registered_metrics[metric].to(self.device) instance = self.registered_metrics[step][metric].to(self.device)
instance(y, preds) instance(y, preds)
def update_metrics_epoch(self): def update_metrics_epoch(self, step):
for metric in self.registered_metrics: for metric in self.registered_metrics[step]:
instance = self.registered_metrics[metric].to(self.device) instance = self.registered_metrics[step][metric].to(self.device)
value = instance.compute() value = instance.compute()
for callback in self.registered_metric_callbacks[metric]: for callback in self.registered_metric_callbacks[step][metric]:
callback(value, self) callback(value, self)
instance.reset() instance.reset()
# Lightning Hooks # Lightning steps
# -------------------------------------------------------------------------
# Steps # >>>> Training
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step([torch.clone(el) for el in batch]) self.update_metrics_step(batch, Steps.TRAINING)
return self.loss_forward(batch) return self.loss_forward(batch)
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)
# Other Hooks
def training_epoch_end(self, outs) -> None: def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch() self.update_metrics_epoch(Steps.TRAINING)
# >>>> Validation
def validation_step(self, batch, batch_idx):
self.update_metrics_step(batch, Steps.VALIDATION)
return self.loss_forward(batch)
def validation_epoch_end(self, outs) -> None:
self.update_metrics_epoch(Steps.VALIDATION)
# >>>> Test
def test_step(self, batch, batch_idx):
self.update_metrics_step(batch, Steps.TEST)
return self.loss_forward(batch)
def test_epoch_end(self, outs) -> None:
self.update_metrics_epoch(Steps.TEST)
# >>>> Prediction
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self.predict(batch)
# Check points
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
# Compatible with Lightning
checkpoint["hyper_parameters"] = { checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"] 'hparams': checkpoint["hyper_parameters"]
} }

View File

@ -1,3 +1,4 @@
import logging
import warnings import warnings
from typing import Optional, Type from typing import Optional, Type
@ -8,7 +9,7 @@ import torchmetrics
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.models.vis import Vis2DAbstract from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d from prototorch.utils.utils import mesh2d
from prototorch.y.architectures.base import BaseYArchitecture from prototorch.y.architectures.base import BaseYArchitecture, Steps
from prototorch.y.library.gmlvq import GMLVQ from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
@ -34,13 +35,13 @@ class LogTorchmetricCallback(pl.Callback):
self, self,
name, name,
metric: Type[torchmetrics.Metric], metric: Type[torchmetrics.Metric],
on="prediction", step: str = Steps.TRAINING,
**metric_kwargs, **metric_kwargs,
) -> None: ) -> None:
self.name = name self.name = name
self.metric = metric self.metric = metric
self.metric_kwargs = metric_kwargs self.metric_kwargs = metric_kwargs
self.on = on self.step = step
def setup( def setup(
self, self,
@ -48,14 +49,12 @@ class LogTorchmetricCallback(pl.Callback):
pl_module: BaseYArchitecture, pl_module: BaseYArchitecture,
stage: Optional[str] = None, stage: Optional[str] = None,
) -> None: ) -> None:
if self.on == "prediction": pl_module.register_torchmetric(
pl_module.register_torchmetric( self,
self, self.metric,
self.metric, step=self.step,
**self.metric_kwargs, **self.metric_kwargs,
) )
else:
raise ValueError(f"{self.on} is no valid metric hook")
def __call__(self, value, pl_module: BaseYArchitecture): def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value) pl_module.log(self.name, value)