feat: Add basic GLVQ with new architecture
This commit is contained in:
parent
d4448f2bc9
commit
967953442b
@ -3,6 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
|
import prototorch.models.expanded
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
@ -29,7 +30,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GLVQ(
|
model = prototorch.models.expanded.GLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
|
1
prototorch/models/expanded/__init__.py
Normal file
1
prototorch/models/expanded/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .glvq import GLVQ
|
86
prototorch/models/expanded/clcc_glvq.py
Normal file
86
prototorch/models/expanded/clcc_glvq.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.core.competitions import WTAC
|
||||||
|
from prototorch.core.components import LabeledComponents
|
||||||
|
from prototorch.core.distances import euclidean_distance
|
||||||
|
from prototorch.core.initializers import AbstractComponentsInitializer, LabelsInitializer
|
||||||
|
from prototorch.core.losses import GLVQLoss
|
||||||
|
from prototorch.models.expanded.clcc_scheme import CLCCScheme
|
||||||
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GLVQhparams:
|
||||||
|
distribution: dict
|
||||||
|
component_initializer: AbstractComponentsInitializer
|
||||||
|
distance_fn: Callable = euclidean_distance
|
||||||
|
lr: float = 0.01
|
||||||
|
margin: float = 0.0
|
||||||
|
# TODO: make nicer
|
||||||
|
transfer_fn: str = "identity"
|
||||||
|
transfer_beta: float = 10.0
|
||||||
|
optimizer: torch.optim.Optimizer = torch.optim.Adam
|
||||||
|
|
||||||
|
|
||||||
|
class GLVQ(CLCCScheme):
|
||||||
|
def __init__(self, hparams: GLVQhparams) -> None:
|
||||||
|
super().__init__(hparams)
|
||||||
|
self.lr = hparams.lr
|
||||||
|
self.optimizer = hparams.optimizer
|
||||||
|
|
||||||
|
# Initializers
|
||||||
|
def init_components(self, hparams):
|
||||||
|
# initialize Component Layer
|
||||||
|
self.components_layer = LabeledComponents(
|
||||||
|
distribution=hparams.distribution,
|
||||||
|
components_initializer=hparams.component_initializer,
|
||||||
|
labels_initializer=LabelsInitializer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_comparison(self, hparams):
|
||||||
|
# initialize Distance Layer
|
||||||
|
self.comparison_layer = LambdaLayer(hparams.distance_fn)
|
||||||
|
|
||||||
|
def init_inference(self, hparams):
|
||||||
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
|
def init_loss(self, hparams):
|
||||||
|
self.loss_layer = GLVQLoss(
|
||||||
|
margin=hparams.margin,
|
||||||
|
transfer_fn=hparams.transfer_fn,
|
||||||
|
beta=hparams.transfer_beta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
def comparison(self, batch, components):
|
||||||
|
comp_tensor, _ = components
|
||||||
|
batch_tensor, _ = batch
|
||||||
|
|
||||||
|
comp_tensor = comp_tensor.unsqueeze(1)
|
||||||
|
|
||||||
|
distances = self.comparison_layer(batch_tensor, comp_tensor)
|
||||||
|
|
||||||
|
return distances
|
||||||
|
|
||||||
|
def inference(self, comparisonmeasures, components):
|
||||||
|
comp_labels = components[1]
|
||||||
|
return self.competition_layer(comparisonmeasures, comp_labels)
|
||||||
|
|
||||||
|
def loss(self, comparisonmeasures, batch, components):
|
||||||
|
target = batch[1]
|
||||||
|
comp_labels = components[1]
|
||||||
|
return self.loss_layer(comparisonmeasures, target, comp_labels)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return self.optimizer(self.parameters(), lr=self.lr)
|
||||||
|
|
||||||
|
# Properties
|
||||||
|
@property
|
||||||
|
def prototypes(self):
|
||||||
|
return self.components_layer.components.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prototype_labels(self):
|
||||||
|
return self.components_layer.labels.detach().cpu()
|
138
prototorch/models/expanded/clcc_scheme.py
Normal file
138
prototorch/models/expanded/clcc_scheme.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
"""
|
||||||
|
CLCC Scheme
|
||||||
|
|
||||||
|
CLCC is a LVQ scheme containing 4 steps
|
||||||
|
- Components
|
||||||
|
- Latent Space
|
||||||
|
- Comparison
|
||||||
|
- Competition
|
||||||
|
|
||||||
|
"""
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
|
||||||
|
class CLCCScheme(pl.LightningModule):
|
||||||
|
def __init__(self, hparams) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Common Steps
|
||||||
|
self.init_components(hparams)
|
||||||
|
self.init_latent(hparams)
|
||||||
|
self.init_comparison(hparams)
|
||||||
|
self.init_competition(hparams)
|
||||||
|
|
||||||
|
# Train Steps
|
||||||
|
self.init_loss(hparams)
|
||||||
|
|
||||||
|
# Inference Steps
|
||||||
|
self.init_inference(hparams)
|
||||||
|
|
||||||
|
# API
|
||||||
|
def get_competion(self, batch, components):
|
||||||
|
latent_batch, latent_components = self.latent(batch, components)
|
||||||
|
# TODO: => Latent Hook
|
||||||
|
comparison_tensor = self.comparison(latent_batch, latent_components)
|
||||||
|
# TODO: => Comparison Hook
|
||||||
|
return comparison_tensor
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
# TODO: manage different datatypes?
|
||||||
|
components = self.components_layer()
|
||||||
|
# TODO: => Component Hook
|
||||||
|
comparison_tensor = self.get_competion(batch, components)
|
||||||
|
# TODO: => Competition Hook
|
||||||
|
return self.inference(comparison_tensor, components)
|
||||||
|
|
||||||
|
def loss_forward(self, batch):
|
||||||
|
# TODO: manage different datatypes?
|
||||||
|
components = self.components_layer()
|
||||||
|
# TODO: => Component Hook
|
||||||
|
comparison_tensor = self.get_competion(batch, components)
|
||||||
|
# TODO: => Competition Hook
|
||||||
|
return self.loss(comparison_tensor, batch, components)
|
||||||
|
|
||||||
|
# Empty Initialization
|
||||||
|
# TODO: Type hints
|
||||||
|
# TODO: Docs
|
||||||
|
def init_components(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_latent(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_comparison(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_competition(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_loss(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
def init_inference(self, hparams):
|
||||||
|
...
|
||||||
|
|
||||||
|
# Empty Steps
|
||||||
|
# TODO: Type hints
|
||||||
|
def components(self):
|
||||||
|
"""
|
||||||
|
This step has no input.
|
||||||
|
|
||||||
|
It returns the components.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The components step has no reasonable default.")
|
||||||
|
|
||||||
|
def latent(self, batch, components):
|
||||||
|
"""
|
||||||
|
The latent step receives the data batch and the components.
|
||||||
|
It can transform both by an arbitrary function.
|
||||||
|
|
||||||
|
It returns the transformed batch and components, each of the same length as the original input.
|
||||||
|
"""
|
||||||
|
return batch, components
|
||||||
|
|
||||||
|
def comparison(self, batch, components):
|
||||||
|
"""
|
||||||
|
Takes a batch of size N and the componentsset of size M.
|
||||||
|
|
||||||
|
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The comparison step has no reasonable default.")
|
||||||
|
|
||||||
|
def competition(self, comparisonmeasures, components):
|
||||||
|
"""
|
||||||
|
Takes the tensor of comparison measures.
|
||||||
|
|
||||||
|
Assigns a competition vector to each class.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The competition step has no reasonable default.")
|
||||||
|
|
||||||
|
def loss(self, comparisonmeasures, batch, components):
|
||||||
|
"""
|
||||||
|
Takes the tensor of competition measures.
|
||||||
|
|
||||||
|
Calculates a single loss value
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("The loss step has no reasonable default.")
|
||||||
|
|
||||||
|
def inference(self, comparisonmeasures, components):
|
||||||
|
"""
|
||||||
|
Takes the tensor of competition measures.
|
||||||
|
|
||||||
|
Returns the inferred vector.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The inference step has no reasonable default.")
|
||||||
|
|
||||||
|
# Lightning Hooks
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
return self.loss_forward(batch)
|
164
prototorch/models/expanded/glvq.py
Normal file
164
prototorch/models/expanded/glvq.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
import torchmetrics
|
||||||
|
from prototorch.core.competitions import WTAC, wtac
|
||||||
|
from prototorch.core.components import Components, LabeledComponents
|
||||||
|
from prototorch.core.distances import (
|
||||||
|
euclidean_distance,
|
||||||
|
lomega_distance,
|
||||||
|
omega_distance,
|
||||||
|
squared_euclidean_distance,
|
||||||
|
)
|
||||||
|
from prototorch.core.initializers import EyeTransformInitializer, LabelsInitializer
|
||||||
|
from prototorch.core.losses import GLVQLoss, lvq1_loss, lvq21_loss
|
||||||
|
from prototorch.core.pooling import stratified_min_pooling
|
||||||
|
from prototorch.core.transforms import LinearTransform
|
||||||
|
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class GLVQ(pl.LightningModule):
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
# Default hparams
|
||||||
|
# TODO: Manage by an HPARAMS Object
|
||||||
|
self.hparams.setdefault("lr", 0.01)
|
||||||
|
self.hparams.setdefault("margin", 0.0)
|
||||||
|
self.hparams.setdefault("transfer_fn", "identity")
|
||||||
|
self.hparams.setdefault("transfer_beta", 10.0)
|
||||||
|
|
||||||
|
# Default config
|
||||||
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
|
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
||||||
|
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||||
|
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||||
|
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||||
|
labels_initializer = kwargs.get("labels_initializer",
|
||||||
|
LabelsInitializer())
|
||||||
|
|
||||||
|
if prototypes_initializer is not None:
|
||||||
|
self.proto_layer = LabeledComponents(
|
||||||
|
distribution=self.hparams.distribution,
|
||||||
|
components_initializer=prototypes_initializer,
|
||||||
|
labels_initializer=labels_initializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.distance_layer = LambdaLayer(distance_fn)
|
||||||
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
|
self.loss = GLVQLoss(
|
||||||
|
margin=self.hparams.margin,
|
||||||
|
transfer_fn=self.hparams.transfer_fn,
|
||||||
|
beta=self.hparams.transfer_beta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_acc(self, distances, targets, tag):
|
||||||
|
preds = self.predict_from_distances(distances)
|
||||||
|
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||||
|
self.log(tag,
|
||||||
|
accuracy,
|
||||||
|
on_step=False,
|
||||||
|
on_epoch=True,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
||||||
|
if self.lr_scheduler is not None:
|
||||||
|
scheduler = self.lr_scheduler(optimizer,
|
||||||
|
**self.lr_scheduler_kwargs)
|
||||||
|
sch = {
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"interval": "step",
|
||||||
|
} # called after each training step
|
||||||
|
return [optimizer], [sch]
|
||||||
|
else:
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
|
x, y = batch
|
||||||
|
out = self.compute_distances(x)
|
||||||
|
_, plabels = self.proto_layer()
|
||||||
|
loss = self.loss(out, y, plabels)
|
||||||
|
return out, loss
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
|
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
|
self.log_prototype_win_ratios(out)
|
||||||
|
self.log("train_loss", train_loss)
|
||||||
|
self.log_acc(out, batch[-1], tag="train_acc")
|
||||||
|
return train_loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
out, val_loss = self.shared_step(batch, batch_idx)
|
||||||
|
self.log("val_loss", val_loss)
|
||||||
|
self.log_acc(out, batch[-1], tag="val_acc")
|
||||||
|
return val_loss
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
out, test_loss = self.shared_step(batch, batch_idx)
|
||||||
|
self.log_acc(out, batch[-1], tag="test_acc")
|
||||||
|
return test_loss
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
test_loss = 0.0
|
||||||
|
for batch_loss in outputs:
|
||||||
|
test_loss += batch_loss.item()
|
||||||
|
self.log("test_loss", test_loss)
|
||||||
|
|
||||||
|
# API
|
||||||
|
def compute_distances(self, x):
|
||||||
|
protos, _ = self.proto_layer()
|
||||||
|
distances = self.distance_layer(x, protos)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
distances = self.compute_distances(x)
|
||||||
|
_, plabels = self.proto_layer()
|
||||||
|
winning = stratified_min_pooling(distances, plabels)
|
||||||
|
y_pred = torch.nn.functional.softmin(winning)
|
||||||
|
return y_pred
|
||||||
|
|
||||||
|
def predict_from_distances(self, distances):
|
||||||
|
with torch.no_grad():
|
||||||
|
_, plabels = self.proto_layer()
|
||||||
|
y_pred = self.competition_layer(distances, plabels)
|
||||||
|
return y_pred
|
||||||
|
|
||||||
|
def predict(self, x):
|
||||||
|
with torch.no_grad():
|
||||||
|
distances = self.compute_distances(x)
|
||||||
|
y_pred = self.predict_from_distances(distances)
|
||||||
|
return y_pred
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prototype_labels(self):
|
||||||
|
return self.proto_layer.labels.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self):
|
||||||
|
return self.proto_layer.num_classes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_prototypes(self):
|
||||||
|
return len(self.proto_layer.components)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prototypes(self):
|
||||||
|
return self.proto_layer.components.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def components(self):
|
||||||
|
"""Only an alias for the prototypes."""
|
||||||
|
return self.prototypes
|
||||||
|
|
||||||
|
# Python overwrites
|
||||||
|
def __repr__(self):
|
||||||
|
surep = super().__repr__()
|
||||||
|
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
||||||
|
wrapped = f"ProtoTorch Bolt(\n{indented})"
|
||||||
|
return wrapped
|
35
prototorch/models/expanded/test_clcc.py
Normal file
35
prototorch/models/expanded/test_clcc.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
|
||||||
|
from prototorch.models.expanded.clcc_glvq import GLVQ, GLVQhparams
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torchvision import datasets
|
||||||
|
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||||
|
|
||||||
|
plt.gray()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||||
|
|
||||||
|
components_initializer = SMCI(train_ds)
|
||||||
|
|
||||||
|
hparams = GLVQhparams(
|
||||||
|
distribution=dict(
|
||||||
|
num_classes=3,
|
||||||
|
per_class=2,
|
||||||
|
),
|
||||||
|
component_initializer=components_initializer,
|
||||||
|
)
|
||||||
|
model = GLVQ(hparams)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
|
# Train
|
||||||
|
trainer = pl.Trainer(callbacks=[vis], gpus=1)
|
||||||
|
trainer.fit(model, train_loader)
|
@ -129,12 +129,14 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
mesh_input, xx, yy = mesh2d(x,
|
||||||
_components = pl_module.proto_layer._components
|
self.border,
|
||||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
self.resolution,
|
||||||
y_pred = pl_module.predict(mesh_input)
|
device=pl_module.device)
|
||||||
|
mesh_input = (mesh_input, None)
|
||||||
|
y_pred = pl_module(mesh_input)
|
||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx.cpu(), yy.cpu(), y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user