From a14e3aa611a1661fce9de1e0539765b8aac49daf Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 18 May 2021 10:17:51 +0200 Subject: [PATCH] Add argparse to mnist example script --- examples/gmlvq_mnist.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/examples/gmlvq_mnist.py b/examples/gmlvq_mnist.py index 0219101..a1bf43a 100644 --- a/examples/gmlvq_mnist.py +++ b/examples/gmlvq_mnist.py @@ -1,5 +1,7 @@ """GMLVQ example using the MNIST dataset.""" +import argparse + import prototorch as pt import pytorch_lightning as pl import torch @@ -7,6 +9,11 @@ from torchvision import transforms from torchvision.datasets import MNIST if __name__ == "__main__": + # Command-line arguments + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + # Dataset train_ds = MNIST( "~/datasets", @@ -40,7 +47,8 @@ if __name__ == "__main__": input_dim=28 * 28, latent_dim=28 * 28, distribution=(nclasses, prototypes_per_class), - lr=0.01, + proto_lr=0.01, + bb_lr=0.01, ) # Initialize the model @@ -51,17 +59,19 @@ if __name__ == "__main__": ) # Callbacks - vis = pt.models.VisImgComp(data=train_ds, - nrow=5, - show=True, - tensorboard=True, - pause_time=0.5) + vis = pt.models.VisImgComp( + data=train_ds, + nrow=5, + show=False, + tensorboard=True, + ) # Setup trainer - trainer = pl.Trainer( - max_epochs=50, + trainer = pl.Trainer.from_argparse_args( + args, callbacks=[vis], - gpus=0, + # kwargs override the cli-arguments + # max_epochs=50, # overfit_batches=1, # fast_dev_run=3, )