Add LVQ1 and LVQ2.1 Models.

This commit is contained in:
Alexander Engelsberger 2021-05-11 13:26:13 +02:00
parent 30ee287ecc
commit 3fa6378c4d
3 changed files with 85 additions and 3 deletions

42
examples/lvq_iris.py Normal file
View File

@ -0,0 +1,42 @@
"""Classical LVQ using GLVQ example on the Iris dataset."""
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
# Hyperparameters
hparams = dict(
nclasses=3,
prototypes_per_class=2,
prototype_initializer=pt.components.SMI(train_ds),
#prototype_initializer=pt.components.Random(2),
lr=0.005,
)
# Initialize the model
model = pt.models.LVQ1(hparams)
#model = pt.models.LVQ21(hparams)
# Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
# Setup trainer
trainer = pl.Trainer(
max_epochs=200,
callbacks=[vis],
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -1,7 +1,7 @@
from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ, LVQ1, LVQ21
from .neural_gas import NeuralGas
from .vis import *

View File

@ -5,10 +5,12 @@ from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from .abstract import AbstractPrototypeModel
from torch.optim.lr_scheduler import ExponentialLR
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
@ -30,6 +32,8 @@ class GLVQ(AbstractPrototypeModel):
self.transfer_function = get_activation(self.hparams.transfer_function)
self.train_acc = torchmetrics.Accuracy()
self.loss = glvq_loss
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
@ -44,7 +48,7 @@ class GLVQ(AbstractPrototypeModel):
x = x.view(x.size(0), -1) # flatten
dis = self(x)
plabels = self.proto_layer.component_labels
mu = glvq_loss(dis, y, prototype_labels=plabels)
mu = self.loss(dis, y, prototype_labels=plabels)
batch_loss = self.transfer_function(mu,
beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
@ -76,6 +80,42 @@ class GLVQ(AbstractPrototypeModel):
return y_pred.numpy()
class LVQ1(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq1_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class LVQ21(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq21_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class ImageGLVQ(GLVQ):
"""GLVQ for training on image data.