feat: copy old clc-lc implementation

This commit is contained in:
Alexander Engelsberger 2022-05-17 16:25:43 +02:00
parent e0b92e9ac2
commit 8f08ba66ea
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
4 changed files with 396 additions and 0 deletions

View File

View File

@ -0,0 +1,90 @@
from dataclasses import dataclass
from typing import Callable
import torch
from prototorch.core.competitions import WTAC
from prototorch.core.components import LabeledComponents
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
)
from prototorch.core.losses import GLVQLoss
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.nn.wrappers import LambdaLayer
@dataclass
class GLVQhparams:
distribution: dict
component_initializer: AbstractComponentsInitializer
distance_fn: Callable = euclidean_distance
lr: float = 0.01
margin: float = 0.0
# TODO: make nicer
transfer_fn: str = "identity"
transfer_beta: float = 10.0
optimizer: torch.optim.Optimizer = torch.optim.Adam
class GLVQ(CLCCScheme):
def __init__(self, hparams: GLVQhparams) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Initializers
def init_components(self, hparams):
# initialize Component Layer
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
def init_comparison(self, hparams):
# initialize Distance Layer
self.comparison_layer = LambdaLayer(hparams.distance_fn)
def init_inference(self, hparams):
self.competition_layer = WTAC()
def init_loss(self, hparams):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
beta=hparams.transfer_beta,
)
# Steps
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(batch_tensor, comp_tensor)
return distances
def inference(self, comparisonmeasures, components):
comp_labels = components[1]
return self.competition_layer(comparisonmeasures, comp_labels)
def loss(self, comparisonmeasures, batch, components):
target = batch[1]
comp_labels = components[1]
return self.loss_layer(comparisonmeasures, target, comp_labels)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr)
# Properties
@property
def prototypes(self):
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
return self.components_layer.labels.detach().cpu()

View File

@ -0,0 +1,203 @@
"""
CLCC Scheme
CLCC is a LVQ scheme containing 4 steps
- Components
- Latent Space
- Comparison
- Competition
"""
from typing import (
Dict,
Set,
Type,
)
import pytorch_lightning as pl
import torch
from torchmetrics import Accuracy, Metric
class CLCCScheme(pl.LightningModule):
registered_metrics: Dict[Type[Metric], Metric] = {}
registered_metric_names: Dict[Type[Metric], Set[str]] = {}
def __init__(self, hparams) -> None:
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# Initialize Model Metrics
self.init_model_metrics()
# internal API, called by models and callbacks
def register_torchmetric(self, name: str, metric: Metric, **metricargs):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metricargs)
self.registered_metric_names[metric] = {name}
else:
self.registered_metric_names[metric].add(name)
# external API
def get_competion(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competion(batch, components)
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def forward_comparison(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
return self.get_competion(batch, components)
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competion(batch, components)
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Type hints
# TODO: Docs
def init_components(self, hparams):
...
def init_latent(self, hparams):
...
def init_comparison(self, hparams):
...
def init_competition(self, hparams):
...
def init_loss(self, hparams):
...
def init_inference(self, hparams):
...
def init_model_metrics(self):
self.register_torchmetric('accuracy', Accuracy)
# Empty Steps
# TODO: Type hints
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def latent(self, batch, components):
"""
The latent step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
Takes a batch of size N and the componentsset of size M.
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
def competition(self, comparisonmeasures, components):
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
def loss(self, comparisonmeasures, batch, components):
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
def inference(self, comparisonmeasures, components):
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
def update_metrics_step(self, batch):
x, y = batch
# Prediction Metrics
preds = self(x)
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
instance(y, preds)
def update_metrics_epoch(self):
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance.compute()
for name in self.registered_metric_names[metric]:
self.log(name, value)
instance.reset()
# Lightning Hooks
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch)
return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)

View File

@ -0,0 +1,103 @@
from typing import Optional
import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.models.vis import Visualize2DVoronoiCallback
# NEW STUFF
# ##############################################################################
# TODO: Metrics
class MetricsTestCallback(pl.Callback):
metric_name = "test_cb_acc"
def setup(self,
trainer: pl.Trainer,
pl_module: CLCCScheme,
stage: Optional[str] = None) -> None:
pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy)
def on_epoch_end(self, trainer: pl.Trainer,
pl_module: pl.LightningModule) -> None:
metric = trainer.logged_metrics[self.metric_name]
if metric > 0.95:
trainer.should_stop = True
class LogTorchmetricCallback(pl.Callback):
def __init__(self, name, metric, on="prediction", **metric_args) -> None:
self.name = name
self.metric = metric
self.metric_args = metric_args
self.on = on
def setup(self,
trainer: pl.Trainer,
pl_module: CLCCScheme,
stage: Optional[str] = None) -> None:
if self.on == "prediction":
pl_module.register_torchmetric(self.name, self.metric,
**self.metric_args)
else:
raise ValueError(f"{self.on} is no valid metric hook")
# TODO: Pruning
# ##############################################################################
if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
train_ds.targets[train_ds.targets == 2.0] = 1.0
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=64,
num_workers=0,
shuffle=True)
#components_initializer = SMCI(train_ds)
components_initializer = RandomNormalCompInitializer(2)
hparams = GLVQhparams(
lr=0.5,
distribution=dict(
num_classes=2,
per_class=1,
),
component_initializer=components_initializer,
)
model = GLVQ(hparams)
print(model)
# Callbacks
vis = Visualize2DVoronoiCallback(
data=train_ds,
resolution=500,
)
metrics = MetricsTestCallback()
recall = LogTorchmetricCallback('recall',
torchmetrics.Recall,
num_classes=2)
# Train
trainer = pl.Trainer(
callbacks=[
vis,
#metrics,
recall,
],
gpus=0,
max_epochs=200,
weights_summary=None,
log_every_n_steps=1,
)
trainer.fit(model, train_loader)