This commit is contained in:
Jensun Ravichandran 2021-05-11 17:22:02 +02:00
parent 2a4f184163
commit 59b8ab6643
6 changed files with 108 additions and 9 deletions

View File

@ -51,6 +51,7 @@ To assist in the development process, you may also find it useful to install
## Available models ## Available models
- K-Nearest Neighbors (KNN)
- Learning Vector Quantization 1 (LVQ1) - Learning Vector Quantization 1 (LVQ1)
- Generalized Learning Vector Quantization (GLVQ) - Generalized Learning Vector Quantization (GLVQ)
- Generalized Relevance Learning Vector Quantization (GRLVQ) - Generalized Relevance Learning Vector Quantization (GRLVQ)
@ -72,7 +73,6 @@ To assist in the development process, you may also find it useful to install
- Robust Soft Learning Vector Quantization (RSLVQ) - Robust Soft Learning Vector Quantization (RSLVQ)
- Probabilistic Learning Vector Quantization (PLVQ) - Probabilistic Learning Vector Quantization (PLVQ)
- Self-Incremental Learning Vector Quantization (SILVQ) - Self-Incremental Learning Vector Quantization (SILVQ)
- K-Nearest Neighbors (KNN)
## FAQ ## FAQ

37
examples/knn_iris.py Normal file
View File

@ -0,0 +1,37 @@
"""k-NN example using the Iris dataset."""
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
# Hyperparameters
hparams = dict(k=20)
# Initialize the model
model = pt.models.KNN(hparams, data=train_ds)
# Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
# Setup trainer
trainer = pl.Trainer(max_epochs=1, callbacks=[vis])
# Training loop
# This is only for visualization. k-NN has no training phase.
trainer.fit(model, train_loader)
# Recall
y_pred = model.predict(torch.tensor(x_train))
print(y_pred)

View File

@ -24,9 +24,7 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
from sklearn.datasets import load_iris train_ds = pt.datasets.Iris()
x_train, y_train = load_iris(return_X_y=True)
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=2) pl.utilities.seed.seed_everything(seed=2)
@ -39,7 +37,7 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[1, 2, 3], distribution=[1, 2, 3],
prototype_initializer=pt.components.SMI((x_train, y_train)), prototype_initializer=pt.components.SMI(train_ds),
proto_lr=0.01, proto_lr=0.01,
bb_lr=0.01, bb_lr=0.01,
) )
@ -54,7 +52,7 @@ if __name__ == "__main__":
print(model) print(model)
# Callbacks # Callbacks
vis = pt.models.VisSiameseGLVQ2D(data=(x_train, y_train), border=0.1) vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) trainer = pl.Trainer(max_epochs=100, callbacks=[vis])

View File

@ -1,7 +1,9 @@
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC from .cbc import CBC
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ, LVQ1, LVQ21 from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ,
SiameseGLVQ)
from .knn import KNN
from .neural_gas import NeuralGas from .neural_gas import NeuralGas
from .vis import * from .vis import *

62
prototorch/models/knn.py Normal file
View File

@ -0,0 +1,62 @@
"""The popular K-Nearest-Neighbors classification algorithm."""
import warnings
import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.components.initializers import parse_init_arg
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance
from .abstract import AbstractPrototypeModel
class KNN(AbstractPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("k", 1)
self.hparams.setdefault("distance", euclidean_distance)
data = kwargs.get("data")
x_train, y_train = parse_init_arg(data)
self.proto_layer = LabeledComponents(initialized_components=(x_train,
y_train))
self.train_acc = torchmetrics.Accuracy()
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
def forward(self, x):
protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos)
return dis
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.component_labels
y_pred = knnc(d, plabels, k=self.hparams.k)
return y_pred.numpy()
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None

View File

@ -19,7 +19,7 @@ DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh: with open("README.md", "r") as fh:
long_description = fh.read() long_description = fh.read()
INSTALL_REQUIRES = ["prototorch>=0.4.2", "pytorch_lightning", "torchmetrics"] INSTALL_REQUIRES = ["prototorch>=0.4.4", "pytorch_lightning", "torchmetrics"]
DEV = ["bumpversion"] DEV = ["bumpversion"]
EXAMPLES = ["matplotlib", "scikit-learn"] EXAMPLES = ["matplotlib", "scikit-learn"]
TESTS = ["codecov", "pytest"] TESTS = ["codecov", "pytest"]