Fix glvq mnist example script

This commit is contained in:
Jensun Ravichandran 2021-04-23 17:49:29 +02:00
parent 688f09ca23
commit 3148684812

View File

@ -1,9 +1,5 @@
"""GLVQ example using the MNIST dataset.
TODO
- Add model serialization/deserialization
- Add evaluation metrics
This script also shows how to use Tensorboard for visualizing the prototypes.
"""
@ -91,14 +87,18 @@ if __name__ == "__main__":
x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
x = x.view(len(mnist_train), -1)
# Initialize the model
model = ImageGLVQ(
# Hyperparameters
hparams = dict(
input_dim=28 * 28,
nclasses=10,
prototypes_per_class=args.ppc,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[x, y],
lr=args.lr,
)
# Initialize the model
model = ImageGLVQ(hparams, data=[x, y])
# Model summary
print(model)