Add argparse to mnist example script

This commit is contained in:
Jensun Ravichandran 2021-05-18 10:17:51 +02:00
parent 00cdacf7ae
commit a14e3aa611

View File

@ -1,5 +1,7 @@
"""GMLVQ example using the MNIST dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
@ -7,6 +9,11 @@ from torchvision import transforms
from torchvision.datasets import MNIST
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = MNIST(
"~/datasets",
@ -40,7 +47,8 @@ if __name__ == "__main__":
input_dim=28 * 28,
latent_dim=28 * 28,
distribution=(nclasses, prototypes_per_class),
lr=0.01,
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
@ -51,17 +59,19 @@ if __name__ == "__main__":
)
# Callbacks
vis = pt.models.VisImgComp(data=train_ds,
vis = pt.models.VisImgComp(
data=train_ds,
nrow=5,
show=True,
show=False,
tensorboard=True,
pause_time=0.5)
)
# Setup trainer
trainer = pl.Trainer(
max_epochs=50,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
gpus=0,
# kwargs override the cli-arguments
# max_epochs=50,
# overfit_batches=1,
# fast_dev_run=3,
)