feat: add GMLVQ with new architecture

This commit is contained in:
Alexander Engelsberger
2022-05-19 16:13:08 +02:00
parent 3e50d0d817
commit e922aae432
15 changed files with 541 additions and 226 deletions

View File

@@ -0,0 +1,15 @@
from .architectures.base import BaseYArchitecture
from .architectures.comparison import SimpleComparisonMixin
from .architectures.competition import WTACompetitionMixin
from .architectures.components import SupervisedArchitecture
from .architectures.loss import GLVQLossMixin
from .architectures.optimization import SingleLearningRateMixin
__all__ = [
'BaseYArchitecture',
"SimpleComparisonMixin",
"SingleLearningRateMixin",
"SupervisedArchitecture",
"WTACompetitionMixin",
"GLVQLossMixin",
]

View File

@@ -0,0 +1,212 @@
"""
Proto Y Architecture
Network architecture for Component based Learning.
"""
from dataclasses import dataclass
from typing import (
Dict,
Set,
Type,
)
import pytorch_lightning as pl
import torch
from torchmetrics import Metric
from torchmetrics.classification.accuracy import Accuracy
class BaseYArchitecture(pl.LightningModule):
@dataclass
class HyperParameters:
...
registered_metrics: Dict[Type[Metric], Metric] = {}
registered_metric_names: Dict[Type[Metric], Set[str]] = {}
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# Initialize Model Metrics
self.init_model_metrics()
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: str,
metric: Type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_names[metric] = {name}
else:
self.registered_metric_names[metric].add(name)
# external API
def get_competition(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competition(batch, components)
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def forward_comparison(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
return self.get_competition(batch, components)
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competition(batch, components)
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Type hints
# TODO: Docs
def init_components(self, hparams: HyperParameters) -> None:
...
def init_latent(self, hparams: HyperParameters) -> None:
...
def init_comparison(self, hparams: HyperParameters) -> None:
...
def init_competition(self, hparams: HyperParameters) -> None:
...
def init_loss(self, hparams: HyperParameters) -> None:
...
def init_inference(self, hparams: HyperParameters) -> None:
...
def init_model_metrics(self) -> None:
self.register_torchmetric('accuracy', Accuracy)
# Empty Steps
# TODO: Type hints
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def latent(self, batch, components):
"""
The latent step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
Takes a batch of size N and the component set of size M.
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
def competition(self, comparison_measures, components):
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
def loss(self, comparison_measures, batch, components):
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
def inference(self, comparison_measures, components):
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
def update_metrics_step(self, batch):
x, y = batch
# Prediction Metrics
preds = self(x)
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
instance(y, preds)
def update_metrics_epoch(self):
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance.compute()
for name in self.registered_metric_names[metric]:
self.log(name, value)
instance.reset()
# Lightning Hooks
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch)
return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)

View File

@@ -0,0 +1,41 @@
from dataclasses import dataclass, field
from typing import Callable
from prototorch.core.distances import euclidean_distance
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
from prototorch.nn.wrappers import LambdaLayer
class SimpleComparisonMixin(BaseYArchitecture):
"""
Simple Comparison
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
**hparams.comparison_args)
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(batch_tensor, comp_tensor)
return distances

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from prototorch.core.competitions import WTAC
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
class WTACompetitionMixin(BaseYArchitecture):
"""
Winner Take All Competition
A competition layer that uses the winner-take-all strategy.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
No hyperparameters.
"""
# Steps
# ----------------------------------------------------------------------------------------------------
def init_inference(self, hparams: HyperParameters):
self.competition_layer = WTAC()
def inference(self, comparison_measures, components):
comp_labels = components[1]
return self.competition_layer(comparison_measures, comp_labels)

View File

@@ -0,0 +1,53 @@
from dataclasses import dataclass
from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
)
from prototorch.models.y_arch import BaseYArchitecture
class SupervisedArchitecture(BaseYArchitecture):
"""
Supervised Architecture
An architecture that uses labeled Components as component Layer.
"""
components_layer: LabeledComponents
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters:
"""
distribution: A valid prototype distribution. No default possible.
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
"""
distribution: "dict[str, int]"
component_initializer: AbstractComponentsInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_components(self, hparams: HyperParameters):
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def prototypes(self):
"""
Returns the position of the prototypes.
"""
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
"""
Returns the labels of the prototypes.
"""
return self.components_layer.labels.detach().cpu()

View File

@@ -0,0 +1,42 @@
from dataclasses import dataclass, field
from prototorch.core.losses import GLVQLoss
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
class GLVQLossMixin(BaseYArchitecture):
"""
GLVQ Loss
A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
margin: The margin of the GLVQ loss. Default: 0.0.
transfer_fn: Transfer function to use. Default: sigmoid_beta.
transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}.
"""
margin: float = 0.0
transfer_fn: str = "sigmoid_beta"
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
# Steps
# ----------------------------------------------------------------------------------------------------
def init_loss(self, hparams: HyperParameters):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
**hparams.transfer_args,
)
def loss(self, comparison_measures, batch, components):
target = batch[1]
comp_labels = components[1]
loss = self.loss_layer(comparison_measures, target, comp_labels)
self.log('loss', loss)
return loss

View File

@@ -0,0 +1,36 @@
from dataclasses import dataclass
from typing import Type
import torch
from prototorch.models.y_arch import BaseYArchitecture
class SingleLearningRateMixin(BaseYArchitecture):
"""
Single Learning Rate
All parameters are updated with a single learning rate.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore

View File

@@ -0,0 +1,149 @@
import warnings
from typing import Optional, Type
import numpy as np
import pytorch_lightning as pl
import torch
import torchmetrics
from matplotlib import pyplot as plt
from prototorch.models.vis import Vis2DAbstract
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
from prototorch.models.y_arch.library.gmlvq import GMLVQ
from prototorch.utils.utils import mesh2d
from pytorch_lightning.loggers import TensorBoardLogger
DIVERGING_COLOR_MAPS = [
'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn',
'Spectral', 'coolwarm', 'bwr', 'seismic'
]
class LogTorchmetricCallback(pl.Callback):
def __init__(
self,
name,
metric: Type[torchmetrics.Metric],
on="prediction",
**metric_kwargs,
) -> None:
self.name = name
self.metric = metric
self.metric_kwargs = metric_kwargs
self.on = on
def setup(
self,
trainer: pl.Trainer,
pl_module: BaseYArchitecture,
stage: Optional[str] = None,
) -> None:
if self.on == "prediction":
pl_module.register_torchmetric(
self.name,
self.metric,
**self.metric_kwargs,
)
else:
raise ValueError(f"{self.on} is no valid metric hook")
class VisGLVQ2D(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(
np.vstack([x_train, protos]),
self.border,
self.resolution,
)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.components_layer.components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
class PlotLambdaMatrixToTensorboard(pl.Callback):
def __init__(self, cmap='seismic') -> None:
super().__init__()
self.cmap = cmap
if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str:
warnings.warn(
f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}"
)
def on_train_start(self, trainer, pl_module: GMLVQ):
self.plot_lambda(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
self.plot_lambda(trainer, pl_module)
def plot_lambda(self, trainer, pl_module: GMLVQ):
self.fig, self.ax = plt.subplots(1, 1)
# plot lambda matrix
l_matrix = pl_module.lambda_matrix
# normalize lambda matrix
l_matrix = l_matrix / torch.max(torch.abs(l_matrix))
# plot lambda matrix
self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1)
self.fig.colorbar(self.ax.images[-1])
# add title
self.ax.set_title('Lambda Matrix')
# add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure(
f"lambda_matrix",
self.fig,
trainer.global_step,
)
else:
warnings.warn(
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
)

View File

@@ -0,0 +1,5 @@
from .glvq import GLVQ
__all__ = [
"GLVQ",
]

View File

@@ -0,0 +1,35 @@
from dataclasses import dataclass
from prototorch.models.y_arch import (
SimpleComparisonMixin,
SingleLearningRateMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
from prototorch.models.y_arch.architectures.loss import GLVQLossMixin
class GLVQ(
SupervisedArchitecture,
SimpleComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
SingleLearningRateMixin,
):
"""
Generalized Learning Vector Quantization (GLVQ)
A GLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
@dataclass
class HyperParameters(
SimpleComparisonMixin.HyperParameters,
SingleLearningRateMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
No hyperparameters.
"""

View File

@@ -0,0 +1,119 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
import torch
from prototorch.core.distances import omega_distance
from prototorch.core.initializers import (
AbstractLinearTransformInitializer,
EyeLinearTransformInitializer,
)
from prototorch.models.y_arch import (
GLVQLossMixin,
SimpleComparisonMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
from prototorch.nn.wrappers import LambdaLayer
from torch.nn.parameter import Parameter
class GMLVQ(
SupervisedArchitecture,
SimpleComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
):
"""
Generalized Matrix Learning Vector Quantization (GMLVQ)
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
_omega: torch.Tensor
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(
SimpleComparisonMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
input_dim: Necessary Field: The dimensionality of the input.
latent_dim: The dimensionality of the latent space. Default: 2.
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
"""
backbone_lr: float = 0.1
lr: float = 0.1
comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict())
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.backbone_lr = hparams.backbone_lr
self.optimizer = hparams.optimizer
def init_comparison(self, hparams: HyperParameters) -> None:
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn,
**hparams.comparison_args,
)
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(
batch_tensor,
comp_tensor,
self._omega,
)
return distances
def configure_optimizers(self):
proto_opt = self.optimizer(
self.components_layer.parameters(),
lr=self.lr,
)
omega_opt = self.optimizer(
[self._omega],
lr=self.backbone_lr,
)
return [proto_opt, omega_opt]
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach()
lam = omega @ omega.T
return lam.detach().cpu()