feat: Add basic GLVQ with new architecture

This commit is contained in:
Alexander Engelsberger 2021-10-14 15:49:12 +02:00
parent d4448f2bc9
commit 967953442b
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
7 changed files with 433 additions and 6 deletions

View File

@ -3,6 +3,7 @@
import argparse
import prototorch as pt
import prototorch.models.expanded
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
@ -29,7 +30,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.GLVQ(
model = prototorch.models.expanded.GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),

View File

@ -0,0 +1 @@
from .glvq import GLVQ

View File

@ -0,0 +1,86 @@
from dataclasses import dataclass
from typing import Callable
import torch
from prototorch.core.competitions import WTAC
from prototorch.core.components import LabeledComponents
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import AbstractComponentsInitializer, LabelsInitializer
from prototorch.core.losses import GLVQLoss
from prototorch.models.expanded.clcc_scheme import CLCCScheme
from prototorch.nn.wrappers import LambdaLayer
@dataclass
class GLVQhparams:
distribution: dict
component_initializer: AbstractComponentsInitializer
distance_fn: Callable = euclidean_distance
lr: float = 0.01
margin: float = 0.0
# TODO: make nicer
transfer_fn: str = "identity"
transfer_beta: float = 10.0
optimizer: torch.optim.Optimizer = torch.optim.Adam
class GLVQ(CLCCScheme):
def __init__(self, hparams: GLVQhparams) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Initializers
def init_components(self, hparams):
# initialize Component Layer
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
def init_comparison(self, hparams):
# initialize Distance Layer
self.comparison_layer = LambdaLayer(hparams.distance_fn)
def init_inference(self, hparams):
self.competition_layer = WTAC()
def init_loss(self, hparams):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
beta=hparams.transfer_beta,
)
# Steps
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(batch_tensor, comp_tensor)
return distances
def inference(self, comparisonmeasures, components):
comp_labels = components[1]
return self.competition_layer(comparisonmeasures, comp_labels)
def loss(self, comparisonmeasures, batch, components):
target = batch[1]
comp_labels = components[1]
return self.loss_layer(comparisonmeasures, target, comp_labels)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr)
# Properties
@property
def prototypes(self):
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
return self.components_layer.labels.detach().cpu()

View File

@ -0,0 +1,138 @@
"""
CLCC Scheme
CLCC is a LVQ scheme containing 4 steps
- Components
- Latent Space
- Comparison
- Competition
"""
import pytorch_lightning as pl
class CLCCScheme(pl.LightningModule):
def __init__(self, hparams) -> None:
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# API
def get_competion(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competion(batch, components)
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competion(batch, components)
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Type hints
# TODO: Docs
def init_components(self, hparams):
...
def init_latent(self, hparams):
...
def init_comparison(self, hparams):
...
def init_competition(self, hparams):
...
def init_loss(self, hparams):
...
def init_inference(self, hparams):
...
# Empty Steps
# TODO: Type hints
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def latent(self, batch, components):
"""
The latent step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
Takes a batch of size N and the componentsset of size M.
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
def competition(self, comparisonmeasures, components):
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
def loss(self, comparisonmeasures, batch, components):
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
def inference(self, comparisonmeasures, components):
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
# Lightning Hooks
def training_step(self, batch, batch_idx, optimizer_idx=None):
return self.loss_forward(batch)
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)

View File

@ -0,0 +1,164 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.core.competitions import WTAC, wtac
from prototorch.core.components import Components, LabeledComponents
from prototorch.core.distances import (
euclidean_distance,
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.core.initializers import EyeTransformInitializer, LabelsInitializer
from prototorch.core.losses import GLVQLoss, lvq1_loss, lvq21_loss
from prototorch.core.pooling import stratified_min_pooling
from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
class GLVQ(pl.LightningModule):
def __init__(self, hparams, **kwargs):
super().__init__()
# Hyperparameters
self.save_hyperparameters(hparams)
# Default hparams
# TODO: Manage by an HPARAMS Object
self.hparams.setdefault("lr", 0.01)
self.hparams.setdefault("margin", 0.0)
self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
# Default config
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
distance_fn = kwargs.get("distance_fn", euclidean_distance)
prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer())
if prototypes_initializer is not None:
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
self.distance_layer = LambdaLayer(distance_fn)
self.competition_layer = WTAC()
self.loss = GLVQLoss(
margin=self.hparams.margin,
transfer_fn=self.hparams.transfer_fn,
beta=self.hparams.transfer_beta,
)
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
if self.lr_scheduler is not None:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
else:
return optimizer
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x)
_, plabels = self.proto_layer()
loss = self.loss(out, y, plabels)
return out, loss
def training_step(self, batch, batch_idx, optimizer_idx=None):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log_prototype_win_ratios(out)
self.log("train_loss", train_loss)
self.log_acc(out, batch[-1], tag="train_acc")
return train_loss
def validation_step(self, batch, batch_idx):
out, val_loss = self.shared_step(batch, batch_idx)
self.log("val_loss", val_loss)
self.log_acc(out, batch[-1], tag="val_acc")
return val_loss
def test_step(self, batch, batch_idx):
out, test_loss = self.shared_step(batch, batch_idx)
self.log_acc(out, batch[-1], tag="test_acc")
return test_loss
def test_epoch_end(self, outputs):
test_loss = 0.0
for batch_loss in outputs:
test_loss += batch_loss.item()
self.log("test_loss", test_loss)
# API
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self.compute_distances(x)
_, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels)
y_pred = torch.nn.functional.softmin(winning)
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
_, plabels = self.proto_layer()
y_pred = self.competition_layer(distances, plabels)
return y_pred
def predict(self, x):
with torch.no_grad():
distances = self.compute_distances(x)
y_pred = self.predict_from_distances(distances)
return y_pred
@property
def prototype_labels(self):
return self.proto_layer.labels.detach().cpu()
@property
def num_classes(self):
return self.proto_layer.num_classes
@property
def num_prototypes(self):
return len(self.proto_layer.components)
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
# Python overwrites
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
wrapped = f"ProtoTorch Bolt(\n{indented})"
return wrapped

View File

@ -0,0 +1,35 @@
import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
from prototorch.models.expanded.clcc_glvq import GLVQ, GLVQhparams
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision.transforms import Compose, Lambda, ToTensor
plt.gray()
if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
components_initializer = SMCI(train_ds)
hparams = GLVQhparams(
distribution=dict(
num_classes=3,
per_class=2,
),
component_initializer=components_initializer,
)
model = GLVQ(hparams)
print(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
# Train
trainer = pl.Trainer(callbacks=[vis], gpus=1)
trainer.fit(model, train_loader)

View File

@ -129,12 +129,14 @@ class VisGLVQ2D(Vis2DAbstract):
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
mesh_input, xx, yy = mesh2d(x,
self.border,
self.resolution,
device=pl_module.device)
mesh_input = (mesh_input, None)
y_pred = pl_module(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.contourf(xx.cpu(), yy.cpu(), y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)