Add LVQ1 and LVQ2.1 Models.
This commit is contained in:
parent
30ee287ecc
commit
3fa6378c4d
42
examples/lvq_iris.py
Normal file
42
examples/lvq_iris.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""Classical LVQ using GLVQ example on the Iris dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
from sklearn.datasets import load_iris
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
nclasses=3,
|
||||
prototypes_per_class=2,
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
#prototype_initializer=pt.components.Random(2),
|
||||
lr=0.005,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.LVQ1(hparams)
|
||||
#model = pt.models.LVQ21(hparams)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=200,
|
||||
callbacks=[vis],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@ -1,7 +1,7 @@
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from .cbc import CBC
|
||||
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ
|
||||
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ, LVQ1, LVQ21
|
||||
from .neural_gas import NeuralGas
|
||||
from .vis import *
|
||||
|
||||
|
@ -5,10 +5,12 @@ from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
||||
squared_euclidean_distance)
|
||||
from prototorch.functions.losses import glvq_loss
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
|
||||
from .abstract import AbstractPrototypeModel
|
||||
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
|
||||
class GLVQ(AbstractPrototypeModel):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
@ -30,6 +32,8 @@ class GLVQ(AbstractPrototypeModel):
|
||||
self.transfer_function = get_activation(self.hparams.transfer_function)
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
|
||||
self.loss = glvq_loss
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.component_labels.detach().cpu()
|
||||
@ -44,7 +48,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
x = x.view(x.size(0), -1) # flatten
|
||||
dis = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = glvq_loss(dis, y, prototype_labels=plabels)
|
||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||
batch_loss = self.transfer_function(mu,
|
||||
beta=self.hparams.transfer_beta)
|
||||
loss = batch_loss.sum(dim=0)
|
||||
@ -76,6 +80,42 @@ class GLVQ(AbstractPrototypeModel):
|
||||
return y_pred.numpy()
|
||||
|
||||
|
||||
class LVQ1(GLVQ):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = lvq1_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
|
||||
scheduler = ExponentialLR(optimizer,
|
||||
gamma=0.99,
|
||||
last_epoch=-1,
|
||||
verbose=False)
|
||||
sch = {
|
||||
"scheduler": scheduler,
|
||||
"interval": "step",
|
||||
} # called after each training step
|
||||
return [optimizer], [sch]
|
||||
|
||||
|
||||
class LVQ21(GLVQ):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = lvq21_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
|
||||
scheduler = ExponentialLR(optimizer,
|
||||
gamma=0.99,
|
||||
last_epoch=-1,
|
||||
verbose=False)
|
||||
sch = {
|
||||
"scheduler": scheduler,
|
||||
"interval": "step",
|
||||
} # called after each training step
|
||||
return [optimizer], [sch]
|
||||
|
||||
|
||||
class ImageGLVQ(GLVQ):
|
||||
"""GLVQ for training on image data.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user