feat: distribute GMLVQ into mixins

This commit is contained in:
Alexander Engelsberger 2022-05-31 17:56:03 +02:00
parent e922aae432
commit 23d1a71b31
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
14 changed files with 211 additions and 152 deletions

View File

@ -2,12 +2,12 @@ import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torchmetrics import torchmetrics
from prototorch.core import SMCI from prototorch.core import SMCI
from prototorch.models.y_arch.callbacks import ( from prototorch.y_arch.callbacks import (
LogTorchmetricCallback, LogTorchmetricCallback,
PlotLambdaMatrixToTensorboard, PlotLambdaMatrixToTensorboard,
VisGMLVQ2D, VisGMLVQ2D,
) )
from prototorch.models.y_arch.library.gmlvq import GMLVQ from prototorch.y_arch.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -39,8 +39,7 @@ if __name__ == "__main__":
# Define Hyperparameters # Define Hyperparameters
hyperparameters = GMLVQ.HyperParameters( hyperparameters = GMLVQ.HyperParameters(
lr=0.1, lr=dict(components_layer=0.1, _omega=0),
backbone_lr=5,
input_dim=4, input_dim=4,
distribution=dict( distribution=dict(
num_classes=3, num_classes=3,

View File

@ -1,41 +0,0 @@
from dataclasses import dataclass, field
from typing import Callable
from prototorch.core.distances import euclidean_distance
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
from prototorch.nn.wrappers import LambdaLayer
class SimpleComparisonMixin(BaseYArchitecture):
"""
Simple Comparison
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
**hparams.comparison_args)
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(batch_tensor, comp_tensor)
return distances

View File

@ -1,36 +0,0 @@
from dataclasses import dataclass
from typing import Type
import torch
from prototorch.models.y_arch import BaseYArchitecture
class SingleLearningRateMixin(BaseYArchitecture):
"""
Single Learning Rate
All parameters are updated with a single learning rate.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore

View File

@ -1,14 +1,22 @@
from .architectures.base import BaseYArchitecture from .architectures.base import BaseYArchitecture
from .architectures.comparison import SimpleComparisonMixin from .architectures.comparison import (
OmegaComparisonMixin,
SimpleComparisonMixin,
)
from .architectures.competition import WTACompetitionMixin from .architectures.competition import WTACompetitionMixin
from .architectures.components import SupervisedArchitecture from .architectures.components import SupervisedArchitecture
from .architectures.loss import GLVQLossMixin from .architectures.loss import GLVQLossMixin
from .architectures.optimization import SingleLearningRateMixin from .architectures.optimization import (
MultipleLearningRateMixin,
SingleLearningRateMixin,
)
__all__ = [ __all__ = [
'BaseYArchitecture', 'BaseYArchitecture',
"OmegaComparisonMixin",
"SimpleComparisonMixin", "SimpleComparisonMixin",
"SingleLearningRateMixin", "SingleLearningRateMixin",
"MultipleLearningRateMixin",
"SupervisedArchitecture", "SupervisedArchitecture",
"WTACompetitionMixin", "WTACompetitionMixin",
"GLVQLossMixin", "GLVQLossMixin",

View File

@ -1,86 +1,50 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable from typing import Callable, Dict
import torch import torch
from prototorch.core.distances import omega_distance from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import ( from prototorch.core.initializers import (
AbstractLinearTransformInitializer, AbstractLinearTransformInitializer,
EyeLinearTransformInitializer, EyeLinearTransformInitializer,
) )
from prototorch.models.y_arch import (
GLVQLossMixin,
SimpleComparisonMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
from prototorch.nn.wrappers import LambdaLayer from prototorch.nn.wrappers import LambdaLayer
from prototorch.y_arch.architectures.base import BaseYArchitecture
from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
class GMLVQ( class SimpleComparisonMixin(BaseYArchitecture):
SupervisedArchitecture,
SimpleComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
):
""" """
Generalized Matrix Learning Vector Quantization (GMLVQ) Simple Comparison
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss. A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
""" """
_omega: torch.Tensor
# HyperParameters # HyperParameters
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@dataclass @dataclass
class HyperParameters( class HyperParameters(BaseYArchitecture.HyperParameters):
SimpleComparisonMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
""" """
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance. comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Override Default: {}. comparison_args: Keyword arguments for the comparison function. Default: {}.
input_dim: Necessary Field: The dimensionality of the input.
latent_dim: The dimensionality of the latent space. Default: 2.
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
""" """
backbone_lr: float = 0.1 comparison_fn: Callable = euclidean_distance
lr: float = 0.1
comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict()) comparison_args: dict = field(default_factory=lambda: dict())
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam comparison_parameters: dict = field(default_factory=lambda: dict())
# Steps # Steps
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def __init__(self, hparams) -> None: def init_comparison(self, hparams: HyperParameters):
super().__init__(hparams)
self.lr = hparams.lr
self.backbone_lr = hparams.backbone_lr
self.optimizer = hparams.optimizer
def init_comparison(self, hparams: HyperParameters) -> None:
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_layer = LambdaLayer( self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn, fn=hparams.comparison_fn,
**hparams.comparison_args, **hparams.comparison_args,
) )
self.comparison_kwargs: dict[str, Tensor] = dict()
def comparison(self, batch, components): def comparison(self, batch, components):
comp_tensor, _ = components comp_tensor, _ = components
batch_tensor, _ = batch batch_tensor, _ = batch
@ -90,21 +54,50 @@ class GMLVQ(
distances = self.comparison_layer( distances = self.comparison_layer(
batch_tensor, batch_tensor,
comp_tensor, comp_tensor,
self._omega, **self.comparison_kwargs,
) )
return distances return distances
def configure_optimizers(self):
proto_opt = self.optimizer( class OmegaComparisonMixin(SimpleComparisonMixin):
self.components_layer.parameters(), """
lr=self.lr, Omega Comparison
)
omega_opt = self.optimizer( A comparison layer that uses the positions of the components and the batch for dissimilarity computation.
[self._omega], """
lr=self.backbone_lr,
) _omega: torch.Tensor
return [proto_opt, omega_opt]
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(SimpleComparisonMixin.HyperParameters):
"""
input_dim: Necessary Field: The dimensionality of the input.
latent_dim: The dimensionality of the latent space. Default: 2.
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
"""
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters) -> None:
super().init_comparison(hparams)
# Initialize the omega matrix
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)
# Properties # Properties
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from prototorch.core.competitions import WTAC from prototorch.core.competitions import WTAC
from prototorch.models.y_arch.architectures.base import BaseYArchitecture from prototorch.y_arch.architectures.base import BaseYArchitecture
class WTACompetitionMixin(BaseYArchitecture): class WTACompetitionMixin(BaseYArchitecture):

View File

@ -5,7 +5,7 @@ from prototorch.core.initializers import (
AbstractComponentsInitializer, AbstractComponentsInitializer,
LabelsInitializer, LabelsInitializer,
) )
from prototorch.models.y_arch import BaseYArchitecture from prototorch.y_arch import BaseYArchitecture
class SupervisedArchitecture(BaseYArchitecture): class SupervisedArchitecture(BaseYArchitecture):

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from prototorch.core.losses import GLVQLoss from prototorch.core.losses import GLVQLoss
from prototorch.models.y_arch.architectures.base import BaseYArchitecture from prototorch.y_arch.architectures.base import BaseYArchitecture
class GLVQLossMixin(BaseYArchitecture): class GLVQLossMixin(BaseYArchitecture):

View File

@ -0,0 +1,86 @@
from dataclasses import dataclass, field
from typing import Type
import torch
from prototorch.y_arch import BaseYArchitecture
from torch.nn.parameter import Parameter
class SingleLearningRateMixin(BaseYArchitecture):
"""
Single Learning Rate
All parameters are updated with a single learning rate.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture):
"""
Multiple Learning Rates
Define Different Learning Rates for different parameters.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
optimizers = []
for name, lr in self.lr.items():
if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}")
else:
model_part = getattr(self, name)
if isinstance(model_part, Parameter):
optimizers.append(
self.optimizer(
[model_part],
lr=lr, # type: ignore
))
elif hasattr(model_part, "parameters"):
optimizers.append(
self.optimizer(
model_part.parameters(),
lr=lr, # type: ignore
))
return optimizers

View File

@ -7,9 +7,9 @@ import torch
import torchmetrics import torchmetrics
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.models.vis import Vis2DAbstract from prototorch.models.vis import Vis2DAbstract
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
from prototorch.models.y_arch.library.gmlvq import GMLVQ
from prototorch.utils.utils import mesh2d from prototorch.utils.utils import mesh2d
from prototorch.y_arch.architectures.base import BaseYArchitecture
from prototorch.y_arch.library.gmlvq import GMLVQ
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
DIVERGING_COLOR_MAPS = [ DIVERGING_COLOR_MAPS = [

View File

@ -1,12 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from prototorch.models.y_arch import ( from prototorch.y_arch import (
SimpleComparisonMixin, SimpleComparisonMixin,
SingleLearningRateMixin, SingleLearningRateMixin,
SupervisedArchitecture, SupervisedArchitecture,
WTACompetitionMixin, WTACompetitionMixin,
) )
from prototorch.models.y_arch.architectures.loss import GLVQLossMixin from prototorch.y_arch.architectures.loss import GLVQLossMixin
class GLVQ( class GLVQ(

View File

@ -0,0 +1,50 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
import torch
from prototorch.core.distances import omega_distance
from prototorch.y_arch import (
GLVQLossMixin,
MultipleLearningRateMixin,
OmegaComparisonMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
class GMLVQ(
SupervisedArchitecture,
OmegaComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
MultipleLearningRateMixin,
):
"""
Generalized Matrix Learning Vector Quantization (GMLVQ)
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(
MultipleLearningRateMixin.HyperParameters,
OmegaComparisonMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
"""
comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict())
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
lr: dict = field(default_factory=lambda: dict(
components_layer=0.1,
_omega=0.5,
))