2021-05-25 19:13:37 +00:00
|
|
|
"""GLVQ example using the MNIST dataset."""
|
2021-05-21 15:10:36 +00:00
|
|
|
|
|
|
|
from prototorch.models import ImageGLVQ
|
2021-05-28 14:33:31 +00:00
|
|
|
from prototorch.models.data import train_on_mnist
|
2021-05-21 15:10:36 +00:00
|
|
|
from pytorch_lightning.utilities.cli import LightningCLI
|
|
|
|
|
|
|
|
|
2021-05-28 14:33:31 +00:00
|
|
|
class GLVQMNIST(train_on_mnist(batch_size=64), ImageGLVQ):
|
2021-05-25 19:13:37 +00:00
|
|
|
"""Model Definition."""
|
2021-05-21 15:10:36 +00:00
|
|
|
|
|
|
|
|
2021-05-25 19:13:37 +00:00
|
|
|
cli = LightningCLI(GLVQMNIST)
|