feat: distribute GMLVQ into mixins
This commit is contained in:
parent
e922aae432
commit
23d1a71b31
@ -2,12 +2,12 @@ import prototorch as pt
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.core import SMCI
|
from prototorch.core import SMCI
|
||||||
from prototorch.models.y_arch.callbacks import (
|
from prototorch.y_arch.callbacks import (
|
||||||
LogTorchmetricCallback,
|
LogTorchmetricCallback,
|
||||||
PlotLambdaMatrixToTensorboard,
|
PlotLambdaMatrixToTensorboard,
|
||||||
VisGMLVQ2D,
|
VisGMLVQ2D,
|
||||||
)
|
)
|
||||||
from prototorch.models.y_arch.library.gmlvq import GMLVQ
|
from prototorch.y_arch.library.gmlvq import GMLVQ
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -39,8 +39,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Define Hyperparameters
|
# Define Hyperparameters
|
||||||
hyperparameters = GMLVQ.HyperParameters(
|
hyperparameters = GMLVQ.HyperParameters(
|
||||||
lr=0.1,
|
lr=dict(components_layer=0.1, _omega=0),
|
||||||
backbone_lr=5,
|
|
||||||
input_dim=4,
|
input_dim=4,
|
||||||
distribution=dict(
|
distribution=dict(
|
||||||
num_classes=3,
|
num_classes=3,
|
||||||
|
@ -1,41 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
from prototorch.core.distances import euclidean_distance
|
|
||||||
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
|
||||||
from prototorch.nn.wrappers import LambdaLayer
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleComparisonMixin(BaseYArchitecture):
|
|
||||||
"""
|
|
||||||
Simple Comparison
|
|
||||||
|
|
||||||
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# HyperParameters
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
@dataclass
|
|
||||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
||||||
"""
|
|
||||||
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
|
|
||||||
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
|
||||||
"""
|
|
||||||
comparison_fn: Callable = euclidean_distance
|
|
||||||
comparison_args: dict = field(default_factory=lambda: dict())
|
|
||||||
|
|
||||||
# Steps
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
def init_comparison(self, hparams: HyperParameters):
|
|
||||||
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
|
|
||||||
**hparams.comparison_args)
|
|
||||||
|
|
||||||
def comparison(self, batch, components):
|
|
||||||
comp_tensor, _ = components
|
|
||||||
batch_tensor, _ = batch
|
|
||||||
|
|
||||||
comp_tensor = comp_tensor.unsqueeze(1)
|
|
||||||
|
|
||||||
distances = self.comparison_layer(batch_tensor, comp_tensor)
|
|
||||||
|
|
||||||
return distances
|
|
@ -1,36 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from prototorch.models.y_arch import BaseYArchitecture
|
|
||||||
|
|
||||||
|
|
||||||
class SingleLearningRateMixin(BaseYArchitecture):
|
|
||||||
"""
|
|
||||||
Single Learning Rate
|
|
||||||
|
|
||||||
All parameters are updated with a single learning rate.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# HyperParameters
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
@dataclass
|
|
||||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
||||||
"""
|
|
||||||
lr: The learning rate. Default: 0.1.
|
|
||||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
|
||||||
"""
|
|
||||||
lr: float = 0.1
|
|
||||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
|
||||||
|
|
||||||
# Steps
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
def __init__(self, hparams: HyperParameters) -> None:
|
|
||||||
super().__init__(hparams)
|
|
||||||
self.lr = hparams.lr
|
|
||||||
self.optimizer = hparams.optimizer
|
|
||||||
|
|
||||||
# Hooks
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
def configure_optimizers(self):
|
|
||||||
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
|
@ -1,14 +1,22 @@
|
|||||||
from .architectures.base import BaseYArchitecture
|
from .architectures.base import BaseYArchitecture
|
||||||
from .architectures.comparison import SimpleComparisonMixin
|
from .architectures.comparison import (
|
||||||
|
OmegaComparisonMixin,
|
||||||
|
SimpleComparisonMixin,
|
||||||
|
)
|
||||||
from .architectures.competition import WTACompetitionMixin
|
from .architectures.competition import WTACompetitionMixin
|
||||||
from .architectures.components import SupervisedArchitecture
|
from .architectures.components import SupervisedArchitecture
|
||||||
from .architectures.loss import GLVQLossMixin
|
from .architectures.loss import GLVQLossMixin
|
||||||
from .architectures.optimization import SingleLearningRateMixin
|
from .architectures.optimization import (
|
||||||
|
MultipleLearningRateMixin,
|
||||||
|
SingleLearningRateMixin,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseYArchitecture',
|
'BaseYArchitecture',
|
||||||
|
"OmegaComparisonMixin",
|
||||||
"SimpleComparisonMixin",
|
"SimpleComparisonMixin",
|
||||||
"SingleLearningRateMixin",
|
"SingleLearningRateMixin",
|
||||||
|
"MultipleLearningRateMixin",
|
||||||
"SupervisedArchitecture",
|
"SupervisedArchitecture",
|
||||||
"WTACompetitionMixin",
|
"WTACompetitionMixin",
|
||||||
"GLVQLossMixin",
|
"GLVQLossMixin",
|
@ -1,86 +1,50 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable
|
from typing import Callable, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.core.distances import omega_distance
|
from prototorch.core.distances import euclidean_distance
|
||||||
from prototorch.core.initializers import (
|
from prototorch.core.initializers import (
|
||||||
AbstractLinearTransformInitializer,
|
AbstractLinearTransformInitializer,
|
||||||
EyeLinearTransformInitializer,
|
EyeLinearTransformInitializer,
|
||||||
)
|
)
|
||||||
from prototorch.models.y_arch import (
|
|
||||||
GLVQLossMixin,
|
|
||||||
SimpleComparisonMixin,
|
|
||||||
SupervisedArchitecture,
|
|
||||||
WTACompetitionMixin,
|
|
||||||
)
|
|
||||||
from prototorch.nn.wrappers import LambdaLayer
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
from prototorch.y_arch.architectures.base import BaseYArchitecture
|
||||||
|
from torch import Tensor
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
class GMLVQ(
|
class SimpleComparisonMixin(BaseYArchitecture):
|
||||||
SupervisedArchitecture,
|
|
||||||
SimpleComparisonMixin,
|
|
||||||
GLVQLossMixin,
|
|
||||||
WTACompetitionMixin,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Generalized Matrix Learning Vector Quantization (GMLVQ)
|
Simple Comparison
|
||||||
|
|
||||||
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
|
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_omega: torch.Tensor
|
|
||||||
|
|
||||||
# HyperParameters
|
# HyperParameters
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@dataclass
|
@dataclass
|
||||||
class HyperParameters(
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||||
SimpleComparisonMixin.HyperParameters,
|
|
||||||
GLVQLossMixin.HyperParameters,
|
|
||||||
WTACompetitionMixin.HyperParameters,
|
|
||||||
SupervisedArchitecture.HyperParameters,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
|
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
|
||||||
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
||||||
input_dim: Necessary Field: The dimensionality of the input.
|
|
||||||
latent_dim: The dimensionality of the latent space. Default: 2.
|
|
||||||
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
|
|
||||||
"""
|
"""
|
||||||
backbone_lr: float = 0.1
|
comparison_fn: Callable = euclidean_distance
|
||||||
lr: float = 0.1
|
|
||||||
comparison_fn: Callable = omega_distance
|
|
||||||
comparison_args: dict = field(default_factory=lambda: dict())
|
comparison_args: dict = field(default_factory=lambda: dict())
|
||||||
input_dim: int | None = None
|
|
||||||
latent_dim: int = 2
|
|
||||||
omega_initializer: type[
|
|
||||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
|
||||||
|
|
||||||
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
comparison_parameters: dict = field(default_factory=lambda: dict())
|
||||||
|
|
||||||
# Steps
|
# Steps
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def __init__(self, hparams) -> None:
|
def init_comparison(self, hparams: HyperParameters):
|
||||||
super().__init__(hparams)
|
|
||||||
self.lr = hparams.lr
|
|
||||||
self.backbone_lr = hparams.backbone_lr
|
|
||||||
self.optimizer = hparams.optimizer
|
|
||||||
|
|
||||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
|
||||||
if hparams.input_dim is None:
|
|
||||||
raise ValueError("input_dim must be specified.")
|
|
||||||
omega = hparams.omega_initializer().generate(
|
|
||||||
hparams.input_dim,
|
|
||||||
hparams.latent_dim,
|
|
||||||
)
|
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
|
||||||
self.comparison_layer = LambdaLayer(
|
self.comparison_layer = LambdaLayer(
|
||||||
fn=hparams.comparison_fn,
|
fn=hparams.comparison_fn,
|
||||||
**hparams.comparison_args,
|
**hparams.comparison_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.comparison_kwargs: dict[str, Tensor] = dict()
|
||||||
|
|
||||||
def comparison(self, batch, components):
|
def comparison(self, batch, components):
|
||||||
comp_tensor, _ = components
|
comp_tensor, _ = components
|
||||||
batch_tensor, _ = batch
|
batch_tensor, _ = batch
|
||||||
@ -90,21 +54,50 @@ class GMLVQ(
|
|||||||
distances = self.comparison_layer(
|
distances = self.comparison_layer(
|
||||||
batch_tensor,
|
batch_tensor,
|
||||||
comp_tensor,
|
comp_tensor,
|
||||||
self._omega,
|
**self.comparison_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
proto_opt = self.optimizer(
|
class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||||
self.components_layer.parameters(),
|
"""
|
||||||
lr=self.lr,
|
Omega Comparison
|
||||||
)
|
|
||||||
omega_opt = self.optimizer(
|
A comparison layer that uses the positions of the components and the batch for dissimilarity computation.
|
||||||
[self._omega],
|
"""
|
||||||
lr=self.backbone_lr,
|
|
||||||
)
|
_omega: torch.Tensor
|
||||||
return [proto_opt, omega_opt]
|
|
||||||
|
# HyperParameters
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters(SimpleComparisonMixin.HyperParameters):
|
||||||
|
"""
|
||||||
|
input_dim: Necessary Field: The dimensionality of the input.
|
||||||
|
latent_dim: The dimensionality of the latent space. Default: 2.
|
||||||
|
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
|
||||||
|
"""
|
||||||
|
input_dim: int | None = None
|
||||||
|
latent_dim: int = 2
|
||||||
|
omega_initializer: type[
|
||||||
|
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||||
|
super().init_comparison(hparams)
|
||||||
|
|
||||||
|
# Initialize the omega matrix
|
||||||
|
if hparams.input_dim is None:
|
||||||
|
raise ValueError("input_dim must be specified.")
|
||||||
|
else:
|
||||||
|
omega = hparams.omega_initializer().generate(
|
||||||
|
hparams.input_dim,
|
||||||
|
hparams.latent_dim,
|
||||||
|
)
|
||||||
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
self.comparison_kwargs = dict(omega=self._omega)
|
||||||
|
|
||||||
# Properties
|
# Properties
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
@ -1,7 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from prototorch.core.competitions import WTAC
|
from prototorch.core.competitions import WTAC
|
||||||
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
from prototorch.y_arch.architectures.base import BaseYArchitecture
|
||||||
|
|
||||||
|
|
||||||
class WTACompetitionMixin(BaseYArchitecture):
|
class WTACompetitionMixin(BaseYArchitecture):
|
@ -5,7 +5,7 @@ from prototorch.core.initializers import (
|
|||||||
AbstractComponentsInitializer,
|
AbstractComponentsInitializer,
|
||||||
LabelsInitializer,
|
LabelsInitializer,
|
||||||
)
|
)
|
||||||
from prototorch.models.y_arch import BaseYArchitecture
|
from prototorch.y_arch import BaseYArchitecture
|
||||||
|
|
||||||
|
|
||||||
class SupervisedArchitecture(BaseYArchitecture):
|
class SupervisedArchitecture(BaseYArchitecture):
|
@ -1,7 +1,7 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from prototorch.core.losses import GLVQLoss
|
from prototorch.core.losses import GLVQLoss
|
||||||
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
from prototorch.y_arch.architectures.base import BaseYArchitecture
|
||||||
|
|
||||||
|
|
||||||
class GLVQLossMixin(BaseYArchitecture):
|
class GLVQLossMixin(BaseYArchitecture):
|
86
prototorch/y_arch/architectures/optimization.py
Normal file
86
prototorch/y_arch/architectures/optimization.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.y_arch import BaseYArchitecture
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class SingleLearningRateMixin(BaseYArchitecture):
|
||||||
|
"""
|
||||||
|
Single Learning Rate
|
||||||
|
|
||||||
|
All parameters are updated with a single learning rate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# HyperParameters
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||||
|
"""
|
||||||
|
lr: The learning rate. Default: 0.1.
|
||||||
|
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||||
|
"""
|
||||||
|
lr: float = 0.1
|
||||||
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def __init__(self, hparams: HyperParameters) -> None:
|
||||||
|
super().__init__(hparams)
|
||||||
|
self.lr = hparams.lr
|
||||||
|
self.optimizer = hparams.optimizer
|
||||||
|
|
||||||
|
# Hooks
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class MultipleLearningRateMixin(BaseYArchitecture):
|
||||||
|
"""
|
||||||
|
Multiple Learning Rates
|
||||||
|
|
||||||
|
Define Different Learning Rates for different parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# HyperParameters
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||||
|
"""
|
||||||
|
lr: The learning rate. Default: 0.1.
|
||||||
|
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||||
|
"""
|
||||||
|
lr: dict = field(default_factory=lambda: dict())
|
||||||
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def __init__(self, hparams: HyperParameters) -> None:
|
||||||
|
super().__init__(hparams)
|
||||||
|
self.lr = hparams.lr
|
||||||
|
self.optimizer = hparams.optimizer
|
||||||
|
|
||||||
|
# Hooks
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizers = []
|
||||||
|
for name, lr in self.lr.items():
|
||||||
|
if not hasattr(self, name):
|
||||||
|
raise ValueError(f"{name} is not a parameter of {self}")
|
||||||
|
else:
|
||||||
|
model_part = getattr(self, name)
|
||||||
|
if isinstance(model_part, Parameter):
|
||||||
|
optimizers.append(
|
||||||
|
self.optimizer(
|
||||||
|
[model_part],
|
||||||
|
lr=lr, # type: ignore
|
||||||
|
))
|
||||||
|
elif hasattr(model_part, "parameters"):
|
||||||
|
optimizers.append(
|
||||||
|
self.optimizer(
|
||||||
|
model_part.parameters(),
|
||||||
|
lr=lr, # type: ignore
|
||||||
|
))
|
||||||
|
return optimizers
|
@ -7,9 +7,9 @@ import torch
|
|||||||
import torchmetrics
|
import torchmetrics
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from prototorch.models.vis import Vis2DAbstract
|
from prototorch.models.vis import Vis2DAbstract
|
||||||
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
|
||||||
from prototorch.models.y_arch.library.gmlvq import GMLVQ
|
|
||||||
from prototorch.utils.utils import mesh2d
|
from prototorch.utils.utils import mesh2d
|
||||||
|
from prototorch.y_arch.architectures.base import BaseYArchitecture
|
||||||
|
from prototorch.y_arch.library.gmlvq import GMLVQ
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
|
|
||||||
DIVERGING_COLOR_MAPS = [
|
DIVERGING_COLOR_MAPS = [
|
@ -1,12 +1,12 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from prototorch.models.y_arch import (
|
from prototorch.y_arch import (
|
||||||
SimpleComparisonMixin,
|
SimpleComparisonMixin,
|
||||||
SingleLearningRateMixin,
|
SingleLearningRateMixin,
|
||||||
SupervisedArchitecture,
|
SupervisedArchitecture,
|
||||||
WTACompetitionMixin,
|
WTACompetitionMixin,
|
||||||
)
|
)
|
||||||
from prototorch.models.y_arch.architectures.loss import GLVQLossMixin
|
from prototorch.y_arch.architectures.loss import GLVQLossMixin
|
||||||
|
|
||||||
|
|
||||||
class GLVQ(
|
class GLVQ(
|
50
prototorch/y_arch/library/gmlvq.py
Normal file
50
prototorch/y_arch/library/gmlvq.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.core.distances import omega_distance
|
||||||
|
from prototorch.y_arch import (
|
||||||
|
GLVQLossMixin,
|
||||||
|
MultipleLearningRateMixin,
|
||||||
|
OmegaComparisonMixin,
|
||||||
|
SupervisedArchitecture,
|
||||||
|
WTACompetitionMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GMLVQ(
|
||||||
|
SupervisedArchitecture,
|
||||||
|
OmegaComparisonMixin,
|
||||||
|
GLVQLossMixin,
|
||||||
|
WTACompetitionMixin,
|
||||||
|
MultipleLearningRateMixin,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generalized Matrix Learning Vector Quantization (GMLVQ)
|
||||||
|
|
||||||
|
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
|
||||||
|
"""
|
||||||
|
# HyperParameters
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters(
|
||||||
|
MultipleLearningRateMixin.HyperParameters,
|
||||||
|
OmegaComparisonMixin.HyperParameters,
|
||||||
|
GLVQLossMixin.HyperParameters,
|
||||||
|
WTACompetitionMixin.HyperParameters,
|
||||||
|
SupervisedArchitecture.HyperParameters,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
|
||||||
|
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
||||||
|
"""
|
||||||
|
comparison_fn: Callable = omega_distance
|
||||||
|
comparison_args: dict = field(default_factory=lambda: dict())
|
||||||
|
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
|
lr: dict = field(default_factory=lambda: dict(
|
||||||
|
components_layer=0.1,
|
||||||
|
_omega=0.5,
|
||||||
|
))
|
Loading…
Reference in New Issue
Block a user