Update mnist example

This commit is contained in:
Jensun Ravichandran 2021-04-21 16:28:20 +02:00
parent 985cdd3120
commit 5a1ef841d3

View File

@ -1,8 +1,14 @@
"""GLVQ example using the MNIST dataset. """GLVQ example using the MNIST dataset.
TODO
- Add model serialization/deserialization
- Add evaluation metrics
This script also shows how to use Tensorboard for visualizing the prototypes. This script also shows how to use Tensorboard for visualizing the prototypes.
""" """
import argparse
import pytorch_lightning as pl import pytorch_lightning as pl
import torchvision import torchvision
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
@ -32,6 +38,30 @@ class VisualizationCallback(pl.Callback):
if __name__ == "__main__": if __name__ == "__main__":
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--epochs",
type=int,
default=10,
help="Epochs to train.")
parser.add_argument("--lr",
type=float,
default=0.001,
help="Learning rate.")
parser.add_argument("--batch_size",
type=int,
default=256,
help="Batch size.")
parser.add_argument("--gpus",
type=int,
default=0,
help="Number of GPUs to use.")
parser.add_argument("--ppc",
type=int,
default=1,
help="Prototypes-Per-Class.")
args = parser.parse_args()
# Dataset # Dataset
mnist_train = MNIST( mnist_train = MNIST(
"./datasets", "./datasets",
@ -63,19 +93,19 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = ImageGLVQ(input_dim=28 * 28, model = ImageGLVQ(input_dim=28 * 28,
nclasses=10, nclasses=10,
prototypes_per_class=10, prototypes_per_class=args.ppc,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=[x, y]) data=[x, y])
# Model summary # Model summary
print(model) print(model)
# Callbacks # Callbacks
vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=10) vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc)
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer(
gpus=0, # change to use GPUs for training gpus=args.gpus, # change to use GPUs for training
max_epochs=10, max_epochs=args.epochs,
callbacks=[vis], callbacks=[vis],
# accelerator="ddp_cpu", # DEBUG-ONLY # accelerator="ddp_cpu", # DEBUG-ONLY
# num_processes=2, # DEBUG-ONLY # num_processes=2, # DEBUG-ONLY