Update mnist example

This commit is contained in:
Jensun Ravichandran 2021-04-21 16:28:20 +02:00
parent 985cdd3120
commit 5a1ef841d3

View File

@ -1,8 +1,14 @@
"""GLVQ example using the MNIST dataset.
TODO
- Add model serialization/deserialization
- Add evaluation metrics
This script also shows how to use Tensorboard for visualizing the prototypes.
"""
import argparse
import pytorch_lightning as pl
import torchvision
from matplotlib import pyplot as plt
@ -32,6 +38,30 @@ class VisualizationCallback(pl.Callback):
if __name__ == "__main__":
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--epochs",
type=int,
default=10,
help="Epochs to train.")
parser.add_argument("--lr",
type=float,
default=0.001,
help="Learning rate.")
parser.add_argument("--batch_size",
type=int,
default=256,
help="Batch size.")
parser.add_argument("--gpus",
type=int,
default=0,
help="Number of GPUs to use.")
parser.add_argument("--ppc",
type=int,
default=1,
help="Prototypes-Per-Class.")
args = parser.parse_args()
# Dataset
mnist_train = MNIST(
"./datasets",
@ -63,19 +93,19 @@ if __name__ == "__main__":
# Initialize the model
model = ImageGLVQ(input_dim=28 * 28,
nclasses=10,
prototypes_per_class=10,
prototypes_per_class=args.ppc,
prototype_initializer="stratified_mean",
data=[x, y])
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=10)
vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc)
# Setup trainer
trainer = pl.Trainer(
gpus=0, # change to use GPUs for training
max_epochs=10,
gpus=args.gpus, # change to use GPUs for training
max_epochs=args.epochs,
callbacks=[vis],
# accelerator="ddp_cpu", # DEBUG-ONLY
# num_processes=2, # DEBUG-ONLY