[FEATURE] Add warm-starting example

This commit is contained in:
Jensun Ravichandran 2021-06-14 20:42:57 +02:00
parent 1911d4b33e
commit 1c658cdc1b
3 changed files with 89 additions and 0 deletions

84
examples/warm_starting.py Normal file
View File

@ -0,0 +1,84 @@
"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Initialize the gng
gng = pt.models.GrowingNeuralGas(
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
prototypes_initializer=pt.initializers.ZCI(2),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Callbacks
es = pl.callbacks.EarlyStopping(
monitor="loss",
min_delta=0.001,
patience=20,
mode="min",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer for GNG
trainer = pl.Trainer(
max_epochs=200,
callbacks=[es],
weights_summary=None,
)
# Training loop
trainer.fit(gng, train_loader)
# Hyperparameters
hparams = dict(
distribution=[],
lr=0.01,
)
# Warm-start prototypes
knn = pt.models.KNN(dict(k=1), data=train_ds)
prototypes = gng.prototypes
plabels = knn.predict(prototypes)
# Initialize the model
model = pt.models.GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.LCI(prototypes),
labels_initializer=pt.initializers.LLI(plabels),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -9,6 +9,7 @@ import torchmetrics
from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents
from ..core.distances import euclidean_distance
from ..core.initializers import LabelsInitializer
from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
@ -111,10 +112,13 @@ class SupervisedPrototypeModel(PrototypeModel):
# Layers
prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer())
if prototypes_initializer is not None:
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
self.competition_layer = WTAC()

View File

@ -118,6 +118,7 @@ class GNGCallback(pl.Callback):
# Add component
pl_module.proto_layer.add_components(
None,
initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
# Adjust Topology