prototorch_models/examples/cli/gmlvq.py
Alexander Engelsberger 7b9b767113 fix: training loss is a zero dimensional tensor
Should fix the problem with EarlyStopping callback.
2021-06-25 17:07:06 +02:00

20 lines
615 B
Python

"""GMLVQ example using the MNIST dataset."""
import prototorch as pt
import torch
from prototorch.models import ImageGMLVQ
from prototorch.models.abstract import PrototypeModel
from prototorch.models.data import MNISTDataModule
from pytorch_lightning.utilities.cli import LightningCLI
class ExperimentClass(ImageGMLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.zeros(28 * 28),
**kwargs)
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)