feat: copy old clc-lc implementation
This commit is contained in:
parent
e0b92e9ac2
commit
8f08ba66ea
0
prototorch/models/clcc/__init__.py
Normal file
0
prototorch/models/clcc/__init__.py
Normal file
90
prototorch/models/clcc/clcc_glvq.py
Normal file
90
prototorch/models/clcc/clcc_glvq.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.core.competitions import WTAC
|
||||||
|
from prototorch.core.components import LabeledComponents
|
||||||
|
from prototorch.core.distances import euclidean_distance
|
||||||
|
from prototorch.core.initializers import (
|
||||||
|
AbstractComponentsInitializer,
|
||||||
|
LabelsInitializer,
|
||||||
|
)
|
||||||
|
from prototorch.core.losses import GLVQLoss
|
||||||
|
from prototorch.models.clcc.clcc_scheme import CLCCScheme
|
||||||
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GLVQhparams:
|
||||||
|
distribution: dict
|
||||||
|
component_initializer: AbstractComponentsInitializer
|
||||||
|
distance_fn: Callable = euclidean_distance
|
||||||
|
lr: float = 0.01
|
||||||
|
margin: float = 0.0
|
||||||
|
# TODO: make nicer
|
||||||
|
transfer_fn: str = "identity"
|
||||||
|
transfer_beta: float = 10.0
|
||||||
|
optimizer: torch.optim.Optimizer = torch.optim.Adam
|
||||||
|
|
||||||
|
|
||||||
|
class GLVQ(CLCCScheme):
|
||||||
|
|
||||||
|
def __init__(self, hparams: GLVQhparams) -> None:
|
||||||
|
super().__init__(hparams)
|
||||||
|
self.lr = hparams.lr
|
||||||
|
self.optimizer = hparams.optimizer
|
||||||
|
|
||||||
|
# Initializers
|
||||||
|
def init_components(self, hparams):
|
||||||
|
# initialize Component Layer
|
||||||
|
self.components_layer = LabeledComponents(
|
||||||
|
distribution=hparams.distribution,
|
||||||
|
components_initializer=hparams.component_initializer,
|
||||||
|
labels_initializer=LabelsInitializer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_comparison(self, hparams):
|
||||||
|
# initialize Distance Layer
|
||||||
|
self.comparison_layer = LambdaLayer(hparams.distance_fn)
|
||||||
|
|
||||||
|
def init_inference(self, hparams):
|
||||||
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
|
def init_loss(self, hparams):
|
||||||
|
self.loss_layer = GLVQLoss(
|
||||||
|
margin=hparams.margin,
|
||||||
|
transfer_fn=hparams.transfer_fn,
|
||||||
|
beta=hparams.transfer_beta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
def comparison(self, batch, components):
|
||||||
|
comp_tensor, _ = components
|
||||||
|
batch_tensor, _ = batch
|
||||||
|
|
||||||
|
comp_tensor = comp_tensor.unsqueeze(1)
|
||||||
|
|
||||||
|
distances = self.comparison_layer(batch_tensor, comp_tensor)
|
||||||
|
|
||||||
|
return distances
|
||||||
|
|
||||||
|
def inference(self, comparisonmeasures, components):
|
||||||
|
comp_labels = components[1]
|
||||||
|
return self.competition_layer(comparisonmeasures, comp_labels)
|
||||||
|
|
||||||
|
def loss(self, comparisonmeasures, batch, components):
|
||||||
|
target = batch[1]
|
||||||
|
comp_labels = components[1]
|
||||||
|
return self.loss_layer(comparisonmeasures, target, comp_labels)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return self.optimizer(self.parameters(), lr=self.lr)
|
||||||
|
|
||||||
|
# Properties
|
||||||
|
@property
|
||||||
|
def prototypes(self):
|
||||||
|
return self.components_layer.components.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prototype_labels(self):
|
||||||
|
return self.components_layer.labels.detach().cpu()
|
203
prototorch/models/clcc/clcc_scheme.py
Normal file
203
prototorch/models/clcc/clcc_scheme.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
"""
|
||||||
|
CLCC Scheme
|
||||||
|
|
||||||
|
CLCC is a LVQ scheme containing 4 steps
|
||||||
|
- Components
|
||||||
|
- Latent Space
|
||||||
|
- Comparison
|
||||||
|
- Competition
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Set,
|
||||||
|
Type,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from torchmetrics import Accuracy, Metric
|
||||||
|
|
||||||
|
|
||||||
|
class CLCCScheme(pl.LightningModule):
|
||||||
|
registered_metrics: Dict[Type[Metric], Metric] = {}
|
||||||
|
registered_metric_names: Dict[Type[Metric], Set[str]] = {}
|
||||||
|
|
||||||
|
def __init__(self, hparams) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Common Steps
|
||||||
|
self.init_components(hparams)
|
||||||
|
self.init_latent(hparams)
|
||||||
|
self.init_comparison(hparams)
|
||||||
|
self.init_competition(hparams)
|
||||||
|
|
||||||
|
# Train Steps
|
||||||
|
self.init_loss(hparams)
|
||||||
|
|
||||||
|
# Inference Steps
|
||||||
|
self.init_inference(hparams)
|
||||||
|
|
||||||
|
# Initialize Model Metrics
|
||||||
|
self.init_model_metrics()
|
||||||
|
|
||||||
|
# internal API, called by models and callbacks
|
||||||
|
def register_torchmetric(self, name: str, metric: Metric, **metricargs):
|
||||||
|
if metric not in self.registered_metrics:
|
||||||
|
self.registered_metrics[metric] = metric(**metricargs)
|
||||||
|
self.registered_metric_names[metric] = {name}
|
||||||
|
else:
|
||||||
|
self.registered_metric_names[metric].add(name)
|
||||||
|
|
||||||
|
# external API
|
||||||
|
def get_competion(self, batch, components):
|
||||||
|
latent_batch, latent_components = self.latent(batch, components)
|
||||||
|
# TODO: => Latent Hook
|
||||||
|
comparison_tensor = self.comparison(latent_batch, latent_components)
|
||||||
|
# TODO: => Comparison Hook
|
||||||
|
return comparison_tensor
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
if isinstance(batch, torch.Tensor):
|
||||||
|
batch = (batch, None)
|
||||||
|
# TODO: manage different datatypes?
|
||||||
|
components = self.components_layer()
|
||||||
|
# TODO: => Component Hook
|
||||||
|
comparison_tensor = self.get_competion(batch, components)
|
||||||
|
# TODO: => Competition Hook
|
||||||
|
return self.inference(comparison_tensor, components)
|
||||||
|
|
||||||
|
def predict(self, batch):
|
||||||
|
"""
|
||||||
|
Alias for forward
|
||||||
|
"""
|
||||||
|
return self.forward(batch)
|
||||||
|
|
||||||
|
def forward_comparison(self, batch):
|
||||||
|
if isinstance(batch, torch.Tensor):
|
||||||
|
batch = (batch, None)
|
||||||
|
# TODO: manage different datatypes?
|
||||||
|
components = self.components_layer()
|
||||||
|
# TODO: => Component Hook
|
||||||
|
return self.get_competion(batch, components)
|
||||||
|
|
||||||
|
def loss_forward(self, batch):
|
||||||
|
# TODO: manage different datatypes?
|
||||||
|
components = self.components_layer()
|
||||||
|
# TODO: => Component Hook
|
||||||
|
comparison_tensor = self.get_competion(batch, components)
|
||||||
|
# TODO: => Competition Hook
|
||||||
|
return self.loss(comparison_tensor, batch, components)
|
||||||
|
|
||||||
|
# Empty Initialization
|
||||||
|
# TODO: Type hints
|
||||||
|
# TODO: Docs
|
||||||
|
def init_components(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_latent(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_comparison(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_competition(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_loss(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_inference(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_model_metrics(self):
|
||||||
|
self.register_torchmetric('accuracy', Accuracy)
|
||||||
|
|
||||||
|
# Empty Steps
|
||||||
|
# TODO: Type hints
|
||||||
|
def components(self):
|
||||||
|
"""
|
||||||
|
This step has no input.
|
||||||
|
|
||||||
|
It returns the components.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The components step has no reasonable default.")
|
||||||
|
|
||||||
|
def latent(self, batch, components):
|
||||||
|
"""
|
||||||
|
The latent step receives the data batch and the components.
|
||||||
|
It can transform both by an arbitrary function.
|
||||||
|
|
||||||
|
It returns the transformed batch and components, each of the same length as the original input.
|
||||||
|
"""
|
||||||
|
return batch, components
|
||||||
|
|
||||||
|
def comparison(self, batch, components):
|
||||||
|
"""
|
||||||
|
Takes a batch of size N and the componentsset of size M.
|
||||||
|
|
||||||
|
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The comparison step has no reasonable default.")
|
||||||
|
|
||||||
|
def competition(self, comparisonmeasures, components):
|
||||||
|
"""
|
||||||
|
Takes the tensor of comparison measures.
|
||||||
|
|
||||||
|
Assigns a competition vector to each class.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The competition step has no reasonable default.")
|
||||||
|
|
||||||
|
def loss(self, comparisonmeasures, batch, components):
|
||||||
|
"""
|
||||||
|
Takes the tensor of competition measures.
|
||||||
|
|
||||||
|
Calculates a single loss value
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("The loss step has no reasonable default.")
|
||||||
|
|
||||||
|
def inference(self, comparisonmeasures, components):
|
||||||
|
"""
|
||||||
|
Takes the tensor of competition measures.
|
||||||
|
|
||||||
|
Returns the inferred vector.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The inference step has no reasonable default.")
|
||||||
|
|
||||||
|
def update_metrics_step(self, batch):
|
||||||
|
x, y = batch
|
||||||
|
|
||||||
|
# Prediction Metrics
|
||||||
|
preds = self(x)
|
||||||
|
for metric in self.registered_metrics:
|
||||||
|
instance = self.registered_metrics[metric].to(self.device)
|
||||||
|
instance(y, preds)
|
||||||
|
|
||||||
|
def update_metrics_epoch(self):
|
||||||
|
for metric in self.registered_metrics:
|
||||||
|
instance = self.registered_metrics[metric].to(self.device)
|
||||||
|
value = instance.compute()
|
||||||
|
|
||||||
|
for name in self.registered_metric_names[metric]:
|
||||||
|
self.log(name, value)
|
||||||
|
|
||||||
|
instance.reset()
|
||||||
|
|
||||||
|
# Lightning Hooks
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
|
self.update_metrics_step(batch)
|
||||||
|
|
||||||
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
|
def training_epoch_end(self, outs) -> None:
|
||||||
|
self.update_metrics_epoch()
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
return self.loss_forward(batch)
|
103
prototorch/models/clcc/test_clcc.py
Normal file
103
prototorch/models/clcc/test_clcc.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
import torchmetrics
|
||||||
|
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
|
||||||
|
from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams
|
||||||
|
from prototorch.models.clcc.clcc_scheme import CLCCScheme
|
||||||
|
from prototorch.models.vis import Visualize2DVoronoiCallback
|
||||||
|
|
||||||
|
# NEW STUFF
|
||||||
|
# ##############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Metrics
|
||||||
|
class MetricsTestCallback(pl.Callback):
|
||||||
|
metric_name = "test_cb_acc"
|
||||||
|
|
||||||
|
def setup(self,
|
||||||
|
trainer: pl.Trainer,
|
||||||
|
pl_module: CLCCScheme,
|
||||||
|
stage: Optional[str] = None) -> None:
|
||||||
|
pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy)
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: pl.Trainer,
|
||||||
|
pl_module: pl.LightningModule) -> None:
|
||||||
|
metric = trainer.logged_metrics[self.metric_name]
|
||||||
|
if metric > 0.95:
|
||||||
|
trainer.should_stop = True
|
||||||
|
|
||||||
|
|
||||||
|
class LogTorchmetricCallback(pl.Callback):
|
||||||
|
|
||||||
|
def __init__(self, name, metric, on="prediction", **metric_args) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.metric = metric
|
||||||
|
self.metric_args = metric_args
|
||||||
|
self.on = on
|
||||||
|
|
||||||
|
def setup(self,
|
||||||
|
trainer: pl.Trainer,
|
||||||
|
pl_module: CLCCScheme,
|
||||||
|
stage: Optional[str] = None) -> None:
|
||||||
|
if self.on == "prediction":
|
||||||
|
pl_module.register_torchmetric(self.name, self.metric,
|
||||||
|
**self.metric_args)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{self.on} is no valid metric hook")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Pruning
|
||||||
|
|
||||||
|
# ##############################################################################
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
|
train_ds.targets[train_ds.targets == 2.0] = 1.0
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
|
batch_size=64,
|
||||||
|
num_workers=0,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
#components_initializer = SMCI(train_ds)
|
||||||
|
components_initializer = RandomNormalCompInitializer(2)
|
||||||
|
|
||||||
|
hparams = GLVQhparams(
|
||||||
|
lr=0.5,
|
||||||
|
distribution=dict(
|
||||||
|
num_classes=2,
|
||||||
|
per_class=1,
|
||||||
|
),
|
||||||
|
component_initializer=components_initializer,
|
||||||
|
)
|
||||||
|
model = GLVQ(hparams)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
# Callbacks
|
||||||
|
vis = Visualize2DVoronoiCallback(
|
||||||
|
data=train_ds,
|
||||||
|
resolution=500,
|
||||||
|
)
|
||||||
|
metrics = MetricsTestCallback()
|
||||||
|
recall = LogTorchmetricCallback('recall',
|
||||||
|
torchmetrics.Recall,
|
||||||
|
num_classes=2)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
#metrics,
|
||||||
|
recall,
|
||||||
|
],
|
||||||
|
gpus=0,
|
||||||
|
max_epochs=200,
|
||||||
|
weights_summary=None,
|
||||||
|
log_every_n_steps=1,
|
||||||
|
)
|
||||||
|
trainer.fit(model, train_loader)
|
Loading…
Reference in New Issue
Block a user