Update mnist example
This commit is contained in:
parent
985cdd3120
commit
5a1ef841d3
@ -1,8 +1,14 @@
|
|||||||
"""GLVQ example using the MNIST dataset.
|
"""GLVQ example using the MNIST dataset.
|
||||||
|
|
||||||
|
TODO
|
||||||
|
- Add model serialization/deserialization
|
||||||
|
- Add evaluation metrics
|
||||||
|
|
||||||
This script also shows how to use Tensorboard for visualizing the prototypes.
|
This script also shows how to use Tensorboard for visualizing the prototypes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torchvision
|
import torchvision
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
@ -32,6 +38,30 @@ class VisualizationCallback(pl.Callback):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--epochs",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Epochs to train.")
|
||||||
|
parser.add_argument("--lr",
|
||||||
|
type=float,
|
||||||
|
default=0.001,
|
||||||
|
help="Learning rate.")
|
||||||
|
parser.add_argument("--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Batch size.")
|
||||||
|
parser.add_argument("--gpus",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Number of GPUs to use.")
|
||||||
|
parser.add_argument("--ppc",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Prototypes-Per-Class.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
mnist_train = MNIST(
|
mnist_train = MNIST(
|
||||||
"./datasets",
|
"./datasets",
|
||||||
@ -63,19 +93,19 @@ if __name__ == "__main__":
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = ImageGLVQ(input_dim=28 * 28,
|
model = ImageGLVQ(input_dim=28 * 28,
|
||||||
nclasses=10,
|
nclasses=10,
|
||||||
prototypes_per_class=10,
|
prototypes_per_class=args.ppc,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=[x, y])
|
data=[x, y])
|
||||||
# Model summary
|
# Model summary
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=10)
|
vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
gpus=0, # change to use GPUs for training
|
gpus=args.gpus, # change to use GPUs for training
|
||||||
max_epochs=10,
|
max_epochs=args.epochs,
|
||||||
callbacks=[vis],
|
callbacks=[vis],
|
||||||
# accelerator="ddp_cpu", # DEBUG-ONLY
|
# accelerator="ddp_cpu", # DEBUG-ONLY
|
||||||
# num_processes=2, # DEBUG-ONLY
|
# num_processes=2, # DEBUG-ONLY
|
||||||
|
Loading…
Reference in New Issue
Block a user