feat: add GMLVQ with new architecture
This commit is contained in:
parent
3e50d0d817
commit
e922aae432
@ -2,11 +2,12 @@ import prototorch as pt
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.core import SMCI
|
from prototorch.core import SMCI
|
||||||
from prototorch.models.proto_y_architecture.callbacks import (
|
from prototorch.models.y_arch.callbacks import (
|
||||||
LogTorchmetricCallback,
|
LogTorchmetricCallback,
|
||||||
VisGLVQ2D,
|
PlotLambdaMatrixToTensorboard,
|
||||||
|
VisGMLVQ2D,
|
||||||
)
|
)
|
||||||
from prototorch.models.proto_y_architecture.glvq import GLVQ
|
from prototorch.models.y_arch.library.gmlvq import GMLVQ
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -19,8 +20,7 @@ if __name__ == "__main__":
|
|||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
train_ds = pt.datasets.Iris()
|
||||||
train_ds.targets[train_ds.targets == 2.0] = 1.0
|
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
@ -38,17 +38,19 @@ if __name__ == "__main__":
|
|||||||
components_initializer = SMCI(train_ds)
|
components_initializer = SMCI(train_ds)
|
||||||
|
|
||||||
# Define Hyperparameters
|
# Define Hyperparameters
|
||||||
hyperparameters = GLVQ.HyperParameters(
|
hyperparameters = GMLVQ.HyperParameters(
|
||||||
lr=0.1,
|
lr=0.1,
|
||||||
|
backbone_lr=5,
|
||||||
|
input_dim=4,
|
||||||
distribution=dict(
|
distribution=dict(
|
||||||
num_classes=2,
|
num_classes=3,
|
||||||
per_class=1,
|
per_class=1,
|
||||||
),
|
),
|
||||||
component_initializer=components_initializer,
|
component_initializer=components_initializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create Model
|
# Create Model
|
||||||
model = GLVQ(hyperparameters)
|
model = GMLVQ(hyperparameters)
|
||||||
|
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
@ -60,19 +62,17 @@ if __name__ == "__main__":
|
|||||||
stopping_criterion = LogTorchmetricCallback(
|
stopping_criterion = LogTorchmetricCallback(
|
||||||
'recall',
|
'recall',
|
||||||
torchmetrics.Recall,
|
torchmetrics.Recall,
|
||||||
num_classes=2,
|
num_classes=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
es = EarlyStopping(
|
es = EarlyStopping(
|
||||||
monitor=stopping_criterion.name,
|
monitor=stopping_criterion.name,
|
||||||
min_delta=0.001,
|
|
||||||
patience=15,
|
|
||||||
mode="max",
|
mode="max",
|
||||||
check_on_train_epoch_end=True,
|
patience=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Visualization Callback
|
# Visualization Callback
|
||||||
vis = VisGLVQ2D(data=train_ds)
|
vis = VisGMLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Define trainer
|
# Define trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
@ -80,10 +80,9 @@ if __name__ == "__main__":
|
|||||||
vis,
|
vis,
|
||||||
stopping_criterion,
|
stopping_criterion,
|
||||||
es,
|
es,
|
||||||
|
PlotLambdaMatrixToTensorboard(),
|
||||||
],
|
],
|
||||||
gpus=0,
|
max_epochs=1000,
|
||||||
max_epochs=200,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train
|
# Train
|
@ -1,63 +0,0 @@
|
|||||||
from typing import Optional, Type
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
import torchmetrics
|
|
||||||
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
|
|
||||||
from prototorch.models.vis import Vis2DAbstract
|
|
||||||
from prototorch.utils.utils import mesh2d
|
|
||||||
|
|
||||||
|
|
||||||
class LogTorchmetricCallback(pl.Callback):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name,
|
|
||||||
metric: Type[torchmetrics.Metric],
|
|
||||||
on="prediction",
|
|
||||||
**metric_kwargs,
|
|
||||||
) -> None:
|
|
||||||
self.name = name
|
|
||||||
self.metric = metric
|
|
||||||
self.metric_kwargs = metric_kwargs
|
|
||||||
self.on = on
|
|
||||||
|
|
||||||
def setup(
|
|
||||||
self,
|
|
||||||
trainer: pl.Trainer,
|
|
||||||
pl_module: BaseYArchitecture,
|
|
||||||
stage: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
if self.on == "prediction":
|
|
||||||
pl_module.register_torchmetric(
|
|
||||||
self.name,
|
|
||||||
self.metric,
|
|
||||||
**self.metric_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{self.on} is no valid metric hook")
|
|
||||||
|
|
||||||
|
|
||||||
class VisGLVQ2D(Vis2DAbstract):
|
|
||||||
|
|
||||||
def visualize(self, pl_module):
|
|
||||||
protos = pl_module.prototypes
|
|
||||||
plabels = pl_module.prototype_labels
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
|
||||||
ax = self.setup_ax()
|
|
||||||
self.plot_protos(ax, protos, plabels)
|
|
||||||
if x_train is not None:
|
|
||||||
self.plot_data(ax, x_train, y_train)
|
|
||||||
mesh_input, xx, yy = mesh2d(
|
|
||||||
np.vstack([x_train, protos]),
|
|
||||||
self.border,
|
|
||||||
self.resolution,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
|
||||||
_components = pl_module.components_layer.components
|
|
||||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
|
||||||
y_pred = pl_module.predict(mesh_input)
|
|
||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
|
@ -1,140 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Callable, Type
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from prototorch.core.competitions import WTAC
|
|
||||||
from prototorch.core.components import LabeledComponents
|
|
||||||
from prototorch.core.distances import euclidean_distance
|
|
||||||
from prototorch.core.initializers import (
|
|
||||||
AbstractComponentsInitializer,
|
|
||||||
LabelsInitializer,
|
|
||||||
)
|
|
||||||
from prototorch.core.losses import GLVQLoss
|
|
||||||
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
|
|
||||||
from prototorch.nn.wrappers import LambdaLayer
|
|
||||||
|
|
||||||
|
|
||||||
class SupervisedArchitecture(BaseYArchitecture):
|
|
||||||
components_layer: LabeledComponents
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HyperParameters:
|
|
||||||
distribution: dict[str, int]
|
|
||||||
component_initializer: AbstractComponentsInitializer
|
|
||||||
|
|
||||||
def init_components(self, hparams: HyperParameters):
|
|
||||||
self.components_layer = LabeledComponents(
|
|
||||||
distribution=hparams.distribution,
|
|
||||||
components_initializer=hparams.component_initializer,
|
|
||||||
labels_initializer=LabelsInitializer(),
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prototypes(self):
|
|
||||||
return self.components_layer.components.detach().cpu()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prototype_labels(self):
|
|
||||||
return self.components_layer.labels.detach().cpu()
|
|
||||||
|
|
||||||
|
|
||||||
class WTACompetitionMixin(BaseYArchitecture):
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def init_inference(self, hparams: HyperParameters):
|
|
||||||
self.competition_layer = WTAC()
|
|
||||||
|
|
||||||
def inference(self, comparison_measures, components):
|
|
||||||
comp_labels = components[1]
|
|
||||||
return self.competition_layer(comparison_measures, comp_labels)
|
|
||||||
|
|
||||||
|
|
||||||
class GLVQLossMixin(BaseYArchitecture):
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
||||||
margin: float = 0.0
|
|
||||||
|
|
||||||
transfer_fn: str = "sigmoid_beta"
|
|
||||||
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
|
|
||||||
|
|
||||||
def init_loss(self, hparams: HyperParameters):
|
|
||||||
self.loss_layer = GLVQLoss(
|
|
||||||
margin=hparams.margin,
|
|
||||||
transfer_fn=hparams.transfer_fn,
|
|
||||||
**hparams.transfer_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
def loss(self, comparison_measures, batch, components):
|
|
||||||
target = batch[1]
|
|
||||||
comp_labels = components[1]
|
|
||||||
loss = self.loss_layer(comparison_measures, target, comp_labels)
|
|
||||||
self.log('loss', loss)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class SingleLearningRateMixin(BaseYArchitecture):
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
||||||
# Training Hyperparameters
|
|
||||||
lr: float = 0.01
|
|
||||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
|
||||||
|
|
||||||
def __init__(self, hparams: HyperParameters) -> None:
|
|
||||||
super().__init__(hparams)
|
|
||||||
self.lr = hparams.lr
|
|
||||||
self.optimizer = hparams.optimizer
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleComparisonMixin(BaseYArchitecture):
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
||||||
# Training Hyperparameters
|
|
||||||
comparison_fn: Callable = euclidean_distance
|
|
||||||
comparison_args: dict = field(default_factory=lambda: dict())
|
|
||||||
|
|
||||||
def init_comparison(self, hparams: HyperParameters):
|
|
||||||
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
|
|
||||||
**hparams.comparison_args)
|
|
||||||
|
|
||||||
def comparison(self, batch, components):
|
|
||||||
comp_tensor, _ = components
|
|
||||||
batch_tensor, _ = batch
|
|
||||||
|
|
||||||
comp_tensor = comp_tensor.unsqueeze(1)
|
|
||||||
|
|
||||||
distances = self.comparison_layer(batch_tensor, comp_tensor)
|
|
||||||
|
|
||||||
return distances
|
|
||||||
|
|
||||||
|
|
||||||
# ##############################################################################
|
|
||||||
# GLVQ
|
|
||||||
# ##############################################################################
|
|
||||||
class GLVQ(
|
|
||||||
SupervisedArchitecture,
|
|
||||||
SimpleComparisonMixin,
|
|
||||||
GLVQLossMixin,
|
|
||||||
WTACompetitionMixin,
|
|
||||||
SingleLearningRateMixin,
|
|
||||||
):
|
|
||||||
"""GLVQ using the new Scheme
|
|
||||||
"""
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HyperParameters(
|
|
||||||
SimpleComparisonMixin.HyperParameters,
|
|
||||||
SingleLearningRateMixin.HyperParameters,
|
|
||||||
GLVQLossMixin.HyperParameters,
|
|
||||||
WTACompetitionMixin.HyperParameters,
|
|
||||||
SupervisedArchitecture.HyperParameters,
|
|
||||||
):
|
|
||||||
pass
|
|
15
prototorch/models/y_arch/__init__.py
Normal file
15
prototorch/models/y_arch/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from .architectures.base import BaseYArchitecture
|
||||||
|
from .architectures.comparison import SimpleComparisonMixin
|
||||||
|
from .architectures.competition import WTACompetitionMixin
|
||||||
|
from .architectures.components import SupervisedArchitecture
|
||||||
|
from .architectures.loss import GLVQLossMixin
|
||||||
|
from .architectures.optimization import SingleLearningRateMixin
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'BaseYArchitecture',
|
||||||
|
"SimpleComparisonMixin",
|
||||||
|
"SingleLearningRateMixin",
|
||||||
|
"SupervisedArchitecture",
|
||||||
|
"WTACompetitionMixin",
|
||||||
|
"GLVQLossMixin",
|
||||||
|
]
|
@ -1,12 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
CLCC Scheme
|
Proto Y Architecture
|
||||||
|
|
||||||
CLCC is a LVQ scheme containing 4 steps
|
|
||||||
- Components
|
|
||||||
- Latent Space
|
|
||||||
- Comparison
|
|
||||||
- Competition
|
|
||||||
|
|
||||||
|
Network architecture for Component based Learning.
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
41
prototorch/models/y_arch/architectures/comparison.py
Normal file
41
prototorch/models/y_arch/architectures/comparison.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from prototorch.core.distances import euclidean_distance
|
||||||
|
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
||||||
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleComparisonMixin(BaseYArchitecture):
|
||||||
|
"""
|
||||||
|
Simple Comparison
|
||||||
|
|
||||||
|
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# HyperParameters
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||||
|
"""
|
||||||
|
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
|
||||||
|
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
||||||
|
"""
|
||||||
|
comparison_fn: Callable = euclidean_distance
|
||||||
|
comparison_args: dict = field(default_factory=lambda: dict())
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def init_comparison(self, hparams: HyperParameters):
|
||||||
|
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
|
||||||
|
**hparams.comparison_args)
|
||||||
|
|
||||||
|
def comparison(self, batch, components):
|
||||||
|
comp_tensor, _ = components
|
||||||
|
batch_tensor, _ = batch
|
||||||
|
|
||||||
|
comp_tensor = comp_tensor.unsqueeze(1)
|
||||||
|
|
||||||
|
distances = self.comparison_layer(batch_tensor, comp_tensor)
|
||||||
|
|
||||||
|
return distances
|
29
prototorch/models/y_arch/architectures/competition.py
Normal file
29
prototorch/models/y_arch/architectures/competition.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from prototorch.core.competitions import WTAC
|
||||||
|
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
||||||
|
|
||||||
|
|
||||||
|
class WTACompetitionMixin(BaseYArchitecture):
|
||||||
|
"""
|
||||||
|
Winner Take All Competition
|
||||||
|
|
||||||
|
A competition layer that uses the winner-take-all strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# HyperParameters
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||||
|
"""
|
||||||
|
No hyperparameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def init_inference(self, hparams: HyperParameters):
|
||||||
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
|
def inference(self, comparison_measures, components):
|
||||||
|
comp_labels = components[1]
|
||||||
|
return self.competition_layer(comparison_measures, comp_labels)
|
53
prototorch/models/y_arch/architectures/components.py
Normal file
53
prototorch/models/y_arch/architectures/components.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from prototorch.core.components import LabeledComponents
|
||||||
|
from prototorch.core.initializers import (
|
||||||
|
AbstractComponentsInitializer,
|
||||||
|
LabelsInitializer,
|
||||||
|
)
|
||||||
|
from prototorch.models.y_arch import BaseYArchitecture
|
||||||
|
|
||||||
|
|
||||||
|
class SupervisedArchitecture(BaseYArchitecture):
|
||||||
|
"""
|
||||||
|
Supervised Architecture
|
||||||
|
|
||||||
|
An architecture that uses labeled Components as component Layer.
|
||||||
|
"""
|
||||||
|
components_layer: LabeledComponents
|
||||||
|
|
||||||
|
# HyperParameters
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters:
|
||||||
|
"""
|
||||||
|
distribution: A valid prototype distribution. No default possible.
|
||||||
|
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
|
||||||
|
"""
|
||||||
|
distribution: "dict[str, int]"
|
||||||
|
component_initializer: AbstractComponentsInitializer
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def init_components(self, hparams: HyperParameters):
|
||||||
|
self.components_layer = LabeledComponents(
|
||||||
|
distribution=hparams.distribution,
|
||||||
|
components_initializer=hparams.component_initializer,
|
||||||
|
labels_initializer=LabelsInitializer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Properties
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def prototypes(self):
|
||||||
|
"""
|
||||||
|
Returns the position of the prototypes.
|
||||||
|
"""
|
||||||
|
return self.components_layer.components.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prototype_labels(self):
|
||||||
|
"""
|
||||||
|
Returns the labels of the prototypes.
|
||||||
|
"""
|
||||||
|
return self.components_layer.labels.detach().cpu()
|
42
prototorch/models/y_arch/architectures/loss.py
Normal file
42
prototorch/models/y_arch/architectures/loss.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from prototorch.core.losses import GLVQLoss
|
||||||
|
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
||||||
|
|
||||||
|
|
||||||
|
class GLVQLossMixin(BaseYArchitecture):
|
||||||
|
"""
|
||||||
|
GLVQ Loss
|
||||||
|
|
||||||
|
A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# HyperParameters
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||||
|
"""
|
||||||
|
margin: The margin of the GLVQ loss. Default: 0.0.
|
||||||
|
transfer_fn: Transfer function to use. Default: sigmoid_beta.
|
||||||
|
transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}.
|
||||||
|
"""
|
||||||
|
margin: float = 0.0
|
||||||
|
|
||||||
|
transfer_fn: str = "sigmoid_beta"
|
||||||
|
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def init_loss(self, hparams: HyperParameters):
|
||||||
|
self.loss_layer = GLVQLoss(
|
||||||
|
margin=hparams.margin,
|
||||||
|
transfer_fn=hparams.transfer_fn,
|
||||||
|
**hparams.transfer_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
def loss(self, comparison_measures, batch, components):
|
||||||
|
target = batch[1]
|
||||||
|
comp_labels = components[1]
|
||||||
|
loss = self.loss_layer(comparison_measures, target, comp_labels)
|
||||||
|
self.log('loss', loss)
|
||||||
|
return loss
|
36
prototorch/models/y_arch/architectures/optimization.py
Normal file
36
prototorch/models/y_arch/architectures/optimization.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.models.y_arch import BaseYArchitecture
|
||||||
|
|
||||||
|
|
||||||
|
class SingleLearningRateMixin(BaseYArchitecture):
|
||||||
|
"""
|
||||||
|
Single Learning Rate
|
||||||
|
|
||||||
|
All parameters are updated with a single learning rate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# HyperParameters
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||||
|
"""
|
||||||
|
lr: The learning rate. Default: 0.1.
|
||||||
|
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||||
|
"""
|
||||||
|
lr: float = 0.1
|
||||||
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def __init__(self, hparams: HyperParameters) -> None:
|
||||||
|
super().__init__(hparams)
|
||||||
|
self.lr = hparams.lr
|
||||||
|
self.optimizer = hparams.optimizer
|
||||||
|
|
||||||
|
# Hooks
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
149
prototorch/models/y_arch/callbacks.py
Normal file
149
prototorch/models/y_arch/callbacks.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
import torchmetrics
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from prototorch.models.vis import Vis2DAbstract
|
||||||
|
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
||||||
|
from prototorch.models.y_arch.library.gmlvq import GMLVQ
|
||||||
|
from prototorch.utils.utils import mesh2d
|
||||||
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
DIVERGING_COLOR_MAPS = [
|
||||||
|
'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn',
|
||||||
|
'Spectral', 'coolwarm', 'bwr', 'seismic'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LogTorchmetricCallback(pl.Callback):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name,
|
||||||
|
metric: Type[torchmetrics.Metric],
|
||||||
|
on="prediction",
|
||||||
|
**metric_kwargs,
|
||||||
|
) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.metric = metric
|
||||||
|
self.metric_kwargs = metric_kwargs
|
||||||
|
self.on = on
|
||||||
|
|
||||||
|
def setup(
|
||||||
|
self,
|
||||||
|
trainer: pl.Trainer,
|
||||||
|
pl_module: BaseYArchitecture,
|
||||||
|
stage: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
if self.on == "prediction":
|
||||||
|
pl_module.register_torchmetric(
|
||||||
|
self.name,
|
||||||
|
self.metric,
|
||||||
|
**self.metric_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{self.on} is no valid metric hook")
|
||||||
|
|
||||||
|
|
||||||
|
class VisGLVQ2D(Vis2DAbstract):
|
||||||
|
|
||||||
|
def visualize(self, pl_module):
|
||||||
|
protos = pl_module.prototypes
|
||||||
|
plabels = pl_module.prototype_labels
|
||||||
|
x_train, y_train = self.x_train, self.y_train
|
||||||
|
ax = self.setup_ax()
|
||||||
|
self.plot_protos(ax, protos, plabels)
|
||||||
|
if x_train is not None:
|
||||||
|
self.plot_data(ax, x_train, y_train)
|
||||||
|
mesh_input, xx, yy = mesh2d(
|
||||||
|
np.vstack([x_train, protos]),
|
||||||
|
self.border,
|
||||||
|
self.resolution,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
||||||
|
_components = pl_module.components_layer.components
|
||||||
|
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||||
|
y_pred = pl_module.predict(mesh_input)
|
||||||
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
|
|
||||||
|
class VisGMLVQ2D(Vis2DAbstract):
|
||||||
|
|
||||||
|
def __init__(self, *args, ev_proj=True, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ev_proj = ev_proj
|
||||||
|
|
||||||
|
def visualize(self, pl_module):
|
||||||
|
protos = pl_module.prototypes
|
||||||
|
plabels = pl_module.prototype_labels
|
||||||
|
x_train, y_train = self.x_train, self.y_train
|
||||||
|
device = pl_module.device
|
||||||
|
omega = pl_module._omega.detach()
|
||||||
|
lam = omega @ omega.T
|
||||||
|
u, _, _ = torch.pca_lowrank(lam, q=2)
|
||||||
|
with torch.no_grad():
|
||||||
|
x_train = torch.Tensor(x_train).to(device)
|
||||||
|
x_train = x_train @ u
|
||||||
|
x_train = x_train.cpu().detach()
|
||||||
|
if self.show_protos:
|
||||||
|
with torch.no_grad():
|
||||||
|
protos = torch.Tensor(protos).to(device)
|
||||||
|
protos = protos @ u
|
||||||
|
protos = protos.cpu().detach()
|
||||||
|
ax = self.setup_ax()
|
||||||
|
self.plot_data(ax, x_train, y_train)
|
||||||
|
if self.show_protos:
|
||||||
|
self.plot_protos(ax, protos, plabels)
|
||||||
|
|
||||||
|
|
||||||
|
class PlotLambdaMatrixToTensorboard(pl.Callback):
|
||||||
|
|
||||||
|
def __init__(self, cmap='seismic') -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.cmap = cmap
|
||||||
|
|
||||||
|
if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str:
|
||||||
|
warnings.warn(
|
||||||
|
f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_train_start(self, trainer, pl_module: GMLVQ):
|
||||||
|
self.plot_lambda(trainer, pl_module)
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
|
||||||
|
self.plot_lambda(trainer, pl_module)
|
||||||
|
|
||||||
|
def plot_lambda(self, trainer, pl_module: GMLVQ):
|
||||||
|
|
||||||
|
self.fig, self.ax = plt.subplots(1, 1)
|
||||||
|
|
||||||
|
# plot lambda matrix
|
||||||
|
l_matrix = pl_module.lambda_matrix
|
||||||
|
|
||||||
|
# normalize lambda matrix
|
||||||
|
l_matrix = l_matrix / torch.max(torch.abs(l_matrix))
|
||||||
|
|
||||||
|
# plot lambda matrix
|
||||||
|
self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1)
|
||||||
|
|
||||||
|
self.fig.colorbar(self.ax.images[-1])
|
||||||
|
|
||||||
|
# add title
|
||||||
|
self.ax.set_title('Lambda Matrix')
|
||||||
|
|
||||||
|
# add to tensorboard
|
||||||
|
if isinstance(trainer.logger, TensorBoardLogger):
|
||||||
|
trainer.logger.experiment.add_figure(
|
||||||
|
f"lambda_matrix",
|
||||||
|
self.fig,
|
||||||
|
trainer.global_step,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
|
||||||
|
)
|
5
prototorch/models/y_arch/library/__init__.py
Normal file
5
prototorch/models/y_arch/library/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .glvq import GLVQ
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GLVQ",
|
||||||
|
]
|
35
prototorch/models/y_arch/library/glvq.py
Normal file
35
prototorch/models/y_arch/library/glvq.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from prototorch.models.y_arch import (
|
||||||
|
SimpleComparisonMixin,
|
||||||
|
SingleLearningRateMixin,
|
||||||
|
SupervisedArchitecture,
|
||||||
|
WTACompetitionMixin,
|
||||||
|
)
|
||||||
|
from prototorch.models.y_arch.architectures.loss import GLVQLossMixin
|
||||||
|
|
||||||
|
|
||||||
|
class GLVQ(
|
||||||
|
SupervisedArchitecture,
|
||||||
|
SimpleComparisonMixin,
|
||||||
|
GLVQLossMixin,
|
||||||
|
WTACompetitionMixin,
|
||||||
|
SingleLearningRateMixin,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generalized Learning Vector Quantization (GLVQ)
|
||||||
|
|
||||||
|
A GLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters(
|
||||||
|
SimpleComparisonMixin.HyperParameters,
|
||||||
|
SingleLearningRateMixin.HyperParameters,
|
||||||
|
GLVQLossMixin.HyperParameters,
|
||||||
|
WTACompetitionMixin.HyperParameters,
|
||||||
|
SupervisedArchitecture.HyperParameters,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
No hyperparameters.
|
||||||
|
"""
|
119
prototorch/models/y_arch/library/gmlvq.py
Normal file
119
prototorch/models/y_arch/library/gmlvq.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.core.distances import omega_distance
|
||||||
|
from prototorch.core.initializers import (
|
||||||
|
AbstractLinearTransformInitializer,
|
||||||
|
EyeLinearTransformInitializer,
|
||||||
|
)
|
||||||
|
from prototorch.models.y_arch import (
|
||||||
|
GLVQLossMixin,
|
||||||
|
SimpleComparisonMixin,
|
||||||
|
SupervisedArchitecture,
|
||||||
|
WTACompetitionMixin,
|
||||||
|
)
|
||||||
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class GMLVQ(
|
||||||
|
SupervisedArchitecture,
|
||||||
|
SimpleComparisonMixin,
|
||||||
|
GLVQLossMixin,
|
||||||
|
WTACompetitionMixin,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generalized Matrix Learning Vector Quantization (GMLVQ)
|
||||||
|
|
||||||
|
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_omega: torch.Tensor
|
||||||
|
|
||||||
|
# HyperParameters
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@dataclass
|
||||||
|
class HyperParameters(
|
||||||
|
SimpleComparisonMixin.HyperParameters,
|
||||||
|
GLVQLossMixin.HyperParameters,
|
||||||
|
WTACompetitionMixin.HyperParameters,
|
||||||
|
SupervisedArchitecture.HyperParameters,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
|
||||||
|
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
||||||
|
input_dim: Necessary Field: The dimensionality of the input.
|
||||||
|
latent_dim: The dimensionality of the latent space. Default: 2.
|
||||||
|
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
|
||||||
|
"""
|
||||||
|
backbone_lr: float = 0.1
|
||||||
|
lr: float = 0.1
|
||||||
|
comparison_fn: Callable = omega_distance
|
||||||
|
comparison_args: dict = field(default_factory=lambda: dict())
|
||||||
|
input_dim: int | None = None
|
||||||
|
latent_dim: int = 2
|
||||||
|
omega_initializer: type[
|
||||||
|
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||||
|
|
||||||
|
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def __init__(self, hparams) -> None:
|
||||||
|
super().__init__(hparams)
|
||||||
|
self.lr = hparams.lr
|
||||||
|
self.backbone_lr = hparams.backbone_lr
|
||||||
|
self.optimizer = hparams.optimizer
|
||||||
|
|
||||||
|
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||||
|
if hparams.input_dim is None:
|
||||||
|
raise ValueError("input_dim must be specified.")
|
||||||
|
omega = hparams.omega_initializer().generate(
|
||||||
|
hparams.input_dim,
|
||||||
|
hparams.latent_dim,
|
||||||
|
)
|
||||||
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
self.comparison_layer = LambdaLayer(
|
||||||
|
fn=hparams.comparison_fn,
|
||||||
|
**hparams.comparison_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
def comparison(self, batch, components):
|
||||||
|
comp_tensor, _ = components
|
||||||
|
batch_tensor, _ = batch
|
||||||
|
|
||||||
|
comp_tensor = comp_tensor.unsqueeze(1)
|
||||||
|
|
||||||
|
distances = self.comparison_layer(
|
||||||
|
batch_tensor,
|
||||||
|
comp_tensor,
|
||||||
|
self._omega,
|
||||||
|
)
|
||||||
|
|
||||||
|
return distances
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
proto_opt = self.optimizer(
|
||||||
|
self.components_layer.parameters(),
|
||||||
|
lr=self.lr,
|
||||||
|
)
|
||||||
|
omega_opt = self.optimizer(
|
||||||
|
[self._omega],
|
||||||
|
lr=self.backbone_lr,
|
||||||
|
)
|
||||||
|
return [proto_opt, omega_opt]
|
||||||
|
|
||||||
|
# Properties
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def omega_matrix(self):
|
||||||
|
return self._omega.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lambda_matrix(self):
|
||||||
|
omega = self._omega.detach()
|
||||||
|
lam = omega @ omega.T
|
||||||
|
return lam.detach().cpu()
|
Loading…
Reference in New Issue
Block a user