From 736565b768666321761776ee9001d84e9fe16f41 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Fri, 24 Jun 2022 15:04:35 +0200 Subject: [PATCH] feat: metrics can be assigned to the different phases --- docs/source/index.rst | 13 +++- docs/source/library.rst | 6 +- docs/source/y-architecture.rst | 71 +++++++++++++++++ examples/y_architecture_example.py | 74 +++++++++++++----- prototorch/y/architectures/base.py | 121 +++++++++++++++++++++-------- prototorch/y/callbacks.py | 21 +++-- 6 files changed, 237 insertions(+), 69 deletions(-) create mode 100644 docs/source/y-architecture.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index fbb2ff8..9d187b2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,13 @@ ProtoTorch Models Plugins custom +.. toctree:: + :hidden: + :maxdepth: 3 + :caption: Proto Y Architecture + + y-architecture + About ----------------------------------------- `Prototorch Models `_ is a Plugin @@ -33,8 +40,10 @@ prototype-based Machine Learning algorithms using `PyTorch-Lightning Library ----------------------------------------- Prototorch Models delivers many application ready models. -These models have been published in the past and have been adapted to the Prototorch library. +These models have been published in the past and have been adapted to the +Prototorch library. Customizable ----------------------------------------- -Prototorch Models also contains the building blocks to build own models with PyTorch-Lightning and Prototorch. +Prototorch Models also contains the building blocks to build own models with +PyTorch-Lightning and Prototorch. diff --git a/docs/source/library.rst b/docs/source/library.rst index a9007db..8d00218 100644 --- a/docs/source/library.rst +++ b/docs/source/library.rst @@ -71,7 +71,7 @@ Probabilistic Models Probabilistic variants assume, that the prototypes generate a probability distribution over the classes. For a test sample they return a distribution instead of a class assignment. -The following two algorihms were presented by :cite:t:`seo2003` . +The following two algorithms were presented by :cite:t:`seo2003` . Every prototypes is a center of a gaussian distribution of its class, generating a mixture model. .. autoclass:: prototorch.models.probabilistic.SLVQ @@ -80,7 +80,7 @@ Every prototypes is a center of a gaussian distribution of its class, generating .. autoclass:: prototorch.models.probabilistic.RSLVQ :members: -:cite:t:`villmann2018` proposed two changes to RSLVQ: First incooperate the winning rank into the prior probability calculation. +:cite:t:`villmann2018` proposed two changes to RSLVQ: First incorporate the winning rank into the prior probability calculation. And second use divergence as loss function. .. autoclass:: prototorch.models.probabilistic.PLVQ @@ -106,7 +106,7 @@ Visualization Visualization is very specific to its application. PrototorchModels delivers visualization for two dimensional data and image data. -The visulizations can be shown in a seperate window and inside a tensorboard. +The visualizations can be shown in a separate window and inside a tensorboard. .. automodule:: prototorch.models.vis :members: diff --git a/docs/source/y-architecture.rst b/docs/source/y-architecture.rst new file mode 100644 index 0000000..9637043 --- /dev/null +++ b/docs/source/y-architecture.rst @@ -0,0 +1,71 @@ +.. Documentation of the updated Architecture. + +Proto Y Architecture +======================================== + +Overview +**************************************** + +The Proto Y Architecture is a framework for abstract prototype learning methods. + +It divides the problem into multiple steps: + + * **Components** : Recalling the position and metadata of the components/prototypes. + * **Backbone** : Apply a mapping function to data and prototypes. + * **Comparison** : Calculate a dissimilarity based on the latent positions. + * **Competition** : Calculate competition values based on the comparison and the metadata. + * **Loss** : Calculate the loss based on the competition values + * **Inference** : Predict the output based on the competition values. + +Depending on the phase (Training or Testing) Loss or Inference is used. + +Inheritance Structure +**************************************** + +The Proto Y Architecture has a single base class that defines all steps and hooks +of the architecture. + +.. autoclass:: prototorch.y.architectures.base.BaseYArchitecture + + **Steps** + + Components + + .. automethod:: init_components + .. automethod:: components + + Backbone + + .. automethod:: init_backbone + .. automethod:: backbone + + Comparison + + .. automethod:: init_comparison + .. automethod:: comparison + + Competition + + .. automethod:: init_competition + .. automethod:: competition + + Loss + + .. automethod:: init_loss + .. automethod:: loss + + Inference + + .. automethod:: init_inference + .. automethod:: inference + + **Hooks** + + Torchmetric + + .. automethod:: register_torchmetric + +Hyperparameters +**************************************** +Every model implemented with the Proto Y Architecture has a set of hyperparameters, +which is stored in the ``HyperParameters`` attribute of the architecture. diff --git a/examples/y_architecture_example.py b/examples/y_architecture_example.py index b66fce9..d704dc9 100644 --- a/examples/y_architecture_example.py +++ b/examples/y_architecture_example.py @@ -1,7 +1,10 @@ +import logging + import prototorch as pt import pytorch_lightning as pl import torchmetrics from prototorch.core import SMCI +from prototorch.y.architectures.base import Steps from prototorch.y.callbacks import ( LogTorchmetricCallback, PlotLambdaMatrixToTensorboard, @@ -9,7 +12,9 @@ from prototorch.y.callbacks import ( ) from prototorch.y.library.gmlvq import GMLVQ from pytorch_lightning.callbacks import EarlyStopping -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, random_split + +logging.basicConfig(level=logging.INFO) # ############################################################################## @@ -20,22 +25,42 @@ def main(): # ------------------------------------------------------------ # Dataset - train_ds = pt.datasets.Iris() + full_dataset = pt.datasets.Iris() + full_count = len(full_dataset) + + train_count = int(full_count * 0.5) + val_count = int(full_count * 0.4) + test_count = int(full_count * 0.1) + + train_dataset, val_dataset, test_dataset = random_split( + full_dataset, (train_count, val_count, test_count)) # Dataloader train_loader = DataLoader( - train_ds, - batch_size=32, - num_workers=0, + train_dataset, + batch_size=1, + num_workers=4, shuffle=True, ) + val_loader = DataLoader( + val_dataset, + batch_size=1, + num_workers=4, + shuffle=False, + ) + test_loader = DataLoader( + test_dataset, + batch_size=1, + num_workers=0, + shuffle=False, + ) # ------------------------------------------------------------ # HYPERPARAMETERS # ------------------------------------------------------------ # Select Initializer - components_initializer = SMCI(train_ds) + components_initializer = SMCI(full_dataset) # Define Hyperparameters hyperparameters = GMLVQ.HyperParameters( @@ -51,17 +76,23 @@ def main(): # Create Model model = GMLVQ(hyperparameters) - print(model.hparams) - # ------------------------------------------------------------ # TRAINING # ------------------------------------------------------------ # Controlling Callbacks - stopping_criterion = LogTorchmetricCallback( - 'recall', + recall = LogTorchmetricCallback( + 'training_recall', torchmetrics.Recall, num_classes=3, + step=Steps.TRAINING, + ) + + stopping_criterion = LogTorchmetricCallback( + 'validation_recall', + torchmetrics.Recall, + num_classes=3, + step=Steps.VALIDATION, ) es = EarlyStopping( @@ -71,18 +102,23 @@ def main(): ) # Visualization Callback - vis = VisGMLVQ2D(data=train_ds) + vis = VisGMLVQ2D(data=full_dataset) # Define trainer - trainer = pl.Trainer(callbacks=[ - vis, - stopping_criterion, - es, - PlotLambdaMatrixToTensorboard(), - ], ) + trainer = pl.Trainer( + callbacks=[ + vis, + recall, + stopping_criterion, + es, + PlotLambdaMatrixToTensorboard(), + ], + max_epochs=100, + ) # Train - trainer.fit(model, train_loader) + trainer.fit(model, train_loader, val_loader) + trainer.test(model, test_loader) # Manual save trainer.save_checkpoint("./y_arch.ckpt") @@ -93,8 +129,6 @@ def main(): strict=True, ) - print(new_model.hparams) - if __name__ == "__main__": main() diff --git a/prototorch/y/architectures/base.py b/prototorch/y/architectures/base.py index ae47eda..1ca59b6 100644 --- a/prototorch/y/architectures/base.py +++ b/prototorch/y/architectures/base.py @@ -13,15 +13,34 @@ import torch from torchmetrics import Metric +class Steps(enumerate): + TRAINING = "training" + VALIDATION = "validation" + TEST = "test" + PREDICT = "predict" + + class BaseYArchitecture(pl.LightningModule): @dataclass class HyperParameters: + """ + Add all hyperparameters in the inherited class. + """ ... # Fields - registered_metrics: dict[type[Metric], Metric] = {} - registered_metric_callbacks: dict[type[Metric], set[Callable]] = {} + registered_metrics: dict[str, dict[type[Metric], Metric]] = { + Steps.TRAINING: {}, + Steps.VALIDATION: {}, + Steps.TEST: {}, + } + registered_metric_callbacks: dict[str, dict[type[Metric], + set[Callable]]] = { + Steps.TRAINING: {}, + Steps.VALIDATION: {}, + Steps.TEST: {}, + } # Type Hints for Necessary Fields components_layer: torch.nn.Module @@ -41,7 +60,7 @@ class BaseYArchitecture(pl.LightningModule): # Common Steps self.init_components(hparams) - self.init_latent(hparams) + self.init_backbone(hparams) self.init_comparison(hparams) self.init_competition(hparams) @@ -53,7 +72,7 @@ class BaseYArchitecture(pl.LightningModule): # external API def get_competition(self, batch, components): - latent_batch, latent_components = self.latent(batch, components) + latent_batch, latent_components = self.backbone(batch, components) # TODO: => Latent Hook comparison_tensor = self.comparison(latent_batch, latent_components) # TODO: => Comparison Hook @@ -92,27 +111,43 @@ class BaseYArchitecture(pl.LightningModule): return self.loss(comparison_tensor, batch, components) # Empty Initialization - # TODO: Docs def init_components(self, hparams: HyperParameters) -> None: + """ + All initialization necessary for the components step. + """ ... - def init_latent(self, hparams: HyperParameters) -> None: + def init_backbone(self, hparams: HyperParameters) -> None: + """ + All initialization necessary for the backbone step. + """ ... def init_comparison(self, hparams: HyperParameters) -> None: + """ + All initialization necessary for the comparison step. + """ ... def init_competition(self, hparams: HyperParameters) -> None: + """ + All initialization necessary for the competition step. + """ ... def init_loss(self, hparams: HyperParameters) -> None: + """ + All initialization necessary for the loss step. + """ ... def init_inference(self, hparams: HyperParameters) -> None: + """ + All initialization necessary for the inference step. + """ ... # Empty Steps - # TODO: Type hints def components(self): """ This step has no input. @@ -122,9 +157,9 @@ class BaseYArchitecture(pl.LightningModule): raise NotImplementedError( "The components step has no reasonable default.") - def latent(self, batch, components): + def backbone(self, batch, components): """ - The latent step receives the data batch and the components. + The backbone step receives the data batch and the components. It can transform both by an arbitrary function. It returns the transformed batch and components, each of the same length as the original input. @@ -173,52 +208,72 @@ class BaseYArchitecture(pl.LightningModule): self, name: Callable, metric: type[Metric], + step: str = Steps.TRAINING, **metric_kwargs, ): - if metric not in self.registered_metrics: - self.registered_metrics[metric] = metric(**metric_kwargs) - self.registered_metric_callbacks[metric] = {name} - else: - self.registered_metric_callbacks[metric].add(name) + if step == Steps.PREDICT: + raise ValueError("Prediction metrics are not supported.") - def update_metrics_step(self, batch): + if metric not in self.registered_metrics: + self.registered_metrics[step][metric] = metric(**metric_kwargs) + self.registered_metric_callbacks[step][metric] = {name} + else: + self.registered_metric_callbacks[step][metric].add(name) + + def update_metrics_step(self, batch, step): # Prediction Metrics preds = self(batch) x, y = batch - for metric in self.registered_metrics: - instance = self.registered_metrics[metric].to(self.device) + for metric in self.registered_metrics[step]: + instance = self.registered_metrics[step][metric].to(self.device) instance(y, preds) - def update_metrics_epoch(self): - for metric in self.registered_metrics: - instance = self.registered_metrics[metric].to(self.device) + def update_metrics_epoch(self, step): + for metric in self.registered_metrics[step]: + instance = self.registered_metrics[step][metric].to(self.device) value = instance.compute() - for callback in self.registered_metric_callbacks[metric]: + for callback in self.registered_metric_callbacks[step][metric]: callback(value, self) instance.reset() - # Lightning Hooks - - # Steps + # Lightning steps + # ------------------------------------------------------------------------- + # >>>> Training def training_step(self, batch, batch_idx, optimizer_idx=None): - self.update_metrics_step([torch.clone(el) for el in batch]) + self.update_metrics_step(batch, Steps.TRAINING) return self.loss_forward(batch) - def validation_step(self, batch, batch_idx): - return self.loss_forward(batch) - - def test_step(self, batch, batch_idx): - return self.loss_forward(batch) - - # Other Hooks def training_epoch_end(self, outs) -> None: - self.update_metrics_epoch() + self.update_metrics_epoch(Steps.TRAINING) + # >>>> Validation + def validation_step(self, batch, batch_idx): + self.update_metrics_step(batch, Steps.VALIDATION) + + return self.loss_forward(batch) + + def validation_epoch_end(self, outs) -> None: + self.update_metrics_epoch(Steps.VALIDATION) + + # >>>> Test + def test_step(self, batch, batch_idx): + self.update_metrics_step(batch, Steps.TEST) + return self.loss_forward(batch) + + def test_epoch_end(self, outs) -> None: + self.update_metrics_epoch(Steps.TEST) + + # >>>> Prediction + def predict_step(self, batch, batch_idx, dataloader_idx=0): + return self.predict(batch) + + # Check points def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: + # Compatible with Lightning checkpoint["hyper_parameters"] = { 'hparams': checkpoint["hyper_parameters"] } diff --git a/prototorch/y/callbacks.py b/prototorch/y/callbacks.py index 99c7a3e..84b5dfb 100644 --- a/prototorch/y/callbacks.py +++ b/prototorch/y/callbacks.py @@ -1,3 +1,4 @@ +import logging import warnings from typing import Optional, Type @@ -8,7 +9,7 @@ import torchmetrics from matplotlib import pyplot as plt from prototorch.models.vis import Vis2DAbstract from prototorch.utils.utils import mesh2d -from prototorch.y.architectures.base import BaseYArchitecture +from prototorch.y.architectures.base import BaseYArchitecture, Steps from prototorch.y.library.gmlvq import GMLVQ from pytorch_lightning.loggers import TensorBoardLogger @@ -34,13 +35,13 @@ class LogTorchmetricCallback(pl.Callback): self, name, metric: Type[torchmetrics.Metric], - on="prediction", + step: str = Steps.TRAINING, **metric_kwargs, ) -> None: self.name = name self.metric = metric self.metric_kwargs = metric_kwargs - self.on = on + self.step = step def setup( self, @@ -48,14 +49,12 @@ class LogTorchmetricCallback(pl.Callback): pl_module: BaseYArchitecture, stage: Optional[str] = None, ) -> None: - if self.on == "prediction": - pl_module.register_torchmetric( - self, - self.metric, - **self.metric_kwargs, - ) - else: - raise ValueError(f"{self.on} is no valid metric hook") + pl_module.register_torchmetric( + self, + self.metric, + step=self.step, + **self.metric_kwargs, + ) def __call__(self, value, pl_module: BaseYArchitecture): pl_module.log(self.name, value)