feat: add GMLVQ with new architecture

This commit is contained in:
Alexander Engelsberger 2022-05-19 16:13:08 +02:00
parent 3e50d0d817
commit e922aae432
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
15 changed files with 541 additions and 226 deletions

View File

@ -2,11 +2,12 @@ import prototorch as pt
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.models.proto_y_architecture.callbacks import (
from prototorch.models.y_arch.callbacks import (
LogTorchmetricCallback,
VisGLVQ2D,
PlotLambdaMatrixToTensorboard,
VisGMLVQ2D,
)
from prototorch.models.proto_y_architecture.glvq import GLVQ
from prototorch.models.y_arch.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
@ -19,8 +20,7 @@ if __name__ == "__main__":
# ------------------------------------------------------------
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
train_ds.targets[train_ds.targets == 2.0] = 1.0
train_ds = pt.datasets.Iris()
# Dataloader
train_loader = DataLoader(
@ -38,17 +38,19 @@ if __name__ == "__main__":
components_initializer = SMCI(train_ds)
# Define Hyperparameters
hyperparameters = GLVQ.HyperParameters(
hyperparameters = GMLVQ.HyperParameters(
lr=0.1,
backbone_lr=5,
input_dim=4,
distribution=dict(
num_classes=2,
num_classes=3,
per_class=1,
),
component_initializer=components_initializer,
)
# Create Model
model = GLVQ(hyperparameters)
model = GMLVQ(hyperparameters)
print(model)
@ -60,19 +62,17 @@ if __name__ == "__main__":
stopping_criterion = LogTorchmetricCallback(
'recall',
torchmetrics.Recall,
num_classes=2,
num_classes=3,
)
es = EarlyStopping(
monitor=stopping_criterion.name,
min_delta=0.001,
patience=15,
mode="max",
check_on_train_epoch_end=True,
patience=10,
)
# Visualization Callback
vis = VisGLVQ2D(data=train_ds)
vis = VisGMLVQ2D(data=train_ds)
# Define trainer
trainer = pl.Trainer(
@ -80,10 +80,9 @@ if __name__ == "__main__":
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
],
gpus=0,
max_epochs=200,
log_every_n_steps=1,
max_epochs=1000,
)
# Train

View File

@ -1,63 +0,0 @@
from typing import Optional, Type
import numpy as np
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d
class LogTorchmetricCallback(pl.Callback):
def __init__(
self,
name,
metric: Type[torchmetrics.Metric],
on="prediction",
**metric_kwargs,
) -> None:
self.name = name
self.metric = metric
self.metric_kwargs = metric_kwargs
self.on = on
def setup(
self,
trainer: pl.Trainer,
pl_module: BaseYArchitecture,
stage: Optional[str] = None,
) -> None:
if self.on == "prediction":
pl_module.register_torchmetric(
self.name,
self.metric,
**self.metric_kwargs,
)
else:
raise ValueError(f"{self.on} is no valid metric hook")
class VisGLVQ2D(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(
np.vstack([x_train, protos]),
self.border,
self.resolution,
)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.components_layer.components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)

View File

@ -1,140 +0,0 @@
from dataclasses import dataclass, field
from typing import Callable, Type
import torch
from prototorch.core.competitions import WTAC
from prototorch.core.components import LabeledComponents
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
)
from prototorch.core.losses import GLVQLoss
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
from prototorch.nn.wrappers import LambdaLayer
class SupervisedArchitecture(BaseYArchitecture):
components_layer: LabeledComponents
@dataclass
class HyperParameters:
distribution: dict[str, int]
component_initializer: AbstractComponentsInitializer
def init_components(self, hparams: HyperParameters):
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
@property
def prototypes(self):
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
return self.components_layer.labels.detach().cpu()
class WTACompetitionMixin(BaseYArchitecture):
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
pass
def init_inference(self, hparams: HyperParameters):
self.competition_layer = WTAC()
def inference(self, comparison_measures, components):
comp_labels = components[1]
return self.competition_layer(comparison_measures, comp_labels)
class GLVQLossMixin(BaseYArchitecture):
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
margin: float = 0.0
transfer_fn: str = "sigmoid_beta"
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
def init_loss(self, hparams: HyperParameters):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
**hparams.transfer_args,
)
def loss(self, comparison_measures, batch, components):
target = batch[1]
comp_labels = components[1]
loss = self.loss_layer(comparison_measures, target, comp_labels)
self.log('loss', loss)
return loss
class SingleLearningRateMixin(BaseYArchitecture):
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
# Training Hyperparameters
lr: float = 0.01
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
class SimpleComparisonMixin(BaseYArchitecture):
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
# Training Hyperparameters
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
**hparams.comparison_args)
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(batch_tensor, comp_tensor)
return distances
# ##############################################################################
# GLVQ
# ##############################################################################
class GLVQ(
SupervisedArchitecture,
SimpleComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
SingleLearningRateMixin,
):
"""GLVQ using the new Scheme
"""
@dataclass
class HyperParameters(
SimpleComparisonMixin.HyperParameters,
SingleLearningRateMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
pass

View File

@ -0,0 +1,15 @@
from .architectures.base import BaseYArchitecture
from .architectures.comparison import SimpleComparisonMixin
from .architectures.competition import WTACompetitionMixin
from .architectures.components import SupervisedArchitecture
from .architectures.loss import GLVQLossMixin
from .architectures.optimization import SingleLearningRateMixin
__all__ = [
'BaseYArchitecture',
"SimpleComparisonMixin",
"SingleLearningRateMixin",
"SupervisedArchitecture",
"WTACompetitionMixin",
"GLVQLossMixin",
]

View File

@ -1,12 +1,7 @@
"""
CLCC Scheme
CLCC is a LVQ scheme containing 4 steps
- Components
- Latent Space
- Comparison
- Competition
Proto Y Architecture
Network architecture for Component based Learning.
"""
from dataclasses import dataclass
from typing import (

View File

@ -0,0 +1,41 @@
from dataclasses import dataclass, field
from typing import Callable
from prototorch.core.distances import euclidean_distance
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
from prototorch.nn.wrappers import LambdaLayer
class SimpleComparisonMixin(BaseYArchitecture):
"""
Simple Comparison
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
**hparams.comparison_args)
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(batch_tensor, comp_tensor)
return distances

View File

@ -0,0 +1,29 @@
from dataclasses import dataclass
from prototorch.core.competitions import WTAC
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
class WTACompetitionMixin(BaseYArchitecture):
"""
Winner Take All Competition
A competition layer that uses the winner-take-all strategy.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
No hyperparameters.
"""
# Steps
# ----------------------------------------------------------------------------------------------------
def init_inference(self, hparams: HyperParameters):
self.competition_layer = WTAC()
def inference(self, comparison_measures, components):
comp_labels = components[1]
return self.competition_layer(comparison_measures, comp_labels)

View File

@ -0,0 +1,53 @@
from dataclasses import dataclass
from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
)
from prototorch.models.y_arch import BaseYArchitecture
class SupervisedArchitecture(BaseYArchitecture):
"""
Supervised Architecture
An architecture that uses labeled Components as component Layer.
"""
components_layer: LabeledComponents
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters:
"""
distribution: A valid prototype distribution. No default possible.
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
"""
distribution: "dict[str, int]"
component_initializer: AbstractComponentsInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_components(self, hparams: HyperParameters):
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def prototypes(self):
"""
Returns the position of the prototypes.
"""
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
"""
Returns the labels of the prototypes.
"""
return self.components_layer.labels.detach().cpu()

View File

@ -0,0 +1,42 @@
from dataclasses import dataclass, field
from prototorch.core.losses import GLVQLoss
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
class GLVQLossMixin(BaseYArchitecture):
"""
GLVQ Loss
A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
margin: The margin of the GLVQ loss. Default: 0.0.
transfer_fn: Transfer function to use. Default: sigmoid_beta.
transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}.
"""
margin: float = 0.0
transfer_fn: str = "sigmoid_beta"
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
# Steps
# ----------------------------------------------------------------------------------------------------
def init_loss(self, hparams: HyperParameters):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
**hparams.transfer_args,
)
def loss(self, comparison_measures, batch, components):
target = batch[1]
comp_labels = components[1]
loss = self.loss_layer(comparison_measures, target, comp_labels)
self.log('loss', loss)
return loss

View File

@ -0,0 +1,36 @@
from dataclasses import dataclass
from typing import Type
import torch
from prototorch.models.y_arch import BaseYArchitecture
class SingleLearningRateMixin(BaseYArchitecture):
"""
Single Learning Rate
All parameters are updated with a single learning rate.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore

View File

@ -0,0 +1,149 @@
import warnings
from typing import Optional, Type
import numpy as np
import pytorch_lightning as pl
import torch
import torchmetrics
from matplotlib import pyplot as plt
from prototorch.models.vis import Vis2DAbstract
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
from prototorch.models.y_arch.library.gmlvq import GMLVQ
from prototorch.utils.utils import mesh2d
from pytorch_lightning.loggers import TensorBoardLogger
DIVERGING_COLOR_MAPS = [
'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn',
'Spectral', 'coolwarm', 'bwr', 'seismic'
]
class LogTorchmetricCallback(pl.Callback):
def __init__(
self,
name,
metric: Type[torchmetrics.Metric],
on="prediction",
**metric_kwargs,
) -> None:
self.name = name
self.metric = metric
self.metric_kwargs = metric_kwargs
self.on = on
def setup(
self,
trainer: pl.Trainer,
pl_module: BaseYArchitecture,
stage: Optional[str] = None,
) -> None:
if self.on == "prediction":
pl_module.register_torchmetric(
self.name,
self.metric,
**self.metric_kwargs,
)
else:
raise ValueError(f"{self.on} is no valid metric hook")
class VisGLVQ2D(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(
np.vstack([x_train, protos]),
self.border,
self.resolution,
)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.components_layer.components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
class PlotLambdaMatrixToTensorboard(pl.Callback):
def __init__(self, cmap='seismic') -> None:
super().__init__()
self.cmap = cmap
if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str:
warnings.warn(
f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}"
)
def on_train_start(self, trainer, pl_module: GMLVQ):
self.plot_lambda(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
self.plot_lambda(trainer, pl_module)
def plot_lambda(self, trainer, pl_module: GMLVQ):
self.fig, self.ax = plt.subplots(1, 1)
# plot lambda matrix
l_matrix = pl_module.lambda_matrix
# normalize lambda matrix
l_matrix = l_matrix / torch.max(torch.abs(l_matrix))
# plot lambda matrix
self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1)
self.fig.colorbar(self.ax.images[-1])
# add title
self.ax.set_title('Lambda Matrix')
# add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure(
f"lambda_matrix",
self.fig,
trainer.global_step,
)
else:
warnings.warn(
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
)

View File

@ -0,0 +1,5 @@
from .glvq import GLVQ
__all__ = [
"GLVQ",
]

View File

@ -0,0 +1,35 @@
from dataclasses import dataclass
from prototorch.models.y_arch import (
SimpleComparisonMixin,
SingleLearningRateMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
from prototorch.models.y_arch.architectures.loss import GLVQLossMixin
class GLVQ(
SupervisedArchitecture,
SimpleComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
SingleLearningRateMixin,
):
"""
Generalized Learning Vector Quantization (GLVQ)
A GLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
@dataclass
class HyperParameters(
SimpleComparisonMixin.HyperParameters,
SingleLearningRateMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
No hyperparameters.
"""

View File

@ -0,0 +1,119 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
import torch
from prototorch.core.distances import omega_distance
from prototorch.core.initializers import (
AbstractLinearTransformInitializer,
EyeLinearTransformInitializer,
)
from prototorch.models.y_arch import (
GLVQLossMixin,
SimpleComparisonMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
from prototorch.nn.wrappers import LambdaLayer
from torch.nn.parameter import Parameter
class GMLVQ(
SupervisedArchitecture,
SimpleComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
):
"""
Generalized Matrix Learning Vector Quantization (GMLVQ)
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
_omega: torch.Tensor
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(
SimpleComparisonMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
input_dim: Necessary Field: The dimensionality of the input.
latent_dim: The dimensionality of the latent space. Default: 2.
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
"""
backbone_lr: float = 0.1
lr: float = 0.1
comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict())
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.backbone_lr = hparams.backbone_lr
self.optimizer = hparams.optimizer
def init_comparison(self, hparams: HyperParameters) -> None:
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn,
**hparams.comparison_args,
)
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(
batch_tensor,
comp_tensor,
self._omega,
)
return distances
def configure_optimizers(self):
proto_opt = self.optimizer(
self.components_layer.parameters(),
lr=self.lr,
)
omega_opt = self.optimizer(
[self._omega],
lr=self.backbone_lr,
)
return [proto_opt, omega_opt]
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach()
lam = omega @ omega.T
return lam.detach().cpu()