Fix glvq mnist example script
This commit is contained in:
parent
688f09ca23
commit
3148684812
@ -1,9 +1,5 @@
|
|||||||
"""GLVQ example using the MNIST dataset.
|
"""GLVQ example using the MNIST dataset.
|
||||||
|
|
||||||
TODO
|
|
||||||
- Add model serialization/deserialization
|
|
||||||
- Add evaluation metrics
|
|
||||||
|
|
||||||
This script also shows how to use Tensorboard for visualizing the prototypes.
|
This script also shows how to use Tensorboard for visualizing the prototypes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -91,14 +87,18 @@ if __name__ == "__main__":
|
|||||||
x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
|
x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
|
||||||
x = x.view(len(mnist_train), -1)
|
x = x.view(len(mnist_train), -1)
|
||||||
|
|
||||||
# Initialize the model
|
# Hyperparameters
|
||||||
model = ImageGLVQ(
|
hparams = dict(
|
||||||
input_dim=28 * 28,
|
input_dim=28 * 28,
|
||||||
nclasses=10,
|
nclasses=10,
|
||||||
prototypes_per_class=args.ppc,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=[x, y],
|
lr=args.lr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = ImageGLVQ(hparams, data=[x, y])
|
||||||
|
|
||||||
# Model summary
|
# Model summary
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user