Fix glvq mnist example script
This commit is contained in:
parent
688f09ca23
commit
3148684812
@ -1,9 +1,5 @@
|
||||
"""GLVQ example using the MNIST dataset.
|
||||
|
||||
TODO
|
||||
- Add model serialization/deserialization
|
||||
- Add evaluation metrics
|
||||
|
||||
This script also shows how to use Tensorboard for visualizing the prototypes.
|
||||
"""
|
||||
|
||||
@ -91,14 +87,18 @@ if __name__ == "__main__":
|
||||
x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
|
||||
x = x.view(len(mnist_train), -1)
|
||||
|
||||
# Initialize the model
|
||||
model = ImageGLVQ(
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
input_dim=28 * 28,
|
||||
nclasses=10,
|
||||
prototypes_per_class=args.ppc,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[x, y],
|
||||
lr=args.lr,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = ImageGLVQ(hparams, data=[x, y])
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user