7 Commits

Author SHA1 Message Date
Alexander Engelsberger
1498c4bde5 Bump version: 0.1.6 → 0.1.7 2021-05-11 17:18:29 +02:00
Jensun Ravichandran
59b8ab6643 Add knn 2021-05-11 17:22:02 +02:00
Jensun Ravichandran
2a4f184163 Update example scripts 2021-05-11 16:15:08 +02:00
Jensun Ravichandran
265e74dd31 Require prototorch>=0.4.2 2021-05-11 16:14:47 +02:00
Jensun Ravichandran
daad018a78 Update readme 2021-05-11 16:14:23 +02:00
Jensun Ravichandran
eab1ec72c2 Change optimizer using kwargs 2021-05-11 16:13:00 +02:00
Jensun Ravichandran
b38acd58a8 [BUGFIX] Fix visualization callbacks bug 2021-05-11 16:09:27 +02:00
15 changed files with 139 additions and 99 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.1.6 current_version = 0.1.7
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@@ -8,24 +8,18 @@ PyTorch-Lightning.
## Installation ## Installation
To install this plugin, first install To install this plugin, simply run the following command:
[ProtoTorch](https://github.com/si-cim/prototorch) with:
```sh ```sh
git clone https://github.com/si-cim/prototorch.git && cd prototorch pip install prototorch_models
pip install -e .
```
and then install the plugin itself with:
```sh
git clone https://github.com/si-cim/prototorch_models.git && cd prototorch_models
pip install -e .
``` ```
The plugin should then be available for use in your Python environment as The plugin should then be available for use in your Python environment as
`prototorch.models`. `prototorch.models`.
*Note: Installing the models plugin should automatically install a suitable
version of * [ProtoTorch](https://github.com/si-cim/prototorch).
## Development setup ## Development setup
It is recommended that you use a virtual environment for development. If you do It is recommended that you use a virtual environment for development. If you do
@@ -57,17 +51,20 @@ To assist in the development process, you may also find it useful to install
## Available models ## Available models
- K-Nearest Neighbors (KNN)
- Learning Vector Quantization 1 (LVQ1)
- Generalized Learning Vector Quantization (GLVQ) - Generalized Learning Vector Quantization (GLVQ)
- Generalized Relevance Learning Vector Quantization (GRLVQ) - Generalized Relevance Learning Vector Quantization (GRLVQ)
- Generalized Matrix Learning Vector Quantization (GMLVQ) - Generalized Matrix Learning Vector Quantization (GMLVQ)
- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ) - Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
- Siamese GLVQ - Siamese GLVQ
- Neural Gas (NG) - Neural Gas (NG)
## Work in Progress ## Work in Progress
- Classification-By-Components Network (CBC) - Classification-By-Components Network (CBC)
- Learning Vector Quantization Multi-Layer Network (LVQMLN) - Learning Vector Quantization 2.1 (LVQ2.1)
## Planned models ## Planned models
@@ -76,8 +73,6 @@ To assist in the development process, you may also find it useful to install
- Robust Soft Learning Vector Quantization (RSLVQ) - Robust Soft Learning Vector Quantization (RSLVQ)
- Probabilistic Learning Vector Quantization (PLVQ) - Probabilistic Learning Vector Quantization (PLVQ)
- Self-Incremental Learning Vector Quantization (SILVQ) - Self-Incremental Learning Vector Quantization (SILVQ)
- K-Nearest Neighbors (KNN)
- Learning Vector Quantization 1 (LVQ1)
## FAQ ## FAQ

View File

@@ -17,15 +17,16 @@ if __name__ == "__main__":
batch_size=150) batch_size=150)
# Hyperparameters # Hyperparameters
nclasses = 3
prototypes_per_class = 2
hparams = dict( hparams = dict(
nclasses=3, distribution=(nclasses, prototypes_per_class),
prototypes_per_class=2,
prototype_initializer=pt.components.SMI(train_ds), prototype_initializer=pt.components.SMI(train_ds),
lr=0.01, lr=0.01,
) )
# Initialize the model # Initialize the model
model = pt.models.GLVQ(hparams) model = pt.models.GLVQ(hparams, optimizer=torch.optim.Adam)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) vis = pt.models.VisGLVQ2D(data=(x_train, y_train))

View File

@@ -25,10 +25,11 @@ if __name__ == "__main__":
batch_size=256) batch_size=256)
# Hyperparameters # Hyperparameters
nclasses = 2
prototypes_per_class = 20
hparams = dict( hparams = dict(
nclasses=2, distribution=(nclasses, prototypes_per_class),
prototypes_per_class=20, prototype_initializer=pt.components.SSI(train_ds, noise=1e-1),
prototype_initializer=pt.components.SSI(train_ds, noise=1e-7),
transfer_function="sigmoid_beta", transfer_function="sigmoid_beta",
transfer_beta=10.0, transfer_beta=10.0,
lr=0.01, lr=0.01,

View File

@@ -15,9 +15,10 @@ if __name__ == "__main__":
num_workers=0, num_workers=0,
batch_size=150) batch_size=150)
# Hyperparameters # Hyperparameters
nclasses = 3
prototypes_per_class = 1
hparams = dict( hparams = dict(
nclasses=3, distribution=(nclasses, prototypes_per_class),
prototypes_per_class=1,
input_dim=x_train.shape[1], input_dim=x_train.shape[1],
latent_dim=x_train.shape[1], latent_dim=x_train.shape[1],
prototype_initializer=pt.components.SMI(train_ds), prototype_initializer=pt.components.SMI(train_ds),

View File

@@ -1,4 +1,4 @@
"""Classical LVQ using GLVQ example on the Iris dataset.""" """k-NN example using the Iris dataset."""
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
@@ -17,26 +17,21 @@ if __name__ == "__main__":
batch_size=150) batch_size=150)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(k=20)
nclasses=3,
prototypes_per_class=2,
prototype_initializer=pt.components.SMI(train_ds),
#prototype_initializer=pt.components.Random(2),
lr=0.005,
)
# Initialize the model # Initialize the model
model = pt.models.LVQ1(hparams) model = pt.models.KNN(hparams, data=train_ds)
#model = pt.models.LVQ21(hparams)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer(max_epochs=1, callbacks=[vis])
max_epochs=200,
callbacks=[vis],
)
# Training loop # Training loop
# This is only for visualization. k-NN has no training phase.
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
# Recall
y_pred = model.predict(torch.tensor(x_train))
print(y_pred)

View File

@@ -17,9 +17,10 @@ if __name__ == "__main__":
batch_size=32) batch_size=32)
# Hyperparameters # Hyperparameters
nclasses = 2
prototypes_per_class = 2
hparams = dict( hparams = dict(
nclasses=2, distribution=(nclasses, prototypes_per_class),
prototypes_per_class=2,
input_dim=100, input_dim=100,
latent_dim=2, latent_dim=2,
prototype_initializer=pt.components.SMI(train_ds), prototype_initializer=pt.components.SMI(train_ds),

View File

@@ -24,9 +24,7 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
from sklearn.datasets import load_iris train_ds = pt.datasets.Iris()
x_train, y_train = load_iris(return_X_y=True)
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=2) pl.utilities.seed.seed_everything(seed=2)
@@ -38,11 +36,10 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
nclasses=3, distribution=[1, 2, 3],
prototypes_per_class=2, prototype_initializer=pt.components.SMI(train_ds),
prototype_initializer=pt.components.SMI((x_train, y_train)), proto_lr=0.01,
proto_lr=0.001, bb_lr=0.01,
bb_lr=0.001,
) )
# Initialize the model # Initialize the model
@@ -55,7 +52,7 @@ if __name__ == "__main__":
print(model) print(model)
# Callbacks # Callbacks
vis = pt.models.VisSiameseGLVQ2D(data=(x_train, y_train), border=0.1) vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) trainer = pl.Trainer(max_epochs=100, callbacks=[vis])

View File

@@ -1,8 +1,10 @@
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC from .cbc import CBC
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ, LVQ1, LVQ21 from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ,
SiameseGLVQ)
from .knn import KNN
from .neural_gas import NeuralGas from .neural_gas import NeuralGas
from .vis import * from .vis import *
__version__ = "0.1.6" __version__ = "0.1.7"

View File

@@ -3,9 +3,13 @@ import torch
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
class AbstractLightningModel(pl.LightningModule): class AbstractPrototypeModel(pl.LightningModule):
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer, scheduler = ExponentialLR(optimizer,
gamma=0.99, gamma=0.99,
last_epoch=-1, last_epoch=-1,
@@ -15,9 +19,3 @@ class AbstractLightningModel(pl.LightningModule):
"interval": "step", "interval": "step",
} # called after each training step } # called after each training step
return [optimizer], [sch] return [optimizer], [sch]
class AbstractPrototypeModel(AbstractLightningModel):
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()

View File

@@ -9,8 +9,6 @@ from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from .abstract import AbstractPrototypeModel from .abstract import AbstractPrototypeModel
from torch.optim.lr_scheduler import ExponentialLR
class GLVQ(AbstractPrototypeModel): class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization.""" """Generalized Learning Vector Quantization."""
@@ -19,14 +17,15 @@ class GLVQ(AbstractPrototypeModel):
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
# Default Values # Default Values
self.hparams.setdefault("distance", euclidean_distance) self.hparams.setdefault("distance", euclidean_distance)
self.hparams.setdefault("optimizer", torch.optim.Adam)
self.hparams.setdefault("transfer_function", "identity") self.hparams.setdefault("transfer_function", "identity")
self.hparams.setdefault("transfer_beta", 10.0) self.hparams.setdefault("transfer_beta", 10.0)
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class), distribution=self.hparams.distribution,
initializer=self.hparams.prototype_initializer) initializer=self.hparams.prototype_initializer)
self.transfer_function = get_activation(self.hparams.transfer_function) self.transfer_function = get_activation(self.hparams.transfer_function)
@@ -81,39 +80,19 @@ class GLVQ(AbstractPrototypeModel):
class LVQ1(GLVQ): class LVQ1(GLVQ):
"""Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.loss = lvq1_loss self.loss = lvq1_loss
self.optimizer = torch.optim.SGD
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class LVQ21(GLVQ): class LVQ21(GLVQ):
"""Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.loss = lvq21_loss self.loss = lvq21_loss
self.optimizer = torch.optim.SGD
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class ImageGLVQ(GLVQ): class ImageGLVQ(GLVQ):
@@ -152,13 +131,13 @@ class SiameseGLVQ(GLVQ):
self.backbone_dependent.load_state_dict(master_state, strict=True) self.backbone_dependent.load_state_dict(master_state, strict=True)
def configure_optimizers(self): def configure_optimizers(self):
optim = self.hparams.optimizer proto_opt = self.optimizer(self.proto_layer.parameters(),
proto_opt = optim(self.proto_layer.parameters(), lr=self.hparams.proto_lr)
lr=self.hparams.proto_lr)
if list(self.backbone.parameters()): if list(self.backbone.parameters()):
# only add an optimizer is the backbone has trainable parameters # only add an optimizer is the backbone has trainable parameters
# otherwise, the next line fails # otherwise, the next line fails
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr) bb_opt = self.optimizer(self.backbone.parameters(),
lr=self.hparams.bb_lr)
return proto_opt, bb_opt return proto_opt, bb_opt
else: else:
return proto_opt return proto_opt

62
prototorch/models/knn.py Normal file
View File

@@ -0,0 +1,62 @@
"""The popular K-Nearest-Neighbors classification algorithm."""
import warnings
import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.components.initializers import parse_init_arg
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance
from .abstract import AbstractPrototypeModel
class KNN(AbstractPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("k", 1)
self.hparams.setdefault("distance", euclidean_distance)
data = kwargs.get("data")
x_train, y_train = parse_init_arg(data)
self.proto_layer = LabeledComponents(initialized_components=(x_train,
y_train))
self.train_acc = torchmetrics.Accuracy()
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
def forward(self, x):
protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos)
return dis
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.component_labels
y_pred = knnc(d, plabels, k=self.hparams.k)
return y_pred.numpy()
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None

View File

@@ -1,6 +1,7 @@
import torch import torch
from prototorch.components import Components from prototorch.components import Components
from prototorch.components import initializers as cinit from prototorch.components import initializers as cinit
from prototorch.components.initializers import ZerosInitializer
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import NeuralGasEnergy from prototorch.modules.losses import NeuralGasEnergy
@@ -41,12 +42,14 @@ class NeuralGas(AbstractPrototypeModel):
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
# Default Values # Default Values
self.hparams.setdefault("input_dim", 2) self.hparams.setdefault("input_dim", 2)
self.hparams.setdefault("agelimit", 10) self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1) self.hparams.setdefault("lm", 1)
self.hparams.setdefault("prototype_initializer", self.hparams.setdefault("prototype_initializer",
cinit.ZerosInitializer(self.hparams.input_dim)) ZerosInitializer(self.hparams.input_dim))
self.proto_layer = Components( self.proto_layer = Components(
self.hparams.num_prototypes, self.hparams.num_prototypes,

View File

@@ -298,7 +298,8 @@ class Vis2DAbstract(pl.Callback):
def precheck(self, trainer): def precheck(self, trainer):
if self.show_last_only: if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1: if trainer.current_epoch != trainer.max_epochs - 1:
return return False
return True
def setup_ax(self, xlabel=None, ylabel=None): def setup_ax(self, xlabel=None, ylabel=None):
ax = self.fig.gca() ax = self.fig.gca()
@@ -362,7 +363,8 @@ class Vis2DAbstract(pl.Callback):
class VisGLVQ2D(Vis2DAbstract): class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer) if not self.precheck(trainer):
return True
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
@@ -386,7 +388,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
self.map_protos = map_protos self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer) if not self.precheck(trainer):
return True
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
@@ -411,14 +414,15 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
class VisCBC2D(Vis2DAbstract): class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer) if not self.precheck(trainer):
return True
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
protos = pl_module.components protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1", ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2") ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input)) y_pred = pl_module.predict(torch.Tensor(mesh_input))
@@ -431,7 +435,8 @@ class VisCBC2D(Vis2DAbstract):
class VisNG2D(Vis2DAbstract): class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer) if not self.precheck(trainer):
return True
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes protos = pl_module.prototypes

View File

@@ -19,7 +19,7 @@ DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh: with open("README.md", "r") as fh:
long_description = fh.read() long_description = fh.read()
INSTALL_REQUIRES = ["prototorch>=0.4.1", "pytorch_lightning", "torchmetrics"] INSTALL_REQUIRES = ["prototorch>=0.4.4", "pytorch_lightning", "torchmetrics"]
DEV = ["bumpversion"] DEV = ["bumpversion"]
EXAMPLES = ["matplotlib", "scikit-learn"] EXAMPLES = ["matplotlib", "scikit-learn"]
TESTS = ["codecov", "pytest"] TESTS = ["codecov", "pytest"]
@@ -27,7 +27,7 @@ ALL = DEV + EXAMPLES + TESTS
setup( setup(
name=safe_name("prototorch_" + PLUGIN_NAME), name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.1.6", version="0.1.7",
description="Pre-packaged prototype-based " description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.", "machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description, long_description=long_description,