Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
e87563e10d | ||
|
767206f905 | ||
|
3fa6378c4d | ||
|
30ee287ecc | ||
|
e323f9d4ca | ||
|
f49db0bf2c | ||
|
db38667306 | ||
|
54a8494d86 | ||
|
bf310be97c | ||
|
32ae1b7862 | ||
|
dfddb92aba | ||
|
4a38bb2bfe | ||
|
6680d4b9df | ||
|
1ae2b41edd | ||
|
9300a6d14d | ||
|
3d42876df1 | ||
|
fbadacdbca |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.1.0
|
current_version = 0.1.6
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
|
@@ -1,5 +1,8 @@
|
|||||||
# ProtoTorch Models
|
# ProtoTorch Models
|
||||||
|
|
||||||
|
[](https://travis-ci.org/si-cim/prototorch_models)
|
||||||
|
[](https://pypi.org/project/prototorch_models/)
|
||||||
|
|
||||||
Pre-packaged prototype-based machine learning models using ProtoTorch and
|
Pre-packaged prototype-based machine learning models using ProtoTorch and
|
||||||
PyTorch-Lightning.
|
PyTorch-Lightning.
|
||||||
|
|
||||||
|
42
examples/lvq_iris.py
Normal file
42
examples/lvq_iris.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""Classical LVQ using GLVQ example on the Iris dataset."""
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Dataset
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
x_train, y_train = load_iris(return_X_y=True)
|
||||||
|
x_train = x_train[:, [0, 2]]
|
||||||
|
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=150)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(
|
||||||
|
nclasses=3,
|
||||||
|
prototypes_per_class=2,
|
||||||
|
prototype_initializer=pt.components.SMI(train_ds),
|
||||||
|
#prototype_initializer=pt.components.Random(2),
|
||||||
|
lr=0.005,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.LVQ1(hparams)
|
||||||
|
#model = pt.models.LVQ21(hparams)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=200,
|
||||||
|
callbacks=[vis],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
@@ -1,8 +1,8 @@
|
|||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
from .cbc import CBC
|
from .cbc import CBC
|
||||||
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ
|
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ, LVQ1, LVQ21
|
||||||
from .neural_gas import NeuralGas
|
from .neural_gas import NeuralGas
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.6"
|
@@ -5,10 +5,12 @@ from prototorch.functions.activations import get_activation
|
|||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
||||||
squared_euclidean_distance)
|
squared_euclidean_distance)
|
||||||
from prototorch.functions.losses import glvq_loss
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
|
|
||||||
from .abstract import AbstractPrototypeModel
|
from .abstract import AbstractPrototypeModel
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
|
|
||||||
class GLVQ(AbstractPrototypeModel):
|
class GLVQ(AbstractPrototypeModel):
|
||||||
"""Generalized Learning Vector Quantization."""
|
"""Generalized Learning Vector Quantization."""
|
||||||
@@ -30,6 +32,8 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
self.transfer_function = get_activation(self.hparams.transfer_function)
|
self.transfer_function = get_activation(self.hparams.transfer_function)
|
||||||
self.train_acc = torchmetrics.Accuracy()
|
self.train_acc = torchmetrics.Accuracy()
|
||||||
|
|
||||||
|
self.loss = glvq_loss
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prototype_labels(self):
|
def prototype_labels(self):
|
||||||
return self.proto_layer.component_labels.detach().cpu()
|
return self.proto_layer.component_labels.detach().cpu()
|
||||||
@@ -44,7 +48,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
x = x.view(x.size(0), -1) # flatten
|
x = x.view(x.size(0), -1) # flatten
|
||||||
dis = self(x)
|
dis = self(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
mu = glvq_loss(dis, y, prototype_labels=plabels)
|
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||||
batch_loss = self.transfer_function(mu,
|
batch_loss = self.transfer_function(mu,
|
||||||
beta=self.hparams.transfer_beta)
|
beta=self.hparams.transfer_beta)
|
||||||
loss = batch_loss.sum(dim=0)
|
loss = batch_loss.sum(dim=0)
|
||||||
@@ -76,6 +80,42 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
return y_pred.numpy()
|
return y_pred.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
class LVQ1(GLVQ):
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
self.loss = lvq1_loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
|
||||||
|
scheduler = ExponentialLR(optimizer,
|
||||||
|
gamma=0.99,
|
||||||
|
last_epoch=-1,
|
||||||
|
verbose=False)
|
||||||
|
sch = {
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"interval": "step",
|
||||||
|
} # called after each training step
|
||||||
|
return [optimizer], [sch]
|
||||||
|
|
||||||
|
|
||||||
|
class LVQ21(GLVQ):
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
self.loss = lvq21_loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
|
||||||
|
scheduler = ExponentialLR(optimizer,
|
||||||
|
gamma=0.99,
|
||||||
|
last_epoch=-1,
|
||||||
|
verbose=False)
|
||||||
|
sch = {
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"interval": "step",
|
||||||
|
} # called after each training step
|
||||||
|
return [optimizer], [sch]
|
||||||
|
|
||||||
|
|
||||||
class ImageGLVQ(GLVQ):
|
class ImageGLVQ(GLVQ):
|
||||||
"""GLVQ for training on image data.
|
"""GLVQ for training on image data.
|
||||||
|
|
||||||
|
5
setup.py
5
setup.py
@@ -19,7 +19,7 @@ DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
|||||||
with open("README.md", "r") as fh:
|
with open("README.md", "r") as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
INSTALL_REQUIRES = ["prototorch", "pytorch_lightning", "torchmetrics"]
|
INSTALL_REQUIRES = ["prototorch>=0.4.1", "pytorch_lightning", "torchmetrics"]
|
||||||
DEV = ["bumpversion"]
|
DEV = ["bumpversion"]
|
||||||
EXAMPLES = ["matplotlib", "scikit-learn"]
|
EXAMPLES = ["matplotlib", "scikit-learn"]
|
||||||
TESTS = ["codecov", "pytest"]
|
TESTS = ["codecov", "pytest"]
|
||||||
@@ -27,10 +27,11 @@ ALL = DEV + EXAMPLES + TESTS
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=safe_name("prototorch_" + PLUGIN_NAME),
|
name=safe_name("prototorch_" + PLUGIN_NAME),
|
||||||
version="0.1.0",
|
version="0.1.6",
|
||||||
description="Pre-packaged prototype-based "
|
description="Pre-packaged prototype-based "
|
||||||
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
author="Alexander Engelsberger",
|
author="Alexander Engelsberger",
|
||||||
author_email="engelsbe@hs-mittweida.de",
|
author_email="engelsbe@hs-mittweida.de",
|
||||||
url=PROJECT_URL,
|
url=PROJECT_URL,
|
||||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
6
tests/test_dummy.py
Normal file
6
tests/test_dummy.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestDummy(unittest.TestCase):
|
||||||
|
def test_one(self):
|
||||||
|
self.assertEqual(True, True)
|
Reference in New Issue
Block a user