From 2a4f1841637afa50939dda8aaaaa4fc0ad2ed93d Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 11 May 2021 16:15:08 +0200 Subject: [PATCH] Update example scripts --- examples/glvq_iris.py | 7 +++--- examples/glvq_spiral.py | 7 +++--- examples/gmlvq_iris.py | 5 +++-- examples/liramlvq_tecator.py | 5 +++-- examples/lvq_iris.py | 42 ----------------------------------- examples/siamese_glvq_iris.py | 7 +++--- 6 files changed, 17 insertions(+), 56 deletions(-) delete mode 100644 examples/lvq_iris.py diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 95982e7..94c15ce 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -17,15 +17,16 @@ if __name__ == "__main__": batch_size=150) # Hyperparameters + nclasses = 3 + prototypes_per_class = 2 hparams = dict( - nclasses=3, - prototypes_per_class=2, + distribution=(nclasses, prototypes_per_class), prototype_initializer=pt.components.SMI(train_ds), lr=0.01, ) # Initialize the model - model = pt.models.GLVQ(hparams) + model = pt.models.GLVQ(hparams, optimizer=torch.optim.Adam) # Callbacks vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py index 3fee454..fea4bf0 100644 --- a/examples/glvq_spiral.py +++ b/examples/glvq_spiral.py @@ -25,10 +25,11 @@ if __name__ == "__main__": batch_size=256) # Hyperparameters + nclasses = 2 + prototypes_per_class = 20 hparams = dict( - nclasses=2, - prototypes_per_class=20, - prototype_initializer=pt.components.SSI(train_ds, noise=1e-7), + distribution=(nclasses, prototypes_per_class), + prototype_initializer=pt.components.SSI(train_ds, noise=1e-1), transfer_function="sigmoid_beta", transfer_beta=10.0, lr=0.01, diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py index ba903e7..c8c5425 100644 --- a/examples/gmlvq_iris.py +++ b/examples/gmlvq_iris.py @@ -15,9 +15,10 @@ if __name__ == "__main__": num_workers=0, batch_size=150) # Hyperparameters + nclasses = 3 + prototypes_per_class = 1 hparams = dict( - nclasses=3, - prototypes_per_class=1, + distribution=(nclasses, prototypes_per_class), input_dim=x_train.shape[1], latent_dim=x_train.shape[1], prototype_initializer=pt.components.SMI(train_ds), diff --git a/examples/liramlvq_tecator.py b/examples/liramlvq_tecator.py index f948b87..1b07a0d 100644 --- a/examples/liramlvq_tecator.py +++ b/examples/liramlvq_tecator.py @@ -17,9 +17,10 @@ if __name__ == "__main__": batch_size=32) # Hyperparameters + nclasses = 2 + prototypes_per_class = 2 hparams = dict( - nclasses=2, - prototypes_per_class=2, + distribution=(nclasses, prototypes_per_class), input_dim=100, latent_dim=2, prototype_initializer=pt.components.SMI(train_ds), diff --git a/examples/lvq_iris.py b/examples/lvq_iris.py deleted file mode 100644 index af2beb7..0000000 --- a/examples/lvq_iris.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Classical LVQ using GLVQ example on the Iris dataset.""" - -import prototorch as pt -import pytorch_lightning as pl -import torch - -if __name__ == "__main__": - # Dataset - from sklearn.datasets import load_iris - x_train, y_train = load_iris(return_X_y=True) - x_train = x_train[:, [0, 2]] - train_ds = pt.datasets.NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - num_workers=0, - batch_size=150) - - # Hyperparameters - hparams = dict( - nclasses=3, - prototypes_per_class=2, - prototype_initializer=pt.components.SMI(train_ds), - #prototype_initializer=pt.components.Random(2), - lr=0.005, - ) - - # Initialize the model - model = pt.models.LVQ1(hparams) - #model = pt.models.LVQ21(hparams) - - # Callbacks - vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) - - # Setup trainer - trainer = pl.Trainer( - max_epochs=200, - callbacks=[vis], - ) - - # Training loop - trainer.fit(model, train_loader) diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index a6390b2..2f64ab4 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -38,11 +38,10 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( - nclasses=3, - prototypes_per_class=2, + distribution=[1, 2, 3], prototype_initializer=pt.components.SMI((x_train, y_train)), - proto_lr=0.001, - bb_lr=0.001, + proto_lr=0.01, + bb_lr=0.01, ) # Initialize the model