feat: distribute GMLVQ into mixins

This commit is contained in:
Alexander Engelsberger 2022-05-31 17:56:03 +02:00
parent e922aae432
commit 23d1a71b31
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
14 changed files with 211 additions and 152 deletions

View File

@ -2,12 +2,12 @@ import prototorch as pt
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.models.y_arch.callbacks import (
from prototorch.y_arch.callbacks import (
LogTorchmetricCallback,
PlotLambdaMatrixToTensorboard,
VisGMLVQ2D,
)
from prototorch.models.y_arch.library.gmlvq import GMLVQ
from prototorch.y_arch.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
@ -39,8 +39,7 @@ if __name__ == "__main__":
# Define Hyperparameters
hyperparameters = GMLVQ.HyperParameters(
lr=0.1,
backbone_lr=5,
lr=dict(components_layer=0.1, _omega=0),
input_dim=4,
distribution=dict(
num_classes=3,

View File

@ -1,41 +0,0 @@
from dataclasses import dataclass, field
from typing import Callable
from prototorch.core.distances import euclidean_distance
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
from prototorch.nn.wrappers import LambdaLayer
class SimpleComparisonMixin(BaseYArchitecture):
"""
Simple Comparison
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
**hparams.comparison_args)
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(batch_tensor, comp_tensor)
return distances

View File

@ -1,36 +0,0 @@
from dataclasses import dataclass
from typing import Type
import torch
from prototorch.models.y_arch import BaseYArchitecture
class SingleLearningRateMixin(BaseYArchitecture):
"""
Single Learning Rate
All parameters are updated with a single learning rate.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore

View File

@ -1,14 +1,22 @@
from .architectures.base import BaseYArchitecture
from .architectures.comparison import SimpleComparisonMixin
from .architectures.comparison import (
OmegaComparisonMixin,
SimpleComparisonMixin,
)
from .architectures.competition import WTACompetitionMixin
from .architectures.components import SupervisedArchitecture
from .architectures.loss import GLVQLossMixin
from .architectures.optimization import SingleLearningRateMixin
from .architectures.optimization import (
MultipleLearningRateMixin,
SingleLearningRateMixin,
)
__all__ = [
'BaseYArchitecture',
"OmegaComparisonMixin",
"SimpleComparisonMixin",
"SingleLearningRateMixin",
"MultipleLearningRateMixin",
"SupervisedArchitecture",
"WTACompetitionMixin",
"GLVQLossMixin",

View File

@ -1,86 +1,50 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
from typing import Callable, Dict
import torch
from prototorch.core.distances import omega_distance
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
AbstractLinearTransformInitializer,
EyeLinearTransformInitializer,
)
from prototorch.models.y_arch import (
GLVQLossMixin,
SimpleComparisonMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
from prototorch.nn.wrappers import LambdaLayer
from prototorch.y_arch.architectures.base import BaseYArchitecture
from torch import Tensor
from torch.nn.parameter import Parameter
class GMLVQ(
SupervisedArchitecture,
SimpleComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
):
class SimpleComparisonMixin(BaseYArchitecture):
"""
Generalized Matrix Learning Vector Quantization (GMLVQ)
Simple Comparison
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
"""
_omega: torch.Tensor
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(
SimpleComparisonMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
input_dim: Necessary Field: The dimensionality of the input.
latent_dim: The dimensionality of the latent space. Default: 2.
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
backbone_lr: float = 0.1
lr: float = 0.1
comparison_fn: Callable = omega_distance
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
comparison_parameters: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.backbone_lr = hparams.backbone_lr
self.optimizer = hparams.optimizer
def init_comparison(self, hparams: HyperParameters) -> None:
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn,
**hparams.comparison_args,
)
self.comparison_kwargs: dict[str, Tensor] = dict()
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
@ -90,21 +54,50 @@ class GMLVQ(
distances = self.comparison_layer(
batch_tensor,
comp_tensor,
self._omega,
**self.comparison_kwargs,
)
return distances
def configure_optimizers(self):
proto_opt = self.optimizer(
self.components_layer.parameters(),
lr=self.lr,
class OmegaComparisonMixin(SimpleComparisonMixin):
"""
Omega Comparison
A comparison layer that uses the positions of the components and the batch for dissimilarity computation.
"""
_omega: torch.Tensor
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(SimpleComparisonMixin.HyperParameters):
"""
input_dim: Necessary Field: The dimensionality of the input.
latent_dim: The dimensionality of the latent space. Default: 2.
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
"""
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters) -> None:
super().init_comparison(hparams)
# Initialize the omega matrix
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
omega_opt = self.optimizer(
[self._omega],
lr=self.backbone_lr,
)
return [proto_opt, omega_opt]
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)
# Properties
# ----------------------------------------------------------------------------------------------------

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass
from prototorch.core.competitions import WTAC
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
from prototorch.y_arch.architectures.base import BaseYArchitecture
class WTACompetitionMixin(BaseYArchitecture):

View File

@ -5,7 +5,7 @@ from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
)
from prototorch.models.y_arch import BaseYArchitecture
from prototorch.y_arch import BaseYArchitecture
class SupervisedArchitecture(BaseYArchitecture):

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from prototorch.core.losses import GLVQLoss
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
from prototorch.y_arch.architectures.base import BaseYArchitecture
class GLVQLossMixin(BaseYArchitecture):

View File

@ -0,0 +1,86 @@
from dataclasses import dataclass, field
from typing import Type
import torch
from prototorch.y_arch import BaseYArchitecture
from torch.nn.parameter import Parameter
class SingleLearningRateMixin(BaseYArchitecture):
"""
Single Learning Rate
All parameters are updated with a single learning rate.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture):
"""
Multiple Learning Rates
Define Different Learning Rates for different parameters.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
optimizers = []
for name, lr in self.lr.items():
if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}")
else:
model_part = getattr(self, name)
if isinstance(model_part, Parameter):
optimizers.append(
self.optimizer(
[model_part],
lr=lr, # type: ignore
))
elif hasattr(model_part, "parameters"):
optimizers.append(
self.optimizer(
model_part.parameters(),
lr=lr, # type: ignore
))
return optimizers

View File

@ -7,9 +7,9 @@ import torch
import torchmetrics
from matplotlib import pyplot as plt
from prototorch.models.vis import Vis2DAbstract
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
from prototorch.models.y_arch.library.gmlvq import GMLVQ
from prototorch.utils.utils import mesh2d
from prototorch.y_arch.architectures.base import BaseYArchitecture
from prototorch.y_arch.library.gmlvq import GMLVQ
from pytorch_lightning.loggers import TensorBoardLogger
DIVERGING_COLOR_MAPS = [

View File

@ -1,12 +1,12 @@
from dataclasses import dataclass
from prototorch.models.y_arch import (
from prototorch.y_arch import (
SimpleComparisonMixin,
SingleLearningRateMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
from prototorch.models.y_arch.architectures.loss import GLVQLossMixin
from prototorch.y_arch.architectures.loss import GLVQLossMixin
class GLVQ(

View File

@ -0,0 +1,50 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
import torch
from prototorch.core.distances import omega_distance
from prototorch.y_arch import (
GLVQLossMixin,
MultipleLearningRateMixin,
OmegaComparisonMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
class GMLVQ(
SupervisedArchitecture,
OmegaComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
MultipleLearningRateMixin,
):
"""
Generalized Matrix Learning Vector Quantization (GMLVQ)
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(
MultipleLearningRateMixin.HyperParameters,
OmegaComparisonMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
"""
comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict())
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
lr: dict = field(default_factory=lambda: dict(
components_layer=0.1,
_omega=0.5,
))