prototorch_models/examples/glvq_spiral.py
Jensun Ravichandran ca39aa00d5 Stop passing component initializers as hparams
Pass the component initializer as an hparam slows down the script very much. The
API has now been changed to pass it as a kwarg to the models instead.

The example scripts have also been updated to reflect the new changes.

Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also
been added.
2021-05-12 16:36:22 +02:00

54 lines
1.4 KiB
Python

"""GLVQ example using the spiral dataset."""
import prototorch as pt
import pytorch_lightning as pl
import torch
class StopOnNaN(pl.Callback):
def __init__(self, param):
super().__init__()
self.param = param
def on_epoch_end(self, trainer, pl_module, logs={}):
if torch.isnan(self.param).any():
raise ValueError("NaN encountered. Stopping.")
if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Spiral(n_samples=600, noise=0.6)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=256)
# Hyperparameters
nclasses = 2
prototypes_per_class = 20
hparams = dict(
distribution=(nclasses, prototypes_per_class),
transfer_function="sigmoid_beta",
transfer_beta=10.0,
lr=0.01,
)
# Initialize the model
model = pt.models.GLVQ(hparams,
prototype_initializer=pt.components.SSI(train_ds,
noise=1e-1))
# Callbacks
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
snan = StopOnNaN(model.proto_layer.components)
# Setup trainer
trainer = pl.Trainer(
max_epochs=200,
callbacks=[vis, snan],
)
# Training loop
trainer.fit(model, train_loader)