diff --git a/examples/cbc_iris.py b/examples/cbc_iris.py index 08ced33..5b7d32d 100644 --- a/examples/cbc_iris.py +++ b/examples/cbc_iris.py @@ -2,10 +2,11 @@ import argparse -import prototorch as pt import pytorch_lightning as pl import torch +import prototorch as pt + if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 2aeb713..7ffab6c 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -2,10 +2,10 @@ import argparse -import prototorch as pt import pytorch_lightning as pl import torch -from sklearn.datasets import load_iris + +import prototorch as pt if __name__ == "__main__": # Command-line arguments @@ -14,14 +14,10 @@ if __name__ == "__main__": args = parser.parse_args() # Dataset - x_train, y_train = load_iris(return_X_y=True) - x_train = x_train[:, [0, 2]] - train_ds = pt.datasets.NumpyDataset(x_train, y_train) + train_ds = pt.datasets.Iris(dims=[0, 2]) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - num_workers=0, - batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) # Hyperparameters hparams = dict( @@ -38,7 +34,7 @@ if __name__ == "__main__": prototype_initializer=pt.components.SMI(train_ds)) # Callbacks - vis = pt.models.VisGLVQ2D(data=(x_train, y_train), block=False) + vis = pt.models.VisGLVQ2D(data=train_ds) # Setup trainer trainer = pl.Trainer.from_argparse_args( diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py index 297a099..8757844 100644 --- a/examples/glvq_spiral.py +++ b/examples/glvq_spiral.py @@ -2,10 +2,11 @@ import argparse -import prototorch as pt import pytorch_lightning as pl import torch +import prototorch as pt + if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py index f81742e..9d55baf 100644 --- a/examples/gmlvq_iris.py +++ b/examples/gmlvq_iris.py @@ -2,11 +2,12 @@ import argparse -import prototorch as pt import pytorch_lightning as pl import torch from sklearn.datasets import load_iris +import prototorch as pt + if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() @@ -18,9 +19,8 @@ if __name__ == "__main__": train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - num_workers=0, - batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) + # Hyperparameters num_classes = 3 prototypes_per_class = 1 diff --git a/examples/gmlvq_mnist.py b/examples/gmlvq_mnist.py index 0286f39..5083199 100644 --- a/examples/gmlvq_mnist.py +++ b/examples/gmlvq_mnist.py @@ -2,12 +2,13 @@ import argparse -import prototorch as pt import pytorch_lightning as pl import torch from torchvision import transforms from torchvision.datasets import MNIST +import prototorch as pt + if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() diff --git a/examples/knn_iris.py b/examples/knn_iris.py index 5d708be..e8e805c 100644 --- a/examples/knn_iris.py +++ b/examples/knn_iris.py @@ -2,11 +2,12 @@ import argparse -import prototorch as pt import pytorch_lightning as pl import torch from sklearn.datasets import load_iris +import prototorch as pt + if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() @@ -19,9 +20,7 @@ if __name__ == "__main__": train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - num_workers=0, - batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) # Hyperparameters hparams = dict(k=20) diff --git a/examples/liramlvq_tecator.py b/examples/liramlvq_tecator.py index 9d70dcf..21dc1b9 100644 --- a/examples/liramlvq_tecator.py +++ b/examples/liramlvq_tecator.py @@ -2,10 +2,11 @@ import argparse -import prototorch as pt import pytorch_lightning as pl import torch +import prototorch as pt + if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() @@ -24,10 +25,11 @@ if __name__ == "__main__": test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32) # Hyperparameters - num_classes = 2 - prototypes_per_class = 2 hparams = dict( - distribution=(num_classes, prototypes_per_class), + distribution={ + "num_classes": 3, + "prototypes_per_class": 4 + }, input_dim=100, latent_dim=2, proto_lr=0.001, diff --git a/examples/lvqmln_iris.py b/examples/lvqmln_iris.py index 7bf1e1f..c688788 100644 --- a/examples/lvqmln_iris.py +++ b/examples/lvqmln_iris.py @@ -1,10 +1,12 @@ """LVQMLN example using all four dimensions of the Iris dataset.""" + import argparse -import prototorch as pt import pytorch_lightning as pl import torch +import prototorch as pt + class Backbone(torch.nn.Module): def __init__(self, input_size=4, hidden_size=10, latent_size=2): @@ -35,9 +37,7 @@ if __name__ == "__main__": pl.utilities.seed.seed_everything(seed=42) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - num_workers=0, - batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) # Hyperparameters hparams = dict( diff --git a/examples/ng_iris.py b/examples/ng_iris.py index 2dc11a7..71e2701 100644 --- a/examples/ng_iris.py +++ b/examples/ng_iris.py @@ -2,12 +2,13 @@ import argparse -import prototorch as pt import pytorch_lightning as pl import torch from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler +import prototorch as pt + if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() @@ -24,9 +25,7 @@ if __name__ == "__main__": train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - num_workers=0, - batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) # Hyperparameters hparams = dict(num_prototypes=30, lr=0.03) diff --git a/examples/probabilistic.py b/examples/probabilistic.py index 2182bf2..608789d 100644 --- a/examples/probabilistic.py +++ b/examples/probabilistic.py @@ -1,11 +1,11 @@ -"""GLVQ example using the Iris dataset.""" +"""Probabilistic-LVQ example using the Iris dataset.""" import argparse -import prototorch as pt import pytorch_lightning as pl import torch -from sklearn.datasets import load_iris + +import prototorch as pt if __name__ == "__main__": # Command-line arguments @@ -14,14 +14,10 @@ if __name__ == "__main__": args = parser.parse_args() # Dataset - x_train, y_train = load_iris(return_X_y=True) - x_train = x_train[:, [0, 2]] - train_ds = pt.datasets.NumpyDataset(x_train, y_train) + train_ds = pt.datasets.Iris(dims=[0, 2]) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - num_workers=0, - batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) # Hyperparameters num_classes = 3 @@ -29,20 +25,19 @@ if __name__ == "__main__": hparams = dict( distribution=(num_classes, prototypes_per_class), lr=0.05, - variance=1, + variance=1.0, ) # Initialize the model model = pt.models.probabilistic.LikelihoodRatioLVQ( - #model = pt.models.probabilistic.RSLVQ( hparams, optimizer=torch.optim.Adam, - #prototype_initializer=pt.components.SSI(train_ds, noise=2), - prototype_initializer=pt.components.UniformInitializer(2), + # prototype_initializer=pt.components.UniformInitializer(2), + prototype_initializer=pt.components.SMI(train_ds), ) # Callbacks - vis = pt.models.VisGLVQ2D(data=(x_train, y_train), block=False) + vis = pt.models.VisGLVQ2D(data=train_ds) # Setup trainer trainer = pl.Trainer.from_argparse_args( diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index 925b0a2..cdd279d 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -2,10 +2,11 @@ import argparse -import prototorch as pt import pytorch_lightning as pl import torch +import prototorch as pt + class Backbone(torch.nn.Module): def __init__(self, input_size=4, hidden_size=10, latent_size=2): @@ -36,9 +37,7 @@ if __name__ == "__main__": pl.utilities.seed.seed_everything(seed=2) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - num_workers=0, - batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) # Hyperparameters hparams = dict(