From 76fea3f881fe26b040b46f86b21093857f4f7c11 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Tue, 17 May 2022 12:03:43 +0200 Subject: [PATCH] chore: update all examples to pytorch 1.6 --- examples/cbc_iris.py | 40 ++++++++++++++++++--------- examples/dynamic_pruning.py | 50 +++++++++++++++++++++++----------- examples/glvq_iris.py | 29 +++++++++++++++----- examples/gmlvq_iris.py | 27 ++++++++++++++---- examples/gmlvq_mnist.py | 39 ++++++++++++++++---------- examples/gmlvq_spiral.py | 30 ++++++++++++++++---- examples/gng_iris.py | 26 +++++++++++++----- examples/gtlvq_mnist.py | 40 +++++++++++++++++---------- examples/gtlvq_moons.py | 35 ++++++++++++++++-------- examples/knn_iris.py | 38 +++++++++++++++++--------- examples/ksom_colors.py | 44 ++++++++++++++++++++---------- examples/lgmlvq_moons.py | 29 +++++++++++++------- examples/lvqmln_iris.py | 28 +++++++++++++------ examples/median_lvq_iris.py | 28 +++++++++++++++---- examples/ng_iris.py | 30 ++++++++++++++------ examples/rslvq_iris.py | 29 ++++++++++++-------- examples/siamese_glvq_iris.py | 26 ++++++++++++------ examples/siamese_gtlvq_iris.py | 38 +++++++++++++++++--------- examples/warm_starting.py | 47 +++++++++++++++++++++++--------- 19 files changed, 453 insertions(+), 200 deletions(-) diff --git a/examples/cbc_iris.py b/examples/cbc_iris.py index f0561af..30610af 100644 --- a/examples/cbc_iris.py +++ b/examples/cbc_iris.py @@ -1,12 +1,22 @@ """CBC example using the Iris dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl -import torch +from prototorch.models import CBC, VisCBC2D +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": + # Reproducibility + seed_everything(seed=4) + # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) @@ -15,11 +25,8 @@ if __name__ == "__main__": # Dataset train_ds = pt.datasets.Iris(dims=[0, 2]) - # Reproducibility - pl.utilities.seed.seed_everything(seed=42) - # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32) + train_loader = DataLoader(train_ds, batch_size=32) # Hyperparameters hparams = dict( @@ -30,23 +37,30 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.CBC( + model = CBC( hparams, - components_initializer=pt.initializers.SSCI(train_ds, noise=0.01), - reasonings_iniitializer=pt.initializers. + components_initializer=pt.initializers.SSCI(train_ds, noise=0.1), + reasonings_initializer=pt.initializers. PurePositiveReasoningsInitializer(), ) # Callbacks - vis = pt.models.VisCBC2D(data=train_ds, - title="CBC Iris Example", - resolution=100, - axis_off=True) + vis = VisCBC2D( + data=train_ds, + title="CBC Iris Example", + resolution=100, + axis_off=True, + ) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - callbacks=[vis], + callbacks=[ + vis, + ], + detect_anomaly=True, + log_every_n_steps=1, + max_epochs=1000, ) # Training loop diff --git a/examples/dynamic_pruning.py b/examples/dynamic_pruning.py index 454b66e..275cdff 100644 --- a/examples/dynamic_pruning.py +++ b/examples/dynamic_pruning.py @@ -1,12 +1,29 @@ """Dynamically prune 'loser' prototypes in GLVQ-type models.""" import argparse +import logging +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import ( + CELVQ, + PruneLoserPrototypes, + VisGLVQ2D, +) +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": + # Reproducibility + seed_everything(seed=4) + # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) @@ -16,15 +33,17 @@ if __name__ == "__main__": num_classes = 4 num_features = 2 num_clusters = 1 - train_ds = pt.datasets.Random(num_samples=500, - num_classes=num_classes, - num_features=num_features, - num_clusters=num_clusters, - separation=3.0, - seed=42) + train_ds = pt.datasets.Random( + num_samples=500, + num_classes=num_classes, + num_features=num_features, + num_clusters=num_clusters, + separation=3.0, + seed=42, + ) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256) + train_loader = DataLoader(train_ds, batch_size=256) # Hyperparameters prototypes_per_class = num_clusters * 5 @@ -34,7 +53,7 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.CELVQ( + model = CELVQ( hparams, prototypes_initializer=pt.initializers.FVCI(2, 3.0), ) @@ -43,18 +62,18 @@ if __name__ == "__main__": model.example_input_array = torch.zeros(4, 2) # Summary - print(model) + logging.info(model) # Callbacks - vis = pt.models.VisGLVQ2D(train_ds) - pruning = pt.models.PruneLoserPrototypes( + vis = VisGLVQ2D(train_ds) + pruning = PruneLoserPrototypes( threshold=0.01, # prune prototype if it wins less than 1% idle_epochs=20, # pruning too early may cause problems prune_quota_per_epoch=2, # prune at most 2 prototypes per epoch frequency=1, # prune every epoch verbose=True, ) - es = pl.callbacks.EarlyStopping( + es = EarlyStopping( monitor="train_loss", min_delta=0.001, patience=20, @@ -71,10 +90,9 @@ if __name__ == "__main__": pruning, es, ], - progress_bar_refresh_rate=0, - terminate_on_nan=True, - weights_summary="full", - accelerator="ddp", + detect_anomaly=True, + log_every_n_steps=1, + max_epochs=1000, ) # Training loop diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index c2a4114..d23e82a 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -1,13 +1,24 @@ """GLVQ example using the Iris dataset.""" import argparse +import logging +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import GLVQ, VisGLVQ2D +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.optim.lr_scheduler import ExponentialLR +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=PossibleUserWarning) if __name__ == "__main__": + # Reproducibility + seed_everything(seed=4) # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) @@ -17,7 +28,7 @@ if __name__ == "__main__": train_ds = pt.datasets.Iris(dims=[0, 2]) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) + train_loader = DataLoader(train_ds, batch_size=64, num_workers=4) # Hyperparameters hparams = dict( @@ -29,7 +40,7 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.GLVQ( + model = GLVQ( hparams, optimizer=torch.optim.Adam, prototypes_initializer=pt.initializers.SMCI(train_ds), @@ -41,13 +52,17 @@ if __name__ == "__main__": model.example_input_array = torch.zeros(4, 2) # Callbacks - vis = pt.models.VisGLVQ2D(data=train_ds) + vis = VisGLVQ2D(data=train_ds) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - callbacks=[vis], - weights_summary="full", + callbacks=[ + vis, + ], + max_epochs=100, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop @@ -57,8 +72,8 @@ if __name__ == "__main__": trainer.save_checkpoint("./glvq_iris.ckpt") # Load saved model - new_model = pt.models.GLVQ.load_from_checkpoint( + new_model = GLVQ.load_from_checkpoint( checkpoint_path="./glvq_iris.ckpt", strict=False, ) - print(new_model) + logging.info(new_model) diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py index f326f8b..718a129 100644 --- a/examples/gmlvq_iris.py +++ b/examples/gmlvq_iris.py @@ -1,13 +1,25 @@ """GMLVQ example using the Iris dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import GMLVQ, VisGMLVQ2D +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.optim.lr_scheduler import ExponentialLR +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": + + # Reproducibility + seed_everything(seed=4) + # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) @@ -17,7 +29,7 @@ if __name__ == "__main__": train_ds = pt.datasets.Iris() # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) + train_loader = DataLoader(train_ds, batch_size=64) # Hyperparameters hparams = dict( @@ -32,7 +44,7 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.GMLVQ( + model = GMLVQ( hparams, optimizer=torch.optim.Adam, prototypes_initializer=pt.initializers.SMCI(train_ds), @@ -44,14 +56,17 @@ if __name__ == "__main__": model.example_input_array = torch.zeros(4, 4) # Callbacks - vis = pt.models.VisGMLVQ2D(data=train_ds) + vis = VisGMLVQ2D(data=train_ds) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - callbacks=[vis], - weights_summary="full", - accelerator="ddp", + callbacks=[ + vis, + ], + max_epochs=100, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/gmlvq_mnist.py b/examples/gmlvq_mnist.py index 54aae2e..64eef26 100644 --- a/examples/gmlvq_mnist.py +++ b/examples/gmlvq_mnist.py @@ -1,14 +1,29 @@ """GMLVQ example using the MNIST dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import ( + ImageGMLVQ, + PruneLoserPrototypes, + VisImgComp, +) +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import MNIST +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) + if __name__ == "__main__": + # Reproducibility + seed_everything(seed=4) # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) @@ -33,12 +48,8 @@ if __name__ == "__main__": ) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - num_workers=0, - batch_size=256) - test_loader = torch.utils.data.DataLoader(test_ds, - num_workers=0, - batch_size=256) + train_loader = DataLoader(train_ds, num_workers=4, batch_size=256) + test_loader = DataLoader(test_ds, num_workers=4, batch_size=256) # Hyperparameters num_classes = 10 @@ -52,14 +63,14 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.ImageGMLVQ( + model = ImageGMLVQ( hparams, optimizer=torch.optim.Adam, prototypes_initializer=pt.initializers.SMCI(train_ds), ) # Callbacks - vis = pt.models.VisImgComp( + vis = VisImgComp( data=train_ds, num_columns=10, show=False, @@ -69,14 +80,14 @@ if __name__ == "__main__": embedding_data=200, flatten_data=False, ) - pruning = pt.models.PruneLoserPrototypes( + pruning = PruneLoserPrototypes( threshold=0.01, idle_epochs=1, prune_quota_per_epoch=10, frequency=1, verbose=True, ) - es = pl.callbacks.EarlyStopping( + es = EarlyStopping( monitor="train_loss", min_delta=0.001, patience=15, @@ -90,11 +101,11 @@ if __name__ == "__main__": callbacks=[ vis, pruning, - # es, + es, ], - terminate_on_nan=True, - weights_summary=None, - # accelerator="ddp", + max_epochs=1000, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/gmlvq_spiral.py b/examples/gmlvq_spiral.py index 68c7e3b..4cd2e7d 100644 --- a/examples/gmlvq_spiral.py +++ b/examples/gmlvq_spiral.py @@ -1,12 +1,28 @@ """GMLVQ example using the spiral dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import ( + GMLVQ, + PruneLoserPrototypes, + VisGLVQ2D, +) +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": + # Reproducibility + seed_everything(seed=4) + # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) @@ -16,7 +32,7 @@ if __name__ == "__main__": train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256) + train_loader = DataLoader(train_ds, batch_size=256) # Hyperparameters num_classes = 2 @@ -32,19 +48,19 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.GMLVQ( + model = GMLVQ( hparams, optimizer=torch.optim.Adam, prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2), ) # Callbacks - vis = pt.models.VisGLVQ2D( + vis = VisGLVQ2D( train_ds, show_last_only=False, block=False, ) - pruning = pt.models.PruneLoserPrototypes( + pruning = PruneLoserPrototypes( threshold=0.01, idle_epochs=10, prune_quota_per_epoch=5, @@ -53,7 +69,7 @@ if __name__ == "__main__": prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1), verbose=True, ) - es = pl.callbacks.EarlyStopping( + es = EarlyStopping( monitor="train_loss", min_delta=1.0, patience=5, @@ -69,7 +85,9 @@ if __name__ == "__main__": es, pruning, ], - terminate_on_nan=True, + max_epochs=1000, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/gng_iris.py b/examples/gng_iris.py index 7f1275d..0783b04 100644 --- a/examples/gng_iris.py +++ b/examples/gng_iris.py @@ -1,10 +1,19 @@ """Growing Neural Gas example using the Iris dataset.""" import argparse +import logging +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import GrowingNeuralGas, VisNG2D +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": # Command-line arguments @@ -13,11 +22,11 @@ if __name__ == "__main__": args = parser.parse_args() # Reproducibility - pl.utilities.seed.seed_everything(seed=42) + seed_everything(seed=42) # Prepare the data train_ds = pt.datasets.Iris(dims=[0, 2]) - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) + train_loader = DataLoader(train_ds, batch_size=64) # Hyperparameters hparams = dict( @@ -27,7 +36,7 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.GrowingNeuralGas( + model = GrowingNeuralGas( hparams, prototypes_initializer=pt.initializers.ZCI(2), ) @@ -36,17 +45,20 @@ if __name__ == "__main__": model.example_input_array = torch.zeros(4, 2) # Model summary - print(model) + logging.info(model) # Callbacks - vis = pt.models.VisNG2D(data=train_loader) + vis = VisNG2D(data=train_loader) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, + callbacks=[ + vis, + ], max_epochs=100, - callbacks=[vis], - weights_summary="full", + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/gtlvq_mnist.py b/examples/gtlvq_mnist.py index 481065a..e5af4eb 100644 --- a/examples/gtlvq_mnist.py +++ b/examples/gtlvq_mnist.py @@ -1,14 +1,30 @@ """GTLVQ example using the MNIST dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import ( + ImageGTLVQ, + PruneLoserPrototypes, + VisImgComp, +) +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import MNIST +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) + if __name__ == "__main__": + # Reproducibility + seed_everything(seed=4) + # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) @@ -33,12 +49,8 @@ if __name__ == "__main__": ) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - num_workers=0, - batch_size=256) - test_loader = torch.utils.data.DataLoader(test_ds, - num_workers=0, - batch_size=256) + train_loader = DataLoader(train_ds, num_workers=0, batch_size=256) + test_loader = DataLoader(test_ds, num_workers=0, batch_size=256) # Hyperparameters num_classes = 10 @@ -52,7 +64,7 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.ImageGTLVQ( + model = ImageGTLVQ( hparams, optimizer=torch.optim.Adam, prototypes_initializer=pt.initializers.SMCI(train_ds), @@ -61,7 +73,7 @@ if __name__ == "__main__": next(iter(train_loader))[0].reshape(256, 28 * 28))) # Callbacks - vis = pt.models.VisImgComp( + vis = VisImgComp( data=train_ds, num_columns=10, show=False, @@ -71,14 +83,14 @@ if __name__ == "__main__": embedding_data=200, flatten_data=False, ) - pruning = pt.models.PruneLoserPrototypes( + pruning = PruneLoserPrototypes( threshold=0.01, idle_epochs=1, prune_quota_per_epoch=10, frequency=1, verbose=True, ) - es = pl.callbacks.EarlyStopping( + es = EarlyStopping( monitor="train_loss", min_delta=0.001, patience=15, @@ -93,11 +105,11 @@ if __name__ == "__main__": callbacks=[ vis, pruning, - # es, + es, ], - terminate_on_nan=True, - weights_summary=None, - accelerator="ddp", + max_epochs=1000, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/gtlvq_moons.py b/examples/gtlvq_moons.py index 79ff32f..7e8368b 100644 --- a/examples/gtlvq_moons.py +++ b/examples/gtlvq_moons.py @@ -1,10 +1,20 @@ """Localized-GTLVQ example using the Moons dataset.""" import argparse +import logging +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import GTLVQ, VisGLVQ2D +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": # Command-line arguments @@ -13,33 +23,35 @@ if __name__ == "__main__": args = parser.parse_args() # Reproducibility - pl.utilities.seed.seed_everything(seed=2) + seed_everything(seed=2) # Dataset train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - batch_size=256, - shuffle=True) + train_loader = DataLoader( + train_ds, + batch_size=256, + shuffle=True, + ) # Hyperparameters # Latent_dim should be lower than input dim. hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1) # Initialize the model - model = pt.models.GTLVQ( - hparams, prototypes_initializer=pt.initializers.SMCI(train_ds)) + model = GTLVQ(hparams, + prototypes_initializer=pt.initializers.SMCI(train_ds)) # Compute intermediate input and output sizes model.example_input_array = torch.zeros(4, 2) # Summary - print(model) + logging.info(model) # Callbacks - vis = pt.models.VisGLVQ2D(data=train_ds) - es = pl.callbacks.EarlyStopping( + vis = VisGLVQ2D(data=train_ds) + es = EarlyStopping( monitor="train_acc", min_delta=0.001, patience=20, @@ -55,8 +67,9 @@ if __name__ == "__main__": vis, es, ], - weights_summary="full", - accelerator="ddp", + max_epochs=1000, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/knn_iris.py b/examples/knn_iris.py index 85dd187..c1b3279 100644 --- a/examples/knn_iris.py +++ b/examples/knn_iris.py @@ -1,12 +1,19 @@ """k-NN example using the Iris dataset from scikit-learn.""" import argparse +import logging +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import KNN, VisGLVQ2D +from pytorch_lightning.utilities.warnings import PossibleUserWarning from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) if __name__ == "__main__": # Command-line arguments @@ -16,34 +23,36 @@ if __name__ == "__main__": # Dataset X, y = load_iris(return_X_y=True) - X = X[:, [0, 2]] + X = X[:, 0:3:2] - X_train, X_test, y_train, y_test = train_test_split(X, - y, - test_size=0.5, - random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, + y, + test_size=0.5, + random_state=42, + ) train_ds = pt.datasets.NumpyDataset(X_train, y_train) test_ds = pt.datasets.NumpyDataset(X_test, y_test) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=16) - test_loader = torch.utils.data.DataLoader(test_ds, batch_size=16) + train_loader = DataLoader(train_ds, batch_size=16) + test_loader = DataLoader(test_ds, batch_size=16) # Hyperparameters hparams = dict(k=5) # Initialize the model - model = pt.models.KNN(hparams, data=train_ds) + model = KNN(hparams, data=train_ds) # Compute intermediate input and output sizes model.example_input_array = torch.zeros(4, 2) # Summary - print(model) + logging.info(model) # Callbacks - vis = pt.models.VisGLVQ2D( + vis = VisGLVQ2D( data=(X_train, y_train), resolution=200, block=True, @@ -53,8 +62,11 @@ if __name__ == "__main__": trainer = pl.Trainer.from_argparse_args( args, max_epochs=1, - callbacks=[vis], - weights_summary="full", + callbacks=[ + vis, + ], + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop @@ -63,7 +75,7 @@ if __name__ == "__main__": # Recall y_pred = model.predict(torch.tensor(X_train)) - print(y_pred) + logging.info(y_pred) # Test trainer.test(model, dataloaders=test_loader) diff --git a/examples/ksom_colors.py b/examples/ksom_colors.py index f86d3ac..bd2ded0 100644 --- a/examples/ksom_colors.py +++ b/examples/ksom_colors.py @@ -1,12 +1,21 @@ """Kohonen Self Organizing Map.""" import argparse +import logging +import warnings import prototorch as pt import pytorch_lightning as pl import torch from matplotlib import pyplot as plt +from prototorch.models import KohonenSOM from prototorch.utils.colors import hex_to_rgb +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader, TensorDataset + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) class Vis2DColorSOM(pl.Callback): @@ -18,7 +27,7 @@ class Vis2DColorSOM(pl.Callback): self.data = data self.pause_time = pause_time - def on_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, pl_module: KohonenSOM): ax = self.fig.gca() ax.cla() ax.set_title(self.title) @@ -31,12 +40,14 @@ class Vis2DColorSOM(pl.Callback): d = pl_module.compute_distances(self.data) wp = pl_module.predict_from_distances(d) for i, iloc in enumerate(wp): - plt.text(iloc[1], - iloc[0], - cnames[i], - ha="center", - va="center", - bbox=dict(facecolor="white", alpha=0.5, lw=0)) + plt.text( + iloc[1], + iloc[0], + color_names[i], + ha="center", + va="center", + bbox=dict(facecolor="white", alpha=0.5, lw=0), + ) if trainer.current_epoch != trainer.max_epochs - 1: plt.pause(self.pause_time) @@ -51,7 +62,7 @@ if __name__ == "__main__": args = parser.parse_args() # Reproducibility - pl.utilities.seed.seed_everything(seed=42) + seed_everything(seed=42) # Prepare the data hex_colors = [ @@ -59,15 +70,15 @@ if __name__ == "__main__": "#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff", "#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500" ] - cnames = [ + color_names = [ "black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green", "red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey", "lightgrey", "olive", "purple", "orange" ] colors = list(hex_to_rgb(hex_colors)) data = torch.Tensor(colors) / 255.0 - train_ds = torch.utils.data.TensorDataset(data) - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8) + train_ds = TensorDataset(data) + train_loader = DataLoader(train_ds, batch_size=8) # Hyperparameters hparams = dict( @@ -78,7 +89,7 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.KohonenSOM( + model = KohonenSOM( hparams, prototypes_initializer=pt.initializers.RNCI(3), ) @@ -87,7 +98,7 @@ if __name__ == "__main__": model.example_input_array = torch.zeros(4, 3) # Model summary - print(model) + logging.info(model) # Callbacks vis = Vis2DColorSOM(data=data) @@ -96,8 +107,11 @@ if __name__ == "__main__": trainer = pl.Trainer.from_argparse_args( args, max_epochs=500, - callbacks=[vis], - weights_summary="full", + callbacks=[ + vis, + ], + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/lgmlvq_moons.py b/examples/lgmlvq_moons.py index cbdc9b4..155bf10 100644 --- a/examples/lgmlvq_moons.py +++ b/examples/lgmlvq_moons.py @@ -1,10 +1,20 @@ """Localized-GMLVQ example using the Moons dataset.""" import argparse +import logging +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import LGMLVQ, VisGLVQ2D +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": # Command-line arguments @@ -13,15 +23,13 @@ if __name__ == "__main__": args = parser.parse_args() # Reproducibility - pl.utilities.seed.seed_everything(seed=2) + seed_everything(seed=2) # Dataset train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, - batch_size=256, - shuffle=True) + train_loader = DataLoader(train_ds, batch_size=256, shuffle=True) # Hyperparameters hparams = dict( @@ -31,7 +39,7 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.LGMLVQ( + model = LGMLVQ( hparams, prototypes_initializer=pt.initializers.SMCI(train_ds), ) @@ -40,11 +48,11 @@ if __name__ == "__main__": model.example_input_array = torch.zeros(4, 2) # Summary - print(model) + logging.info(model) # Callbacks - vis = pt.models.VisGLVQ2D(data=train_ds) - es = pl.callbacks.EarlyStopping( + vis = VisGLVQ2D(data=train_ds) + es = EarlyStopping( monitor="train_acc", min_delta=0.001, patience=20, @@ -60,8 +68,9 @@ if __name__ == "__main__": vis, es, ], - weights_summary="full", - accelerator="ddp", + log_every_n_steps=1, + max_epochs=1000, + detect_anomaly=True, ) # Training loop diff --git a/examples/lvqmln_iris.py b/examples/lvqmln_iris.py index 6a6023c..42c8b7b 100644 --- a/examples/lvqmln_iris.py +++ b/examples/lvqmln_iris.py @@ -1,10 +1,22 @@ """LVQMLN example using all four dimensions of the Iris dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import ( + LVQMLN, + PruneLoserPrototypes, + VisSiameseGLVQ2D, +) +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) class Backbone(torch.nn.Module): @@ -34,10 +46,10 @@ if __name__ == "__main__": train_ds = pt.datasets.Iris() # Reproducibility - pl.utilities.seed.seed_everything(seed=42) + seed_everything(seed=42) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) + train_loader = DataLoader(train_ds, batch_size=150) # Hyperparameters hparams = dict( @@ -50,7 +62,7 @@ if __name__ == "__main__": backbone = Backbone() # Initialize the model - model = pt.models.LVQMLN( + model = LVQMLN( hparams, prototypes_initializer=pt.initializers.SSCI( train_ds, @@ -59,18 +71,15 @@ if __name__ == "__main__": backbone=backbone, ) - # Model summary - print(model) - # Callbacks - vis = pt.models.VisSiameseGLVQ2D( + vis = VisSiameseGLVQ2D( data=train_ds, map_protos=False, border=0.1, resolution=500, axis_off=True, ) - pruning = pt.models.PruneLoserPrototypes( + pruning = PruneLoserPrototypes( threshold=0.01, idle_epochs=20, prune_quota_per_epoch=2, @@ -85,6 +94,9 @@ if __name__ == "__main__": vis, pruning, ], + log_every_n_steps=1, + max_epochs=1000, + detect_anomaly=True, ) # Training loop diff --git a/examples/median_lvq_iris.py b/examples/median_lvq_iris.py index 44152b7..d709f80 100644 --- a/examples/median_lvq_iris.py +++ b/examples/median_lvq_iris.py @@ -1,12 +1,23 @@ """Median-LVQ example using the Iris dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import MedianLVQ, VisGLVQ2D +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": + # Reproducibility + seed_everything(seed=4) # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) @@ -16,13 +27,13 @@ if __name__ == "__main__": train_ds = pt.datasets.Iris(dims=[0, 2]) # Dataloaders - train_loader = torch.utils.data.DataLoader( + train_loader = DataLoader( train_ds, batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches ) # Initialize the model - model = pt.models.MedianLVQ( + model = MedianLVQ( hparams=dict(distribution=(3, 2), lr=0.01), prototypes_initializer=pt.initializers.SSCI(train_ds), ) @@ -31,8 +42,8 @@ if __name__ == "__main__": model.example_input_array = torch.zeros(4, 2) # Callbacks - vis = pt.models.VisGLVQ2D(data=train_ds) - es = pl.callbacks.EarlyStopping( + vis = VisGLVQ2D(data=train_ds) + es = EarlyStopping( monitor="train_acc", min_delta=0.01, patience=5, @@ -44,8 +55,13 @@ if __name__ == "__main__": # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - callbacks=[vis, es], - weights_summary="full", + callbacks=[ + vis, + es, + ], + max_epochs=1000, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/ng_iris.py b/examples/ng_iris.py index fe5d7b3..c897fd0 100644 --- a/examples/ng_iris.py +++ b/examples/ng_iris.py @@ -1,15 +1,26 @@ """Neural Gas example using the Iris dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import NeuralGas, VisNG2D +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler from torch.optim.lr_scheduler import ExponentialLR +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": + # Reproducibility + seed_everything(seed=4) + # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) @@ -17,7 +28,7 @@ if __name__ == "__main__": # Prepare and pre-process the dataset x_train, y_train = load_iris(return_X_y=True) - x_train = x_train[:, [0, 2]] + x_train = x_train[:, 0:3:2] scaler = StandardScaler() scaler.fit(x_train) x_train = scaler.transform(x_train) @@ -25,7 +36,7 @@ if __name__ == "__main__": train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) + train_loader = DataLoader(train_ds, batch_size=150) # Hyperparameters hparams = dict( @@ -35,7 +46,7 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.NeuralGas( + model = NeuralGas( hparams, prototypes_initializer=pt.core.ZCI(2), lr_scheduler=ExponentialLR, @@ -45,17 +56,18 @@ if __name__ == "__main__": # Compute intermediate input and output sizes model.example_input_array = torch.zeros(4, 2) - # Model summary - print(model) - # Callbacks - vis = pt.models.VisNG2D(data=train_ds) + vis = VisNG2D(data=train_ds) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - callbacks=[vis], - weights_summary="full", + callbacks=[ + vis, + ], + max_epochs=1000, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/rslvq_iris.py b/examples/rslvq_iris.py index c7d3961..c561cbe 100644 --- a/examples/rslvq_iris.py +++ b/examples/rslvq_iris.py @@ -1,10 +1,18 @@ """RSLVQ example using the Iris dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import RSLVQ, VisGLVQ2D +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": # Command-line arguments @@ -13,13 +21,13 @@ if __name__ == "__main__": args = parser.parse_args() # Reproducibility - pl.utilities.seed.seed_everything(seed=42) + seed_everything(seed=42) # Dataset train_ds = pt.datasets.Iris(dims=[0, 2]) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) + train_loader = DataLoader(train_ds, batch_size=64) # Hyperparameters hparams = dict( @@ -33,7 +41,7 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.RSLVQ( + model = RSLVQ( hparams, optimizer=torch.optim.Adam, prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2), @@ -42,19 +50,18 @@ if __name__ == "__main__": # Compute intermediate input and output sizes model.example_input_array = torch.zeros(4, 2) - # Summary - print(model) - # Callbacks - vis = pt.models.VisGLVQ2D(data=train_ds) + vis = VisGLVQ2D(data=train_ds) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - callbacks=[vis], - terminate_on_nan=True, - weights_summary="full", - accelerator="ddp", + callbacks=[ + vis, + ], + detect_anomaly=True, + max_epochs=100, + log_every_n_steps=1, ) # Training loop diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index e7a297b..19f990e 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -1,10 +1,18 @@ """Siamese GLVQ example using all four dimensions of the Iris dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) class Backbone(torch.nn.Module): @@ -34,10 +42,10 @@ if __name__ == "__main__": train_ds = pt.datasets.Iris() # Reproducibility - pl.utilities.seed.seed_everything(seed=2) + seed_everything(seed=2) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) + train_loader = DataLoader(train_ds, batch_size=150) # Hyperparameters hparams = dict( @@ -50,23 +58,25 @@ if __name__ == "__main__": backbone = Backbone() # Initialize the model - model = pt.models.SiameseGLVQ( + model = SiameseGLVQ( hparams, prototypes_initializer=pt.initializers.SMCI(train_ds), backbone=backbone, both_path_gradients=False, ) - # Model summary - print(model) - # Callbacks - vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1) + vis = VisSiameseGLVQ2D(data=train_ds, border=0.1) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - callbacks=[vis], + callbacks=[ + vis, + ], + max_epochs=1000, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/siamese_gtlvq_iris.py b/examples/siamese_gtlvq_iris.py index 455c0fb..16db174 100644 --- a/examples/siamese_gtlvq_iris.py +++ b/examples/siamese_gtlvq_iris.py @@ -1,10 +1,18 @@ """Siamese GTLVQ example using all four dimensions of the Iris dataset.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) class Backbone(torch.nn.Module): @@ -34,39 +42,43 @@ if __name__ == "__main__": train_ds = pt.datasets.Iris() # Reproducibility - pl.utilities.seed.seed_everything(seed=2) + seed_everything(seed=2) # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) + train_loader = DataLoader(train_ds, batch_size=150) # Hyperparameters - hparams = dict(distribution=[1, 2, 3], - proto_lr=0.01, - bb_lr=0.01, - input_dim=2, - latent_dim=1) + hparams = dict( + distribution=[1, 2, 3], + proto_lr=0.01, + bb_lr=0.01, + input_dim=2, + latent_dim=1, + ) # Initialize the backbone backbone = Backbone(latent_size=hparams["input_dim"]) # Initialize the model - model = pt.models.SiameseGTLVQ( + model = SiameseGTLVQ( hparams, prototypes_initializer=pt.initializers.SMCI(train_ds), backbone=backbone, both_path_gradients=False, ) - # Model summary - print(model) - # Callbacks - vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1) + vis = VisSiameseGLVQ2D(data=train_ds, border=0.1) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - callbacks=[vis], + callbacks=[ + vis, + ], + max_epochs=1000, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop diff --git a/examples/warm_starting.py b/examples/warm_starting.py index 8b3415c..98d0668 100644 --- a/examples/warm_starting.py +++ b/examples/warm_starting.py @@ -1,13 +1,30 @@ """Warm-starting GLVQ with prototypes from Growing Neural Gas.""" import argparse +import warnings import prototorch as pt import pytorch_lightning as pl import torch +from prototorch.models import ( + GLVQ, + KNN, + GrowingNeuralGas, + PruneLoserPrototypes, + VisGLVQ2D, +) +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.optim.lr_scheduler import ExponentialLR +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) if __name__ == "__main__": + + # Reproducibility + seed_everything(seed=4) # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) @@ -15,10 +32,10 @@ if __name__ == "__main__": # Prepare the data train_ds = pt.datasets.Iris(dims=[0, 2]) - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) + train_loader = DataLoader(train_ds, batch_size=64, num_workers=0) # Initialize the gng - gng = pt.models.GrowingNeuralGas( + gng = GrowingNeuralGas( hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1), prototypes_initializer=pt.initializers.ZCI(2), lr_scheduler=ExponentialLR, @@ -26,7 +43,7 @@ if __name__ == "__main__": ) # Callbacks - es = pl.callbacks.EarlyStopping( + es = EarlyStopping( monitor="loss", min_delta=0.001, patience=20, @@ -37,9 +54,12 @@ if __name__ == "__main__": # Setup trainer for GNG trainer = pl.Trainer( - max_epochs=100, - callbacks=[es], - weights_summary=None, + max_epochs=1000, + callbacks=[ + es, + ], + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop @@ -52,12 +72,12 @@ if __name__ == "__main__": ) # Warm-start prototypes - knn = pt.models.KNN(dict(k=1), data=train_ds) + knn = KNN(dict(k=1), data=train_ds) prototypes = gng.prototypes plabels = knn.predict(prototypes) # Initialize the model - model = pt.models.GLVQ( + model = GLVQ( hparams, optimizer=torch.optim.Adam, prototypes_initializer=pt.initializers.LCI(prototypes), @@ -70,15 +90,15 @@ if __name__ == "__main__": model.example_input_array = torch.zeros(4, 2) # Callbacks - vis = pt.models.VisGLVQ2D(data=train_ds) - pruning = pt.models.PruneLoserPrototypes( + vis = VisGLVQ2D(data=train_ds) + pruning = PruneLoserPrototypes( threshold=0.02, idle_epochs=2, prune_quota_per_epoch=5, frequency=1, verbose=True, ) - es = pl.callbacks.EarlyStopping( + es = EarlyStopping( monitor="train_loss", min_delta=0.001, patience=10, @@ -95,8 +115,9 @@ if __name__ == "__main__": pruning, es, ], - weights_summary="full", - accelerator="ddp", + max_epochs=1000, + log_every_n_steps=1, + detect_anomaly=True, ) # Training loop