3 Commits

Author SHA1 Message Date
Alexander Engelsberger
e87563e10d Bump version: 0.1.5 → 0.1.6 2021-05-11 13:41:26 +02:00
Alexander Engelsberger
767206f905 Define minimum prototorch version in setup 2021-05-11 13:41:09 +02:00
Alexander Engelsberger
3fa6378c4d Add LVQ1 and LVQ2.1 Models. 2021-05-11 13:26:13 +02:00
5 changed files with 89 additions and 7 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.5
current_version = 0.1.6
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

42
examples/lvq_iris.py Normal file
View File

@@ -0,0 +1,42 @@
"""Classical LVQ using GLVQ example on the Iris dataset."""
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
# Hyperparameters
hparams = dict(
nclasses=3,
prototypes_per_class=2,
prototype_initializer=pt.components.SMI(train_ds),
#prototype_initializer=pt.components.Random(2),
lr=0.005,
)
# Initialize the model
model = pt.models.LVQ1(hparams)
#model = pt.models.LVQ21(hparams)
# Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
# Setup trainer
trainer = pl.Trainer(
max_epochs=200,
callbacks=[vis],
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,8 +1,8 @@
from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ, LVQ1, LVQ21
from .neural_gas import NeuralGas
from .vis import *
__version__ = "0.1.5"
__version__ = "0.1.6"

View File

@@ -5,10 +5,12 @@ from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from .abstract import AbstractPrototypeModel
from torch.optim.lr_scheduler import ExponentialLR
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
@@ -30,6 +32,8 @@ class GLVQ(AbstractPrototypeModel):
self.transfer_function = get_activation(self.hparams.transfer_function)
self.train_acc = torchmetrics.Accuracy()
self.loss = glvq_loss
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
@@ -44,7 +48,7 @@ class GLVQ(AbstractPrototypeModel):
x = x.view(x.size(0), -1) # flatten
dis = self(x)
plabels = self.proto_layer.component_labels
mu = glvq_loss(dis, y, prototype_labels=plabels)
mu = self.loss(dis, y, prototype_labels=plabels)
batch_loss = self.transfer_function(mu,
beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
@@ -76,6 +80,42 @@ class GLVQ(AbstractPrototypeModel):
return y_pred.numpy()
class LVQ1(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq1_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class LVQ21(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq21_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class ImageGLVQ(GLVQ):
"""GLVQ for training on image data.

View File

@@ -19,7 +19,7 @@ DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh:
long_description = fh.read()
INSTALL_REQUIRES = ["prototorch", "pytorch_lightning", "torchmetrics"]
INSTALL_REQUIRES = ["prototorch>=0.4.1", "pytorch_lightning", "torchmetrics"]
DEV = ["bumpversion"]
EXAMPLES = ["matplotlib", "scikit-learn"]
TESTS = ["codecov", "pytest"]
@@ -27,7 +27,7 @@ ALL = DEV + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.1.5",
version="0.1.6",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,