Fix glvq mnist example script

This commit is contained in:
Jensun Ravichandran 2021-04-23 17:49:29 +02:00
parent 688f09ca23
commit 3148684812

View File

@ -1,9 +1,5 @@
"""GLVQ example using the MNIST dataset. """GLVQ example using the MNIST dataset.
TODO
- Add model serialization/deserialization
- Add evaluation metrics
This script also shows how to use Tensorboard for visualizing the prototypes. This script also shows how to use Tensorboard for visualizing the prototypes.
""" """
@ -91,14 +87,18 @@ if __name__ == "__main__":
x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train)))) x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
x = x.view(len(mnist_train), -1) x = x.view(len(mnist_train), -1)
# Initialize the model # Hyperparameters
model = ImageGLVQ( hparams = dict(
input_dim=28 * 28, input_dim=28 * 28,
nclasses=10, nclasses=10,
prototypes_per_class=args.ppc, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=[x, y], lr=args.lr,
) )
# Initialize the model
model = ImageGLVQ(hparams, data=[x, y])
# Model summary # Model summary
print(model) print(model)