From 1c658cdc1b38685767490d5c3e732076cb699732 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 14 Jun 2021 20:42:57 +0200 Subject: [PATCH] [FEATURE] Add warm-starting example --- examples/warm_starting.py | 84 ++++++++++++++++++++++++++++++++++ prototorch/models/abstract.py | 4 ++ prototorch/models/callbacks.py | 1 + 3 files changed, 89 insertions(+) create mode 100644 examples/warm_starting.py diff --git a/examples/warm_starting.py b/examples/warm_starting.py new file mode 100644 index 0000000..1a966f3 --- /dev/null +++ b/examples/warm_starting.py @@ -0,0 +1,84 @@ +"""Warm-starting GLVQ with prototypes from Growing Neural Gas.""" + +import argparse + +import prototorch as pt +import pytorch_lightning as pl +import torch +from torch.optim.lr_scheduler import ExponentialLR + +if __name__ == "__main__": + # Command-line arguments + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # Prepare the data + train_ds = pt.datasets.Iris(dims=[0, 2]) + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) + + # Initialize the gng + gng = pt.models.GrowingNeuralGas( + hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1), + prototypes_initializer=pt.initializers.ZCI(2), + lr_scheduler=ExponentialLR, + lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), + ) + + # Callbacks + es = pl.callbacks.EarlyStopping( + monitor="loss", + min_delta=0.001, + patience=20, + mode="min", + verbose=False, + check_on_train_epoch_end=True, + ) + + # Setup trainer for GNG + trainer = pl.Trainer( + max_epochs=200, + callbacks=[es], + weights_summary=None, + ) + + # Training loop + trainer.fit(gng, train_loader) + + # Hyperparameters + hparams = dict( + distribution=[], + lr=0.01, + ) + + # Warm-start prototypes + knn = pt.models.KNN(dict(k=1), data=train_ds) + prototypes = gng.prototypes + plabels = knn.predict(prototypes) + + # Initialize the model + model = pt.models.GLVQ( + hparams, + optimizer=torch.optim.Adam, + prototypes_initializer=pt.initializers.LCI(prototypes), + labels_initializer=pt.initializers.LLI(plabels), + lr_scheduler=ExponentialLR, + lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), + ) + + # Compute intermediate input and output sizes + model.example_input_array = torch.zeros(4, 2) + + # Callbacks + vis = pt.models.VisGLVQ2D(data=train_ds) + + # Setup trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=[vis], + weights_summary="full", + accelerator="ddp", + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index eacf586..4e79d4a 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -9,6 +9,7 @@ import torchmetrics from ..core.competitions import WTAC from ..core.components import Components, LabeledComponents from ..core.distances import euclidean_distance +from ..core.initializers import LabelsInitializer from ..core.pooling import stratified_min_pooling from ..nn.wrappers import LambdaLayer @@ -111,10 +112,13 @@ class SupervisedPrototypeModel(PrototypeModel): # Layers prototypes_initializer = kwargs.get("prototypes_initializer", None) + labels_initializer = kwargs.get("labels_initializer", + LabelsInitializer()) if prototypes_initializer is not None: self.proto_layer = LabeledComponents( distribution=self.hparams.distribution, components_initializer=prototypes_initializer, + labels_initializer=labels_initializer, ) self.competition_layer = WTAC() diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index 0267270..62b628a 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -118,6 +118,7 @@ class GNGCallback(pl.Callback): # Add component pl_module.proto_layer.add_components( + None, initializer=LiteralCompInitializer(new_component.unsqueeze(0))) # Adjust Topology