prototorch_models/examples/glvq_iris.py
Jensun Ravichandran ca39aa00d5 Stop passing component initializers as hparams
Pass the component initializer as an hparam slows down the script very much. The
API has now been changed to pass it as a kwarg to the models instead.

The example scripts have also been updated to reflect the new changes.

Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also
been added.
2021-05-12 16:36:22 +02:00

43 lines
1.1 KiB
Python

"""GLVQ example using the Iris dataset."""
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
# Hyperparameters
nclasses = 3
prototypes_per_class = 2
hparams = dict(
distribution=(nclasses, prototypes_per_class),
lr=0.01,
)
# Initialize the model
model = pt.models.GLVQ(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SMI(train_ds))
# Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
# Setup trainer
trainer = pl.Trainer(
max_epochs=50,
callbacks=[vis],
)
# Training loop
trainer.fit(model, train_loader)