Update mnist example
This commit is contained in:
parent
985cdd3120
commit
5a1ef841d3
@ -1,8 +1,14 @@
|
||||
"""GLVQ example using the MNIST dataset.
|
||||
|
||||
TODO
|
||||
- Add model serialization/deserialization
|
||||
- Add evaluation metrics
|
||||
|
||||
This script also shows how to use Tensorboard for visualizing the prototypes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torchvision
|
||||
from matplotlib import pyplot as plt
|
||||
@ -32,6 +38,30 @@ class VisualizationCallback(pl.Callback):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--epochs",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Epochs to train.")
|
||||
parser.add_argument("--lr",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Learning rate.")
|
||||
parser.add_argument("--batch_size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Batch size.")
|
||||
parser.add_argument("--gpus",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of GPUs to use.")
|
||||
parser.add_argument("--ppc",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Prototypes-Per-Class.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
mnist_train = MNIST(
|
||||
"./datasets",
|
||||
@ -63,19 +93,19 @@ if __name__ == "__main__":
|
||||
# Initialize the model
|
||||
model = ImageGLVQ(input_dim=28 * 28,
|
||||
nclasses=10,
|
||||
prototypes_per_class=10,
|
||||
prototypes_per_class=args.ppc,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[x, y])
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=10)
|
||||
vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
gpus=0, # change to use GPUs for training
|
||||
max_epochs=10,
|
||||
gpus=args.gpus, # change to use GPUs for training
|
||||
max_epochs=args.epochs,
|
||||
callbacks=[vis],
|
||||
# accelerator="ddp_cpu", # DEBUG-ONLY
|
||||
# num_processes=2, # DEBUG-ONLY
|
||||
|
Loading…
Reference in New Issue
Block a user