Add argparse to mnist example script

This commit is contained in:
Jensun Ravichandran 2021-05-18 10:17:51 +02:00
parent 00cdacf7ae
commit a14e3aa611

View File

@ -1,5 +1,7 @@
"""GMLVQ example using the MNIST dataset.""" """GMLVQ example using the MNIST dataset."""
import argparse
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
@ -7,6 +9,11 @@ from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset # Dataset
train_ds = MNIST( train_ds = MNIST(
"~/datasets", "~/datasets",
@ -40,7 +47,8 @@ if __name__ == "__main__":
input_dim=28 * 28, input_dim=28 * 28,
latent_dim=28 * 28, latent_dim=28 * 28,
distribution=(nclasses, prototypes_per_class), distribution=(nclasses, prototypes_per_class),
lr=0.01, proto_lr=0.01,
bb_lr=0.01,
) )
# Initialize the model # Initialize the model
@ -51,17 +59,19 @@ if __name__ == "__main__":
) )
# Callbacks # Callbacks
vis = pt.models.VisImgComp(data=train_ds, vis = pt.models.VisImgComp(
nrow=5, data=train_ds,
show=True, nrow=5,
tensorboard=True, show=False,
pause_time=0.5) tensorboard=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer.from_argparse_args(
max_epochs=50, args,
callbacks=[vis], callbacks=[vis],
gpus=0, # kwargs override the cli-arguments
# max_epochs=50,
# overfit_batches=1, # overfit_batches=1,
# fast_dev_run=3, # fast_dev_run=3,
) )