diff --git a/examples/lvq_iris.py b/examples/lvq_iris.py new file mode 100644 index 0000000..af2beb7 --- /dev/null +++ b/examples/lvq_iris.py @@ -0,0 +1,42 @@ +"""Classical LVQ using GLVQ example on the Iris dataset.""" + +import prototorch as pt +import pytorch_lightning as pl +import torch + +if __name__ == "__main__": + # Dataset + from sklearn.datasets import load_iris + x_train, y_train = load_iris(return_X_y=True) + x_train = x_train[:, [0, 2]] + train_ds = pt.datasets.NumpyDataset(x_train, y_train) + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) + + # Hyperparameters + hparams = dict( + nclasses=3, + prototypes_per_class=2, + prototype_initializer=pt.components.SMI(train_ds), + #prototype_initializer=pt.components.Random(2), + lr=0.005, + ) + + # Initialize the model + model = pt.models.LVQ1(hparams) + #model = pt.models.LVQ21(hparams) + + # Callbacks + vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) + + # Setup trainer + trainer = pl.Trainer( + max_epochs=200, + callbacks=[vis], + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index 9fc99ee..d806f07 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -1,7 +1,7 @@ from importlib.metadata import PackageNotFoundError, version from .cbc import CBC -from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ +from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ, LVQ1, LVQ21 from .neural_gas import NeuralGas from .vis import * diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index bd8c3be..b88ea23 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -5,10 +5,12 @@ from prototorch.functions.activations import get_activation from prototorch.functions.competitions import wtac from prototorch.functions.distances import (euclidean_distance, omega_distance, squared_euclidean_distance) -from prototorch.functions.losses import glvq_loss +from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from .abstract import AbstractPrototypeModel +from torch.optim.lr_scheduler import ExponentialLR + class GLVQ(AbstractPrototypeModel): """Generalized Learning Vector Quantization.""" @@ -30,6 +32,8 @@ class GLVQ(AbstractPrototypeModel): self.transfer_function = get_activation(self.hparams.transfer_function) self.train_acc = torchmetrics.Accuracy() + self.loss = glvq_loss + @property def prototype_labels(self): return self.proto_layer.component_labels.detach().cpu() @@ -44,7 +48,7 @@ class GLVQ(AbstractPrototypeModel): x = x.view(x.size(0), -1) # flatten dis = self(x) plabels = self.proto_layer.component_labels - mu = glvq_loss(dis, y, prototype_labels=plabels) + mu = self.loss(dis, y, prototype_labels=plabels) batch_loss = self.transfer_function(mu, beta=self.hparams.transfer_beta) loss = batch_loss.sum(dim=0) @@ -76,6 +80,42 @@ class GLVQ(AbstractPrototypeModel): return y_pred.numpy() +class LVQ1(GLVQ): + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.loss = lvq1_loss + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr) + scheduler = ExponentialLR(optimizer, + gamma=0.99, + last_epoch=-1, + verbose=False) + sch = { + "scheduler": scheduler, + "interval": "step", + } # called after each training step + return [optimizer], [sch] + + +class LVQ21(GLVQ): + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.loss = lvq21_loss + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr) + scheduler = ExponentialLR(optimizer, + gamma=0.99, + last_epoch=-1, + verbose=False) + sch = { + "scheduler": scheduler, + "interval": "step", + } # called after each training step + return [optimizer], [sch] + + class ImageGLVQ(GLVQ): """GLVQ for training on image data.