Add argparse to mnist example script
This commit is contained in:
parent
00cdacf7ae
commit
a14e3aa611
@ -1,5 +1,7 @@
|
||||
"""GMLVQ example using the MNIST dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
@ -7,6 +9,11 @@ from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = MNIST(
|
||||
"~/datasets",
|
||||
@ -40,7 +47,8 @@ if __name__ == "__main__":
|
||||
input_dim=28 * 28,
|
||||
latent_dim=28 * 28,
|
||||
distribution=(nclasses, prototypes_per_class),
|
||||
lr=0.01,
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
@ -51,17 +59,19 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisImgComp(data=train_ds,
|
||||
nrow=5,
|
||||
show=True,
|
||||
tensorboard=True,
|
||||
pause_time=0.5)
|
||||
vis = pt.models.VisImgComp(
|
||||
data=train_ds,
|
||||
nrow=5,
|
||||
show=False,
|
||||
tensorboard=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=50,
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
gpus=0,
|
||||
# kwargs override the cli-arguments
|
||||
# max_epochs=50,
|
||||
# overfit_batches=1,
|
||||
# fast_dev_run=3,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user