Add argparse to mnist example script
This commit is contained in:
parent
00cdacf7ae
commit
a14e3aa611
@ -1,5 +1,7 @@
|
|||||||
"""GMLVQ example using the MNIST dataset."""
|
"""GMLVQ example using the MNIST dataset."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
@ -7,6 +9,11 @@ from torchvision import transforms
|
|||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = MNIST(
|
train_ds = MNIST(
|
||||||
"~/datasets",
|
"~/datasets",
|
||||||
@ -40,7 +47,8 @@ if __name__ == "__main__":
|
|||||||
input_dim=28 * 28,
|
input_dim=28 * 28,
|
||||||
latent_dim=28 * 28,
|
latent_dim=28 * 28,
|
||||||
distribution=(nclasses, prototypes_per_class),
|
distribution=(nclasses, prototypes_per_class),
|
||||||
lr=0.01,
|
proto_lr=0.01,
|
||||||
|
bb_lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
@ -51,17 +59,19 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisImgComp(data=train_ds,
|
vis = pt.models.VisImgComp(
|
||||||
nrow=5,
|
data=train_ds,
|
||||||
show=True,
|
nrow=5,
|
||||||
tensorboard=True,
|
show=False,
|
||||||
pause_time=0.5)
|
tensorboard=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
max_epochs=50,
|
args,
|
||||||
callbacks=[vis],
|
callbacks=[vis],
|
||||||
gpus=0,
|
# kwargs override the cli-arguments
|
||||||
|
# max_epochs=50,
|
||||||
# overfit_batches=1,
|
# overfit_batches=1,
|
||||||
# fast_dev_run=3,
|
# fast_dev_run=3,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user