feat: metrics can be assigned to the different phases
This commit is contained in:
parent
46ec7b07d7
commit
736565b768
@ -23,6 +23,13 @@ ProtoTorch Models Plugins
|
||||
|
||||
custom
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 3
|
||||
:caption: Proto Y Architecture
|
||||
|
||||
y-architecture
|
||||
|
||||
About
|
||||
-----------------------------------------
|
||||
`Prototorch Models <https://github.com/si-cim/prototorch_models>`_ is a Plugin
|
||||
@ -33,8 +40,10 @@ prototype-based Machine Learning algorithms using `PyTorch-Lightning
|
||||
Library
|
||||
-----------------------------------------
|
||||
Prototorch Models delivers many application ready models.
|
||||
These models have been published in the past and have been adapted to the Prototorch library.
|
||||
These models have been published in the past and have been adapted to the
|
||||
Prototorch library.
|
||||
|
||||
Customizable
|
||||
-----------------------------------------
|
||||
Prototorch Models also contains the building blocks to build own models with PyTorch-Lightning and Prototorch.
|
||||
Prototorch Models also contains the building blocks to build own models with
|
||||
PyTorch-Lightning and Prototorch.
|
||||
|
@ -71,7 +71,7 @@ Probabilistic Models
|
||||
Probabilistic variants assume, that the prototypes generate a probability distribution over the classes.
|
||||
For a test sample they return a distribution instead of a class assignment.
|
||||
|
||||
The following two algorihms were presented by :cite:t:`seo2003` .
|
||||
The following two algorithms were presented by :cite:t:`seo2003` .
|
||||
Every prototypes is a center of a gaussian distribution of its class, generating a mixture model.
|
||||
|
||||
.. autoclass:: prototorch.models.probabilistic.SLVQ
|
||||
@ -80,7 +80,7 @@ Every prototypes is a center of a gaussian distribution of its class, generating
|
||||
.. autoclass:: prototorch.models.probabilistic.RSLVQ
|
||||
:members:
|
||||
|
||||
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incooperate the winning rank into the prior probability calculation.
|
||||
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incorporate the winning rank into the prior probability calculation.
|
||||
And second use divergence as loss function.
|
||||
|
||||
.. autoclass:: prototorch.models.probabilistic.PLVQ
|
||||
@ -106,7 +106,7 @@ Visualization
|
||||
Visualization is very specific to its application.
|
||||
PrototorchModels delivers visualization for two dimensional data and image data.
|
||||
|
||||
The visulizations can be shown in a seperate window and inside a tensorboard.
|
||||
The visualizations can be shown in a separate window and inside a tensorboard.
|
||||
|
||||
.. automodule:: prototorch.models.vis
|
||||
:members:
|
||||
|
71
docs/source/y-architecture.rst
Normal file
71
docs/source/y-architecture.rst
Normal file
@ -0,0 +1,71 @@
|
||||
.. Documentation of the updated Architecture.
|
||||
|
||||
Proto Y Architecture
|
||||
========================================
|
||||
|
||||
Overview
|
||||
****************************************
|
||||
|
||||
The Proto Y Architecture is a framework for abstract prototype learning methods.
|
||||
|
||||
It divides the problem into multiple steps:
|
||||
|
||||
* **Components** : Recalling the position and metadata of the components/prototypes.
|
||||
* **Backbone** : Apply a mapping function to data and prototypes.
|
||||
* **Comparison** : Calculate a dissimilarity based on the latent positions.
|
||||
* **Competition** : Calculate competition values based on the comparison and the metadata.
|
||||
* **Loss** : Calculate the loss based on the competition values
|
||||
* **Inference** : Predict the output based on the competition values.
|
||||
|
||||
Depending on the phase (Training or Testing) Loss or Inference is used.
|
||||
|
||||
Inheritance Structure
|
||||
****************************************
|
||||
|
||||
The Proto Y Architecture has a single base class that defines all steps and hooks
|
||||
of the architecture.
|
||||
|
||||
.. autoclass:: prototorch.y.architectures.base.BaseYArchitecture
|
||||
|
||||
**Steps**
|
||||
|
||||
Components
|
||||
|
||||
.. automethod:: init_components
|
||||
.. automethod:: components
|
||||
|
||||
Backbone
|
||||
|
||||
.. automethod:: init_backbone
|
||||
.. automethod:: backbone
|
||||
|
||||
Comparison
|
||||
|
||||
.. automethod:: init_comparison
|
||||
.. automethod:: comparison
|
||||
|
||||
Competition
|
||||
|
||||
.. automethod:: init_competition
|
||||
.. automethod:: competition
|
||||
|
||||
Loss
|
||||
|
||||
.. automethod:: init_loss
|
||||
.. automethod:: loss
|
||||
|
||||
Inference
|
||||
|
||||
.. automethod:: init_inference
|
||||
.. automethod:: inference
|
||||
|
||||
**Hooks**
|
||||
|
||||
Torchmetric
|
||||
|
||||
.. automethod:: register_torchmetric
|
||||
|
||||
Hyperparameters
|
||||
****************************************
|
||||
Every model implemented with the Proto Y Architecture has a set of hyperparameters,
|
||||
which is stored in the ``HyperParameters`` attribute of the architecture.
|
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torchmetrics
|
||||
from prototorch.core import SMCI
|
||||
from prototorch.y.architectures.base import Steps
|
||||
from prototorch.y.callbacks import (
|
||||
LogTorchmetricCallback,
|
||||
PlotLambdaMatrixToTensorboard,
|
||||
@ -9,7 +12,9 @@ from prototorch.y.callbacks import (
|
||||
)
|
||||
from prototorch.y.library.gmlvq import GMLVQ
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# ##############################################################################
|
||||
|
||||
@ -20,22 +25,42 @@ def main():
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris()
|
||||
full_dataset = pt.datasets.Iris()
|
||||
full_count = len(full_dataset)
|
||||
|
||||
train_count = int(full_count * 0.5)
|
||||
val_count = int(full_count * 0.4)
|
||||
test_count = int(full_count * 0.1)
|
||||
|
||||
train_dataset, val_dataset, test_dataset = random_split(
|
||||
full_dataset, (train_count, val_count, test_count))
|
||||
|
||||
# Dataloader
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=32,
|
||||
num_workers=0,
|
||||
train_dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
shuffle=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
shuffle=False,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
num_workers=0,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# HYPERPARAMETERS
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Select Initializer
|
||||
components_initializer = SMCI(train_ds)
|
||||
components_initializer = SMCI(full_dataset)
|
||||
|
||||
# Define Hyperparameters
|
||||
hyperparameters = GMLVQ.HyperParameters(
|
||||
@ -51,17 +76,23 @@ def main():
|
||||
# Create Model
|
||||
model = GMLVQ(hyperparameters)
|
||||
|
||||
print(model.hparams)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# TRAINING
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Controlling Callbacks
|
||||
stopping_criterion = LogTorchmetricCallback(
|
||||
'recall',
|
||||
recall = LogTorchmetricCallback(
|
||||
'training_recall',
|
||||
torchmetrics.Recall,
|
||||
num_classes=3,
|
||||
step=Steps.TRAINING,
|
||||
)
|
||||
|
||||
stopping_criterion = LogTorchmetricCallback(
|
||||
'validation_recall',
|
||||
torchmetrics.Recall,
|
||||
num_classes=3,
|
||||
step=Steps.VALIDATION,
|
||||
)
|
||||
|
||||
es = EarlyStopping(
|
||||
@ -71,18 +102,23 @@ def main():
|
||||
)
|
||||
|
||||
# Visualization Callback
|
||||
vis = VisGMLVQ2D(data=train_ds)
|
||||
vis = VisGMLVQ2D(data=full_dataset)
|
||||
|
||||
# Define trainer
|
||||
trainer = pl.Trainer(callbacks=[
|
||||
trainer = pl.Trainer(
|
||||
callbacks=[
|
||||
vis,
|
||||
recall,
|
||||
stopping_criterion,
|
||||
es,
|
||||
PlotLambdaMatrixToTensorboard(),
|
||||
], )
|
||||
],
|
||||
max_epochs=100,
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.fit(model, train_loader)
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
trainer.test(model, test_loader)
|
||||
|
||||
# Manual save
|
||||
trainer.save_checkpoint("./y_arch.ckpt")
|
||||
@ -93,8 +129,6 @@ def main():
|
||||
strict=True,
|
||||
)
|
||||
|
||||
print(new_model.hparams)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -13,15 +13,34 @@ import torch
|
||||
from torchmetrics import Metric
|
||||
|
||||
|
||||
class Steps(enumerate):
|
||||
TRAINING = "training"
|
||||
VALIDATION = "validation"
|
||||
TEST = "test"
|
||||
PREDICT = "predict"
|
||||
|
||||
|
||||
class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
@dataclass
|
||||
class HyperParameters:
|
||||
"""
|
||||
Add all hyperparameters in the inherited class.
|
||||
"""
|
||||
...
|
||||
|
||||
# Fields
|
||||
registered_metrics: dict[type[Metric], Metric] = {}
|
||||
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
|
||||
registered_metrics: dict[str, dict[type[Metric], Metric]] = {
|
||||
Steps.TRAINING: {},
|
||||
Steps.VALIDATION: {},
|
||||
Steps.TEST: {},
|
||||
}
|
||||
registered_metric_callbacks: dict[str, dict[type[Metric],
|
||||
set[Callable]]] = {
|
||||
Steps.TRAINING: {},
|
||||
Steps.VALIDATION: {},
|
||||
Steps.TEST: {},
|
||||
}
|
||||
|
||||
# Type Hints for Necessary Fields
|
||||
components_layer: torch.nn.Module
|
||||
@ -41,7 +60,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
# Common Steps
|
||||
self.init_components(hparams)
|
||||
self.init_latent(hparams)
|
||||
self.init_backbone(hparams)
|
||||
self.init_comparison(hparams)
|
||||
self.init_competition(hparams)
|
||||
|
||||
@ -53,7 +72,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
# external API
|
||||
def get_competition(self, batch, components):
|
||||
latent_batch, latent_components = self.latent(batch, components)
|
||||
latent_batch, latent_components = self.backbone(batch, components)
|
||||
# TODO: => Latent Hook
|
||||
comparison_tensor = self.comparison(latent_batch, latent_components)
|
||||
# TODO: => Comparison Hook
|
||||
@ -92,27 +111,43 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
return self.loss(comparison_tensor, batch, components)
|
||||
|
||||
# Empty Initialization
|
||||
# TODO: Docs
|
||||
def init_components(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the components step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_latent(self, hparams: HyperParameters) -> None:
|
||||
def init_backbone(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the backbone step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the comparison step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_competition(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the competition step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_loss(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the loss step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_inference(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the inference step.
|
||||
"""
|
||||
...
|
||||
|
||||
# Empty Steps
|
||||
# TODO: Type hints
|
||||
def components(self):
|
||||
"""
|
||||
This step has no input.
|
||||
@ -122,9 +157,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
raise NotImplementedError(
|
||||
"The components step has no reasonable default.")
|
||||
|
||||
def latent(self, batch, components):
|
||||
def backbone(self, batch, components):
|
||||
"""
|
||||
The latent step receives the data batch and the components.
|
||||
The backbone step receives the data batch and the components.
|
||||
It can transform both by an arbitrary function.
|
||||
|
||||
It returns the transformed batch and components, each of the same length as the original input.
|
||||
@ -173,52 +208,72 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
self,
|
||||
name: Callable,
|
||||
metric: type[Metric],
|
||||
step: str = Steps.TRAINING,
|
||||
**metric_kwargs,
|
||||
):
|
||||
if metric not in self.registered_metrics:
|
||||
self.registered_metrics[metric] = metric(**metric_kwargs)
|
||||
self.registered_metric_callbacks[metric] = {name}
|
||||
else:
|
||||
self.registered_metric_callbacks[metric].add(name)
|
||||
if step == Steps.PREDICT:
|
||||
raise ValueError("Prediction metrics are not supported.")
|
||||
|
||||
def update_metrics_step(self, batch):
|
||||
if metric not in self.registered_metrics:
|
||||
self.registered_metrics[step][metric] = metric(**metric_kwargs)
|
||||
self.registered_metric_callbacks[step][metric] = {name}
|
||||
else:
|
||||
self.registered_metric_callbacks[step][metric].add(name)
|
||||
|
||||
def update_metrics_step(self, batch, step):
|
||||
# Prediction Metrics
|
||||
preds = self(batch)
|
||||
|
||||
x, y = batch
|
||||
for metric in self.registered_metrics:
|
||||
instance = self.registered_metrics[metric].to(self.device)
|
||||
for metric in self.registered_metrics[step]:
|
||||
instance = self.registered_metrics[step][metric].to(self.device)
|
||||
instance(y, preds)
|
||||
|
||||
def update_metrics_epoch(self):
|
||||
for metric in self.registered_metrics:
|
||||
instance = self.registered_metrics[metric].to(self.device)
|
||||
def update_metrics_epoch(self, step):
|
||||
for metric in self.registered_metrics[step]:
|
||||
instance = self.registered_metrics[step][metric].to(self.device)
|
||||
value = instance.compute()
|
||||
|
||||
for callback in self.registered_metric_callbacks[metric]:
|
||||
for callback in self.registered_metric_callbacks[step][metric]:
|
||||
callback(value, self)
|
||||
|
||||
instance.reset()
|
||||
|
||||
# Lightning Hooks
|
||||
|
||||
# Steps
|
||||
# Lightning steps
|
||||
# -------------------------------------------------------------------------
|
||||
# >>>> Training
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
self.update_metrics_step([torch.clone(el) for el in batch])
|
||||
self.update_metrics_step(batch, Steps.TRAINING)
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self.loss_forward(batch)
|
||||
|
||||
# Other Hooks
|
||||
def training_epoch_end(self, outs) -> None:
|
||||
self.update_metrics_epoch()
|
||||
self.update_metrics_epoch(Steps.TRAINING)
|
||||
|
||||
# >>>> Validation
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.update_metrics_step(batch, Steps.VALIDATION)
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def validation_epoch_end(self, outs) -> None:
|
||||
self.update_metrics_epoch(Steps.VALIDATION)
|
||||
|
||||
# >>>> Test
|
||||
def test_step(self, batch, batch_idx):
|
||||
self.update_metrics_step(batch, Steps.TEST)
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def test_epoch_end(self, outs) -> None:
|
||||
self.update_metrics_epoch(Steps.TEST)
|
||||
|
||||
# >>>> Prediction
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
return self.predict(batch)
|
||||
|
||||
# Check points
|
||||
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
||||
# Compatible with Lightning
|
||||
checkpoint["hyper_parameters"] = {
|
||||
'hparams': checkpoint["hyper_parameters"]
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Optional, Type
|
||||
|
||||
@ -8,7 +9,7 @@ import torchmetrics
|
||||
from matplotlib import pyplot as plt
|
||||
from prototorch.models.vis import Vis2DAbstract
|
||||
from prototorch.utils.utils import mesh2d
|
||||
from prototorch.y.architectures.base import BaseYArchitecture
|
||||
from prototorch.y.architectures.base import BaseYArchitecture, Steps
|
||||
from prototorch.y.library.gmlvq import GMLVQ
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
||||
@ -34,13 +35,13 @@ class LogTorchmetricCallback(pl.Callback):
|
||||
self,
|
||||
name,
|
||||
metric: Type[torchmetrics.Metric],
|
||||
on="prediction",
|
||||
step: str = Steps.TRAINING,
|
||||
**metric_kwargs,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.metric = metric
|
||||
self.metric_kwargs = metric_kwargs
|
||||
self.on = on
|
||||
self.step = step
|
||||
|
||||
def setup(
|
||||
self,
|
||||
@ -48,14 +49,12 @@ class LogTorchmetricCallback(pl.Callback):
|
||||
pl_module: BaseYArchitecture,
|
||||
stage: Optional[str] = None,
|
||||
) -> None:
|
||||
if self.on == "prediction":
|
||||
pl_module.register_torchmetric(
|
||||
self,
|
||||
self.metric,
|
||||
step=self.step,
|
||||
**self.metric_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{self.on} is no valid metric hook")
|
||||
|
||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||
pl_module.log(self.name, value)
|
||||
|
Loading…
Reference in New Issue
Block a user