chore: rename clc-lc to proto-Y-architecture

This commit is contained in:
Alexander Engelsberger
2022-05-18 14:11:46 +02:00
parent 02954044d7
commit dc4f31d700
5 changed files with 94 additions and 75 deletions

View File

@@ -0,0 +1,216 @@
"""
CLCC Scheme
CLCC is a LVQ scheme containing 4 steps
- Components
- Latent Space
- Comparison
- Competition
"""
from dataclasses import dataclass
from typing import (
Dict,
Set,
Type,
)
import pytorch_lightning as pl
import torch
from torchmetrics import Accuracy, Metric
class BaseYArchitecture(pl.LightningModule):
@dataclass
class HyperParameters:
...
registered_metrics: Dict[Type[Metric], Metric] = {}
registered_metric_names: Dict[Type[Metric], Set[str]] = {}
components_layer: pl.LightningModule
def __init__(self, hparams) -> None:
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# Initialize Model Metrics
self.init_model_metrics()
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: str,
metric: Type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_names[metric] = {name}
else:
self.registered_metric_names[metric].add(name)
# external API
def get_competion(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competion(batch, components)
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def forward_comparison(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
return self.get_competion(batch, components)
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competion(batch, components)
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Type hints
# TODO: Docs
def init_components(self, hparams: HyperParameters) -> None:
...
def init_latent(self, hparams: HyperParameters) -> None:
...
def init_comparison(self, hparams: HyperParameters) -> None:
...
def init_competition(self, hparams: HyperParameters) -> None:
...
def init_loss(self, hparams: HyperParameters) -> None:
...
def init_inference(self, hparams: HyperParameters) -> None:
...
def init_model_metrics(self) -> None:
self.register_torchmetric('accuracy', Accuracy)
# Empty Steps
# TODO: Type hints
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def latent(self, batch, components):
"""
The latent step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
Takes a batch of size N and the componentsset of size M.
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
def competition(self, comparisonmeasures, components):
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
def loss(self, comparisonmeasures, batch, components):
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
def inference(self, comparisonmeasures, components):
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
def update_metrics_step(self, batch):
x, y = batch
# Prediction Metrics
preds = self(x)
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
instance(y, preds)
def update_metrics_epoch(self):
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance.compute()
for name in self.registered_metric_names[metric]:
self.log(name, value)
instance.reset()
# Lightning Hooks
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch)
return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)

View File

@@ -0,0 +1,63 @@
from typing import Optional, Type
import numpy as np
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d
class LogTorchmetricCallback(pl.Callback):
def __init__(
self,
name,
metric: Type[torchmetrics.Metric],
on="prediction",
**metric_kwargs,
) -> None:
self.name = name
self.metric = metric
self.metric_kwargs = metric_kwargs
self.on = on
def setup(
self,
trainer: pl.Trainer,
pl_module: BaseYArchitecture,
stage: Optional[str] = None,
) -> None:
if self.on == "prediction":
pl_module.register_torchmetric(
self.name,
self.metric,
**self.metric_kwargs,
)
else:
raise ValueError(f"{self.on} is no valid metric hook")
class VisGLVQ2D(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(
np.vstack([x_train, protos]),
self.border,
self.resolution,
)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.components_layer.components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)

View File

@@ -0,0 +1,98 @@
from dataclasses import dataclass
from typing import Callable, Type
import torch
from prototorch.core.competitions import WTAC
from prototorch.core.components import LabeledComponents
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
)
from prototorch.core.losses import GLVQLoss
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
from prototorch.nn.wrappers import LambdaLayer
class SupervisedScheme(BaseYArchitecture):
@dataclass
class HyperParameters:
distribution: dict[str, int]
component_initializer: AbstractComponentsInitializer
def init_components(self, hparams: HyperParameters):
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
# ##############################################################################
# GLVQ
# ##############################################################################
class GLVQ(
SupervisedScheme, ):
"""GLVQ using the new Scheme
"""
@dataclass
class HyperParameters(SupervisedScheme.HyperParameters):
distance_fn: Callable = euclidean_distance
lr: float = 0.01
margin: float = 0.0
# TODO: make nicer
transfer_fn: str = "identity"
transfer_beta: float = 10.0
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(hparams.distance_fn)
def init_inference(self, hparams: HyperParameters):
self.competition_layer = WTAC()
def init_loss(self, hparams):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
beta=hparams.transfer_beta,
)
# Steps
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(batch_tensor, comp_tensor)
return distances
def inference(self, comparisonmeasures, components):
comp_labels = components[1]
return self.competition_layer(comparisonmeasures, comp_labels)
def loss(self, comparisonmeasures, batch, components):
target = batch[1]
comp_labels = components[1]
return self.loss_layer(comparisonmeasures, target, comp_labels)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
# Properties
@property
def prototypes(self):
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
return self.components_layer.labels.detach().cpu()

View File

@@ -0,0 +1,89 @@
import prototorch as pt
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.models.proto_y_architecture.callbacks import (
LogTorchmetricCallback,
VisGLVQ2D,
)
from prototorch.models.proto_y_architecture.glvq import GLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
# ##############################################################################
if __name__ == "__main__":
# ------------------------------------------------------------
# DATA
# ------------------------------------------------------------
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
train_ds.targets[train_ds.targets == 2.0] = 1.0
# Dataloader
train_loader = DataLoader(
train_ds,
batch_size=64,
num_workers=0,
shuffle=True,
)
# ------------------------------------------------------------
# HYPERPARAMETERS
# ------------------------------------------------------------
# Select Initializer
components_initializer = SMCI(train_ds)
# Define Hyperparameters
hyperparameters = GLVQ.HyperParameters(
lr=0.5,
distribution=dict(
num_classes=2,
per_class=1,
),
component_initializer=components_initializer,
)
# Create Model
model = GLVQ(hyperparameters)
print(model)
# ------------------------------------------------------------
# TRAINING
# ------------------------------------------------------------
# Controlling Callbacks
stopping_criterion = LogTorchmetricCallback(
'recall',
torchmetrics.Recall,
num_classes=2,
)
es = EarlyStopping(
monitor=stopping_criterion.name,
min_delta=0.001,
patience=15,
mode="max",
check_on_train_epoch_end=True,
)
# Visualization Callback
vis = VisGLVQ2D(data=train_ds)
# Define trainer
trainer = pl.Trainer(
callbacks=[
vis,
stopping_criterion,
es,
],
gpus=0,
max_epochs=200,
log_every_n_steps=1,
)
# Train
trainer.fit(model, train_loader)