Update example scripts
This commit is contained in:
@@ -1,14 +1,8 @@
|
||||
"""GLVQ example using the spiral dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from prototorch.components import initializers as cinit
|
||||
from prototorch.datasets.abstract import NumpyDataset
|
||||
from prototorch.datasets.spiral import make_spiral
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from prototorch.models.callbacks.visualization import VisGLVQ2D
|
||||
from prototorch.models.glvq import GLVQ
|
||||
|
||||
|
||||
class StopOnNaN(pl.Callback):
|
||||
@@ -23,29 +17,28 @@ class StopOnNaN(pl.Callback):
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
x_train, y_train = make_spiral(n_samples=600, noise=0.6)
|
||||
train_ds = NumpyDataset(x_train, y_train)
|
||||
train_ds = pt.datasets.Spiral(n_samples=600, noise=0.6)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = DataLoader(train_ds, num_workers=0, batch_size=256)
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=256)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
nclasses=2,
|
||||
prototypes_per_class=20,
|
||||
prototype_initializer=cinit.SSI(torch.Tensor(x_train),
|
||||
torch.Tensor(y_train),
|
||||
noise=1e-7),
|
||||
prototype_initializer=pt.components.SSI(train_ds, noise=1e-7),
|
||||
transfer_function="sigmoid_beta",
|
||||
transfer_beta=10.0,
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = GLVQ(hparams)
|
||||
model = pt.models.GLVQ(hparams)
|
||||
|
||||
# Callbacks
|
||||
vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True)
|
||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
|
||||
snan = StopOnNaN(model.proto_layer.components)
|
||||
|
||||
# Setup trainer
|
||||
|
Reference in New Issue
Block a user