From 3e50d0d817f330fb5b4043ebe3e66123ab8d8557 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Wed, 18 May 2022 15:43:09 +0200 Subject: [PATCH] chore(protoy): mixin restructuring --- .../models/proto_y_architecture/base.py | 21 +-- .../models/proto_y_architecture/glvq.py | 132 ++++++++++++------ .../y_architecture_example.py | 5 +- 3 files changed, 101 insertions(+), 57 deletions(-) diff --git a/prototorch/models/proto_y_architecture/base.py b/prototorch/models/proto_y_architecture/base.py index 41591d8..b551d66 100644 --- a/prototorch/models/proto_y_architecture/base.py +++ b/prototorch/models/proto_y_architecture/base.py @@ -17,7 +17,8 @@ from typing import ( import pytorch_lightning as pl import torch -from torchmetrics import Accuracy, Metric +from torchmetrics import Metric +from torchmetrics.classification.accuracy import Accuracy class BaseYArchitecture(pl.LightningModule): @@ -29,7 +30,7 @@ class BaseYArchitecture(pl.LightningModule): registered_metrics: Dict[Type[Metric], Metric] = {} registered_metric_names: Dict[Type[Metric], Set[str]] = {} - components_layer: pl.LightningModule + components_layer: torch.nn.Module def __init__(self, hparams) -> None: super().__init__() @@ -63,7 +64,7 @@ class BaseYArchitecture(pl.LightningModule): self.registered_metric_names[metric].add(name) # external API - def get_competion(self, batch, components): + def get_competition(self, batch, components): latent_batch, latent_components = self.latent(batch, components) # TODO: => Latent Hook comparison_tensor = self.comparison(latent_batch, latent_components) @@ -76,7 +77,7 @@ class BaseYArchitecture(pl.LightningModule): # TODO: manage different datatypes? components = self.components_layer() # TODO: => Component Hook - comparison_tensor = self.get_competion(batch, components) + comparison_tensor = self.get_competition(batch, components) # TODO: => Competition Hook return self.inference(comparison_tensor, components) @@ -92,13 +93,13 @@ class BaseYArchitecture(pl.LightningModule): # TODO: manage different datatypes? components = self.components_layer() # TODO: => Component Hook - return self.get_competion(batch, components) + return self.get_competition(batch, components) def loss_forward(self, batch): # TODO: manage different datatypes? components = self.components_layer() # TODO: => Component Hook - comparison_tensor = self.get_competion(batch, components) + comparison_tensor = self.get_competition(batch, components) # TODO: => Competition Hook return self.loss(comparison_tensor, batch, components) @@ -148,14 +149,14 @@ class BaseYArchitecture(pl.LightningModule): def comparison(self, batch, components): """ - Takes a batch of size N and the componentsset of size M. + Takes a batch of size N and the component set of size M. It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures. """ raise NotImplementedError( "The comparison step has no reasonable default.") - def competition(self, comparisonmeasures, components): + def competition(self, comparison_measures, components): """ Takes the tensor of comparison measures. @@ -164,7 +165,7 @@ class BaseYArchitecture(pl.LightningModule): raise NotImplementedError( "The competition step has no reasonable default.") - def loss(self, comparisonmeasures, batch, components): + def loss(self, comparison_measures, batch, components): """ Takes the tensor of competition measures. @@ -172,7 +173,7 @@ class BaseYArchitecture(pl.LightningModule): """ raise NotImplementedError("The loss step has no reasonable default.") - def inference(self, comparisonmeasures, components): + def inference(self, comparison_measures, components): """ Takes the tensor of competition measures. diff --git a/prototorch/models/proto_y_architecture/glvq.py b/prototorch/models/proto_y_architecture/glvq.py index 9d0c644..67d8883 100644 --- a/prototorch/models/proto_y_architecture/glvq.py +++ b/prototorch/models/proto_y_architecture/glvq.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Callable, Type import torch @@ -14,7 +14,8 @@ from prototorch.models.proto_y_architecture.base import BaseYArchitecture from prototorch.nn.wrappers import LambdaLayer -class SupervisedScheme(BaseYArchitecture): +class SupervisedArchitecture(BaseYArchitecture): + components_layer: LabeledComponents @dataclass class HyperParameters: @@ -28,23 +29,59 @@ class SupervisedScheme(BaseYArchitecture): labels_initializer=LabelsInitializer(), ) + @property + def prototypes(self): + return self.components_layer.components.detach().cpu() -# ############################################################################## -# GLVQ -# ############################################################################## -class GLVQ( - SupervisedScheme, ): - """GLVQ using the new Scheme - """ + @property + def prototype_labels(self): + return self.components_layer.labels.detach().cpu() + + +class WTACompetitionMixin(BaseYArchitecture): @dataclass - class HyperParameters(SupervisedScheme.HyperParameters): - distance_fn: Callable = euclidean_distance - lr: float = 0.01 + class HyperParameters(BaseYArchitecture.HyperParameters): + pass + + def init_inference(self, hparams: HyperParameters): + self.competition_layer = WTAC() + + def inference(self, comparison_measures, components): + comp_labels = components[1] + return self.competition_layer(comparison_measures, comp_labels) + + +class GLVQLossMixin(BaseYArchitecture): + + @dataclass + class HyperParameters(BaseYArchitecture.HyperParameters): margin: float = 0.0 - # TODO: make nicer - transfer_fn: str = "identity" - transfer_beta: float = 10.0 + + transfer_fn: str = "sigmoid_beta" + transfer_args: dict = field(default_factory=lambda: dict(beta=10.0)) + + def init_loss(self, hparams: HyperParameters): + self.loss_layer = GLVQLoss( + margin=hparams.margin, + transfer_fn=hparams.transfer_fn, + **hparams.transfer_args, + ) + + def loss(self, comparison_measures, batch, components): + target = batch[1] + comp_labels = components[1] + loss = self.loss_layer(comparison_measures, target, comp_labels) + self.log('loss', loss) + return loss + + +class SingleLearningRateMixin(BaseYArchitecture): + + @dataclass + class HyperParameters(BaseYArchitecture.HyperParameters): + # Training Hyperparameters + lr: float = 0.01 optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam def __init__(self, hparams: HyperParameters) -> None: @@ -52,20 +89,22 @@ class GLVQ( self.lr = hparams.lr self.optimizer = hparams.optimizer + def configure_optimizers(self): + return self.optimizer(self.parameters(), lr=self.lr) # type: ignore + + +class SimpleComparisonMixin(BaseYArchitecture): + + @dataclass + class HyperParameters(BaseYArchitecture.HyperParameters): + # Training Hyperparameters + comparison_fn: Callable = euclidean_distance + comparison_args: dict = field(default_factory=lambda: dict()) + def init_comparison(self, hparams: HyperParameters): - self.comparison_layer = LambdaLayer(hparams.distance_fn) + self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn, + **hparams.comparison_args) - def init_inference(self, hparams: HyperParameters): - self.competition_layer = WTAC() - - def init_loss(self, hparams): - self.loss_layer = GLVQLoss( - margin=hparams.margin, - transfer_fn=hparams.transfer_fn, - beta=hparams.transfer_beta, - ) - - # Steps def comparison(self, batch, components): comp_tensor, _ = components batch_tensor, _ = batch @@ -76,23 +115,26 @@ class GLVQ( return distances - def inference(self, comparisonmeasures, components): - comp_labels = components[1] - return self.competition_layer(comparisonmeasures, comp_labels) - def loss(self, comparisonmeasures, batch, components): - target = batch[1] - comp_labels = components[1] - return self.loss_layer(comparisonmeasures, target, comp_labels) +# ############################################################################## +# GLVQ +# ############################################################################## +class GLVQ( + SupervisedArchitecture, + SimpleComparisonMixin, + GLVQLossMixin, + WTACompetitionMixin, + SingleLearningRateMixin, +): + """GLVQ using the new Scheme + """ - def configure_optimizers(self): - return self.optimizer(self.parameters(), lr=self.lr) # type: ignore - - # Properties - @property - def prototypes(self): - return self.components_layer.components.detach().cpu() - - @property - def prototype_labels(self): - return self.components_layer.labels.detach().cpu() + @dataclass + class HyperParameters( + SimpleComparisonMixin.HyperParameters, + SingleLearningRateMixin.HyperParameters, + GLVQLossMixin.HyperParameters, + WTACompetitionMixin.HyperParameters, + SupervisedArchitecture.HyperParameters, + ): + pass diff --git a/prototorch/models/proto_y_architecture/y_architecture_example.py b/prototorch/models/proto_y_architecture/y_architecture_example.py index a58afeb..6155f10 100644 --- a/prototorch/models/proto_y_architecture/y_architecture_example.py +++ b/prototorch/models/proto_y_architecture/y_architecture_example.py @@ -25,7 +25,7 @@ if __name__ == "__main__": # Dataloader train_loader = DataLoader( train_ds, - batch_size=64, + batch_size=32, num_workers=0, shuffle=True, ) @@ -39,7 +39,7 @@ if __name__ == "__main__": # Define Hyperparameters hyperparameters = GLVQ.HyperParameters( - lr=0.5, + lr=0.1, distribution=dict( num_classes=2, per_class=1, @@ -49,6 +49,7 @@ if __name__ == "__main__": # Create Model model = GLVQ(hyperparameters) + print(model) # ------------------------------------------------------------