7b9b767113
Should fix the problem with EarlyStopping callback.
20 lines
615 B
Python
20 lines
615 B
Python
"""GMLVQ example using the MNIST dataset."""
|
|
|
|
import prototorch as pt
|
|
import torch
|
|
from prototorch.models import ImageGMLVQ
|
|
from prototorch.models.abstract import PrototypeModel
|
|
from prototorch.models.data import MNISTDataModule
|
|
from pytorch_lightning.utilities.cli import LightningCLI
|
|
|
|
|
|
class ExperimentClass(ImageGMLVQ):
|
|
def __init__(self, hparams, **kwargs):
|
|
super().__init__(hparams,
|
|
optimizer=torch.optim.Adam,
|
|
prototype_initializer=pt.components.zeros(28 * 28),
|
|
**kwargs)
|
|
|
|
|
|
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)
|