feat: distribute GMLVQ into mixins
This commit is contained in:
parent
e922aae432
commit
23d1a71b31
@ -2,12 +2,12 @@ import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torchmetrics
|
||||
from prototorch.core import SMCI
|
||||
from prototorch.models.y_arch.callbacks import (
|
||||
from prototorch.y_arch.callbacks import (
|
||||
LogTorchmetricCallback,
|
||||
PlotLambdaMatrixToTensorboard,
|
||||
VisGMLVQ2D,
|
||||
)
|
||||
from prototorch.models.y_arch.library.gmlvq import GMLVQ
|
||||
from prototorch.y_arch.library.gmlvq import GMLVQ
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
@ -39,8 +39,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Define Hyperparameters
|
||||
hyperparameters = GMLVQ.HyperParameters(
|
||||
lr=0.1,
|
||||
backbone_lr=5,
|
||||
lr=dict(components_layer=0.1, _omega=0),
|
||||
input_dim=4,
|
||||
distribution=dict(
|
||||
num_classes=3,
|
||||
|
@ -1,41 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
from prototorch.core.distances import euclidean_distance
|
||||
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
|
||||
|
||||
class SimpleComparisonMixin(BaseYArchitecture):
|
||||
"""
|
||||
Simple Comparison
|
||||
|
||||
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
|
||||
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
||||
"""
|
||||
comparison_fn: Callable = euclidean_distance
|
||||
comparison_args: dict = field(default_factory=lambda: dict())
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def init_comparison(self, hparams: HyperParameters):
|
||||
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
|
||||
**hparams.comparison_args)
|
||||
|
||||
def comparison(self, batch, components):
|
||||
comp_tensor, _ = components
|
||||
batch_tensor, _ = batch
|
||||
|
||||
comp_tensor = comp_tensor.unsqueeze(1)
|
||||
|
||||
distances = self.comparison_layer(batch_tensor, comp_tensor)
|
||||
|
||||
return distances
|
@ -1,36 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from prototorch.models.y_arch import BaseYArchitecture
|
||||
|
||||
|
||||
class SingleLearningRateMixin(BaseYArchitecture):
|
||||
"""
|
||||
Single Learning Rate
|
||||
|
||||
All parameters are updated with a single learning rate.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
lr: The learning rate. Default: 0.1.
|
||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||
"""
|
||||
lr: float = 0.1
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def __init__(self, hparams: HyperParameters) -> None:
|
||||
super().__init__(hparams)
|
||||
self.lr = hparams.lr
|
||||
self.optimizer = hparams.optimizer
|
||||
|
||||
# Hooks
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
@ -1,14 +1,22 @@
|
||||
from .architectures.base import BaseYArchitecture
|
||||
from .architectures.comparison import SimpleComparisonMixin
|
||||
from .architectures.comparison import (
|
||||
OmegaComparisonMixin,
|
||||
SimpleComparisonMixin,
|
||||
)
|
||||
from .architectures.competition import WTACompetitionMixin
|
||||
from .architectures.components import SupervisedArchitecture
|
||||
from .architectures.loss import GLVQLossMixin
|
||||
from .architectures.optimization import SingleLearningRateMixin
|
||||
from .architectures.optimization import (
|
||||
MultipleLearningRateMixin,
|
||||
SingleLearningRateMixin,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BaseYArchitecture',
|
||||
"OmegaComparisonMixin",
|
||||
"SimpleComparisonMixin",
|
||||
"SingleLearningRateMixin",
|
||||
"MultipleLearningRateMixin",
|
||||
"SupervisedArchitecture",
|
||||
"WTACompetitionMixin",
|
||||
"GLVQLossMixin",
|
@ -1,86 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
from prototorch.core.distances import omega_distance
|
||||
from prototorch.core.distances import euclidean_distance
|
||||
from prototorch.core.initializers import (
|
||||
AbstractLinearTransformInitializer,
|
||||
EyeLinearTransformInitializer,
|
||||
)
|
||||
from prototorch.models.y_arch import (
|
||||
GLVQLossMixin,
|
||||
SimpleComparisonMixin,
|
||||
SupervisedArchitecture,
|
||||
WTACompetitionMixin,
|
||||
)
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
from prototorch.y_arch.architectures.base import BaseYArchitecture
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class GMLVQ(
|
||||
SupervisedArchitecture,
|
||||
SimpleComparisonMixin,
|
||||
GLVQLossMixin,
|
||||
WTACompetitionMixin,
|
||||
):
|
||||
class SimpleComparisonMixin(BaseYArchitecture):
|
||||
"""
|
||||
Generalized Matrix Learning Vector Quantization (GMLVQ)
|
||||
Simple Comparison
|
||||
|
||||
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
|
||||
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
|
||||
"""
|
||||
|
||||
_omega: torch.Tensor
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(
|
||||
SimpleComparisonMixin.HyperParameters,
|
||||
GLVQLossMixin.HyperParameters,
|
||||
WTACompetitionMixin.HyperParameters,
|
||||
SupervisedArchitecture.HyperParameters,
|
||||
):
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
|
||||
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
||||
input_dim: Necessary Field: The dimensionality of the input.
|
||||
latent_dim: The dimensionality of the latent space. Default: 2.
|
||||
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
|
||||
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
|
||||
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
||||
"""
|
||||
backbone_lr: float = 0.1
|
||||
lr: float = 0.1
|
||||
comparison_fn: Callable = omega_distance
|
||||
comparison_fn: Callable = euclidean_distance
|
||||
comparison_args: dict = field(default_factory=lambda: dict())
|
||||
input_dim: int | None = None
|
||||
latent_dim: int = 2
|
||||
omega_initializer: type[
|
||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||
|
||||
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
comparison_parameters: dict = field(default_factory=lambda: dict())
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def __init__(self, hparams) -> None:
|
||||
super().__init__(hparams)
|
||||
self.lr = hparams.lr
|
||||
self.backbone_lr = hparams.backbone_lr
|
||||
self.optimizer = hparams.optimizer
|
||||
|
||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||
if hparams.input_dim is None:
|
||||
raise ValueError("input_dim must be specified.")
|
||||
omega = hparams.omega_initializer().generate(
|
||||
hparams.input_dim,
|
||||
hparams.latent_dim,
|
||||
)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
def init_comparison(self, hparams: HyperParameters):
|
||||
self.comparison_layer = LambdaLayer(
|
||||
fn=hparams.comparison_fn,
|
||||
**hparams.comparison_args,
|
||||
)
|
||||
|
||||
self.comparison_kwargs: dict[str, Tensor] = dict()
|
||||
|
||||
def comparison(self, batch, components):
|
||||
comp_tensor, _ = components
|
||||
batch_tensor, _ = batch
|
||||
@ -90,21 +54,50 @@ class GMLVQ(
|
||||
distances = self.comparison_layer(
|
||||
batch_tensor,
|
||||
comp_tensor,
|
||||
self._omega,
|
||||
**self.comparison_kwargs,
|
||||
)
|
||||
|
||||
return distances
|
||||
|
||||
def configure_optimizers(self):
|
||||
proto_opt = self.optimizer(
|
||||
self.components_layer.parameters(),
|
||||
lr=self.lr,
|
||||
|
||||
class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||
"""
|
||||
Omega Comparison
|
||||
|
||||
A comparison layer that uses the positions of the components and the batch for dissimilarity computation.
|
||||
"""
|
||||
|
||||
_omega: torch.Tensor
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(SimpleComparisonMixin.HyperParameters):
|
||||
"""
|
||||
input_dim: Necessary Field: The dimensionality of the input.
|
||||
latent_dim: The dimensionality of the latent space. Default: 2.
|
||||
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
|
||||
"""
|
||||
input_dim: int | None = None
|
||||
latent_dim: int = 2
|
||||
omega_initializer: type[
|
||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||
super().init_comparison(hparams)
|
||||
|
||||
# Initialize the omega matrix
|
||||
if hparams.input_dim is None:
|
||||
raise ValueError("input_dim must be specified.")
|
||||
else:
|
||||
omega = hparams.omega_initializer().generate(
|
||||
hparams.input_dim,
|
||||
hparams.latent_dim,
|
||||
)
|
||||
omega_opt = self.optimizer(
|
||||
[self._omega],
|
||||
lr=self.backbone_lr,
|
||||
)
|
||||
return [proto_opt, omega_opt]
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
self.comparison_kwargs = dict(omega=self._omega)
|
||||
|
||||
# Properties
|
||||
# ----------------------------------------------------------------------------------------------------
|
@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from prototorch.core.competitions import WTAC
|
||||
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
||||
from prototorch.y_arch.architectures.base import BaseYArchitecture
|
||||
|
||||
|
||||
class WTACompetitionMixin(BaseYArchitecture):
|
@ -5,7 +5,7 @@ from prototorch.core.initializers import (
|
||||
AbstractComponentsInitializer,
|
||||
LabelsInitializer,
|
||||
)
|
||||
from prototorch.models.y_arch import BaseYArchitecture
|
||||
from prototorch.y_arch import BaseYArchitecture
|
||||
|
||||
|
||||
class SupervisedArchitecture(BaseYArchitecture):
|
@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from prototorch.core.losses import GLVQLoss
|
||||
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
||||
from prototorch.y_arch.architectures.base import BaseYArchitecture
|
||||
|
||||
|
||||
class GLVQLossMixin(BaseYArchitecture):
|
86
prototorch/y_arch/architectures/optimization.py
Normal file
86
prototorch/y_arch/architectures/optimization.py
Normal file
@ -0,0 +1,86 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from prototorch.y_arch import BaseYArchitecture
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class SingleLearningRateMixin(BaseYArchitecture):
|
||||
"""
|
||||
Single Learning Rate
|
||||
|
||||
All parameters are updated with a single learning rate.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
lr: The learning rate. Default: 0.1.
|
||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||
"""
|
||||
lr: float = 0.1
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def __init__(self, hparams: HyperParameters) -> None:
|
||||
super().__init__(hparams)
|
||||
self.lr = hparams.lr
|
||||
self.optimizer = hparams.optimizer
|
||||
|
||||
# Hooks
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
||||
|
||||
|
||||
class MultipleLearningRateMixin(BaseYArchitecture):
|
||||
"""
|
||||
Multiple Learning Rates
|
||||
|
||||
Define Different Learning Rates for different parameters.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
lr: The learning rate. Default: 0.1.
|
||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||
"""
|
||||
lr: dict = field(default_factory=lambda: dict())
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def __init__(self, hparams: HyperParameters) -> None:
|
||||
super().__init__(hparams)
|
||||
self.lr = hparams.lr
|
||||
self.optimizer = hparams.optimizer
|
||||
|
||||
# Hooks
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def configure_optimizers(self):
|
||||
optimizers = []
|
||||
for name, lr in self.lr.items():
|
||||
if not hasattr(self, name):
|
||||
raise ValueError(f"{name} is not a parameter of {self}")
|
||||
else:
|
||||
model_part = getattr(self, name)
|
||||
if isinstance(model_part, Parameter):
|
||||
optimizers.append(
|
||||
self.optimizer(
|
||||
[model_part],
|
||||
lr=lr, # type: ignore
|
||||
))
|
||||
elif hasattr(model_part, "parameters"):
|
||||
optimizers.append(
|
||||
self.optimizer(
|
||||
model_part.parameters(),
|
||||
lr=lr, # type: ignore
|
||||
))
|
||||
return optimizers
|
@ -7,9 +7,9 @@ import torch
|
||||
import torchmetrics
|
||||
from matplotlib import pyplot as plt
|
||||
from prototorch.models.vis import Vis2DAbstract
|
||||
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
||||
from prototorch.models.y_arch.library.gmlvq import GMLVQ
|
||||
from prototorch.utils.utils import mesh2d
|
||||
from prototorch.y_arch.architectures.base import BaseYArchitecture
|
||||
from prototorch.y_arch.library.gmlvq import GMLVQ
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
||||
DIVERGING_COLOR_MAPS = [
|
@ -1,12 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from prototorch.models.y_arch import (
|
||||
from prototorch.y_arch import (
|
||||
SimpleComparisonMixin,
|
||||
SingleLearningRateMixin,
|
||||
SupervisedArchitecture,
|
||||
WTACompetitionMixin,
|
||||
)
|
||||
from prototorch.models.y_arch.architectures.loss import GLVQLossMixin
|
||||
from prototorch.y_arch.architectures.loss import GLVQLossMixin
|
||||
|
||||
|
||||
class GLVQ(
|
50
prototorch/y_arch/library/gmlvq.py
Normal file
50
prototorch/y_arch/library/gmlvq.py
Normal file
@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from prototorch.core.distances import omega_distance
|
||||
from prototorch.y_arch import (
|
||||
GLVQLossMixin,
|
||||
MultipleLearningRateMixin,
|
||||
OmegaComparisonMixin,
|
||||
SupervisedArchitecture,
|
||||
WTACompetitionMixin,
|
||||
)
|
||||
|
||||
|
||||
class GMLVQ(
|
||||
SupervisedArchitecture,
|
||||
OmegaComparisonMixin,
|
||||
GLVQLossMixin,
|
||||
WTACompetitionMixin,
|
||||
MultipleLearningRateMixin,
|
||||
):
|
||||
"""
|
||||
Generalized Matrix Learning Vector Quantization (GMLVQ)
|
||||
|
||||
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
|
||||
"""
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(
|
||||
MultipleLearningRateMixin.HyperParameters,
|
||||
OmegaComparisonMixin.HyperParameters,
|
||||
GLVQLossMixin.HyperParameters,
|
||||
WTACompetitionMixin.HyperParameters,
|
||||
SupervisedArchitecture.HyperParameters,
|
||||
):
|
||||
"""
|
||||
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
|
||||
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
||||
"""
|
||||
comparison_fn: Callable = omega_distance
|
||||
comparison_args: dict = field(default_factory=lambda: dict())
|
||||
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
lr: dict = field(default_factory=lambda: dict(
|
||||
components_layer=0.1,
|
||||
_omega=0.5,
|
||||
))
|
Loading…
Reference in New Issue
Block a user