3 Commits

Author SHA1 Message Date
Alexander Engelsberger
e87563e10d Bump version: 0.1.5 → 0.1.6 2021-05-11 13:41:26 +02:00
Alexander Engelsberger
767206f905 Define minimum prototorch version in setup 2021-05-11 13:41:09 +02:00
Alexander Engelsberger
3fa6378c4d Add LVQ1 and LVQ2.1 Models. 2021-05-11 13:26:13 +02:00
5 changed files with 89 additions and 7 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.1.5 current_version = 0.1.6
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

42
examples/lvq_iris.py Normal file
View File

@@ -0,0 +1,42 @@
"""Classical LVQ using GLVQ example on the Iris dataset."""
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
# Hyperparameters
hparams = dict(
nclasses=3,
prototypes_per_class=2,
prototype_initializer=pt.components.SMI(train_ds),
#prototype_initializer=pt.components.Random(2),
lr=0.005,
)
# Initialize the model
model = pt.models.LVQ1(hparams)
#model = pt.models.LVQ21(hparams)
# Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
# Setup trainer
trainer = pl.Trainer(
max_epochs=200,
callbacks=[vis],
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,8 +1,8 @@
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC from .cbc import CBC
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ, LVQ1, LVQ21
from .neural_gas import NeuralGas from .neural_gas import NeuralGas
from .vis import * from .vis import *
__version__ = "0.1.5" __version__ = "0.1.6"

View File

@@ -5,10 +5,12 @@ from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance, from prototorch.functions.distances import (euclidean_distance, omega_distance,
squared_euclidean_distance) squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from .abstract import AbstractPrototypeModel from .abstract import AbstractPrototypeModel
from torch.optim.lr_scheduler import ExponentialLR
class GLVQ(AbstractPrototypeModel): class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization.""" """Generalized Learning Vector Quantization."""
@@ -30,6 +32,8 @@ class GLVQ(AbstractPrototypeModel):
self.transfer_function = get_activation(self.hparams.transfer_function) self.transfer_function = get_activation(self.hparams.transfer_function)
self.train_acc = torchmetrics.Accuracy() self.train_acc = torchmetrics.Accuracy()
self.loss = glvq_loss
@property @property
def prototype_labels(self): def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu() return self.proto_layer.component_labels.detach().cpu()
@@ -44,7 +48,7 @@ class GLVQ(AbstractPrototypeModel):
x = x.view(x.size(0), -1) # flatten x = x.view(x.size(0), -1) # flatten
dis = self(x) dis = self(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
mu = glvq_loss(dis, y, prototype_labels=plabels) mu = self.loss(dis, y, prototype_labels=plabels)
batch_loss = self.transfer_function(mu, batch_loss = self.transfer_function(mu,
beta=self.hparams.transfer_beta) beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0) loss = batch_loss.sum(dim=0)
@@ -76,6 +80,42 @@ class GLVQ(AbstractPrototypeModel):
return y_pred.numpy() return y_pred.numpy()
class LVQ1(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq1_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class LVQ21(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq21_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class ImageGLVQ(GLVQ): class ImageGLVQ(GLVQ):
"""GLVQ for training on image data. """GLVQ for training on image data.

View File

@@ -19,7 +19,7 @@ DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh: with open("README.md", "r") as fh:
long_description = fh.read() long_description = fh.read()
INSTALL_REQUIRES = ["prototorch", "pytorch_lightning", "torchmetrics"] INSTALL_REQUIRES = ["prototorch>=0.4.1", "pytorch_lightning", "torchmetrics"]
DEV = ["bumpversion"] DEV = ["bumpversion"]
EXAMPLES = ["matplotlib", "scikit-learn"] EXAMPLES = ["matplotlib", "scikit-learn"]
TESTS = ["codecov", "pytest"] TESTS = ["codecov", "pytest"]
@@ -27,7 +27,7 @@ ALL = DEV + EXAMPLES + TESTS
setup( setup(
name=safe_name("prototorch_" + PLUGIN_NAME), name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.1.5", version="0.1.6",
description="Pre-packaged prototype-based " description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.", "machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description, long_description=long_description,