diff --git a/examples/glvq_mnist.py b/examples/glvq_mnist.py index 875fbfa..a98c958 100644 --- a/examples/glvq_mnist.py +++ b/examples/glvq_mnist.py @@ -1,8 +1,14 @@ """GLVQ example using the MNIST dataset. +TODO +- Add model serialization/deserialization +- Add evaluation metrics + This script also shows how to use Tensorboard for visualizing the prototypes. """ +import argparse + import pytorch_lightning as pl import torchvision from matplotlib import pyplot as plt @@ -32,6 +38,30 @@ class VisualizationCallback(pl.Callback): if __name__ == "__main__": + # Arguments + parser = argparse.ArgumentParser() + parser.add_argument("--epochs", + type=int, + default=10, + help="Epochs to train.") + parser.add_argument("--lr", + type=float, + default=0.001, + help="Learning rate.") + parser.add_argument("--batch_size", + type=int, + default=256, + help="Batch size.") + parser.add_argument("--gpus", + type=int, + default=0, + help="Number of GPUs to use.") + parser.add_argument("--ppc", + type=int, + default=1, + help="Prototypes-Per-Class.") + args = parser.parse_args() + # Dataset mnist_train = MNIST( "./datasets", @@ -63,19 +93,19 @@ if __name__ == "__main__": # Initialize the model model = ImageGLVQ(input_dim=28 * 28, nclasses=10, - prototypes_per_class=10, + prototypes_per_class=args.ppc, prototype_initializer="stratified_mean", data=[x, y]) # Model summary print(model) # Callbacks - vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=10) + vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc) # Setup trainer trainer = pl.Trainer( - gpus=0, # change to use GPUs for training - max_epochs=10, + gpus=args.gpus, # change to use GPUs for training + max_epochs=args.epochs, callbacks=[vis], # accelerator="ddp_cpu", # DEBUG-ONLY # num_processes=2, # DEBUG-ONLY