[FEATURE] Add warm-starting example
This commit is contained in:
		
							
								
								
									
										84
									
								
								examples/warm_starting.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								examples/warm_starting.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,84 @@
 | 
				
			|||||||
 | 
					"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Prepare the data
 | 
				
			||||||
 | 
					    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
				
			||||||
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initialize the gng
 | 
				
			||||||
 | 
					    gng = pt.models.GrowingNeuralGas(
 | 
				
			||||||
 | 
					        hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
 | 
				
			||||||
 | 
					        prototypes_initializer=pt.initializers.ZCI(2),
 | 
				
			||||||
 | 
					        lr_scheduler=ExponentialLR,
 | 
				
			||||||
 | 
					        lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    es = pl.callbacks.EarlyStopping(
 | 
				
			||||||
 | 
					        monitor="loss",
 | 
				
			||||||
 | 
					        min_delta=0.001,
 | 
				
			||||||
 | 
					        patience=20,
 | 
				
			||||||
 | 
					        mode="min",
 | 
				
			||||||
 | 
					        verbose=False,
 | 
				
			||||||
 | 
					        check_on_train_epoch_end=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setup trainer for GNG
 | 
				
			||||||
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
 | 
					        max_epochs=200,
 | 
				
			||||||
 | 
					        callbacks=[es],
 | 
				
			||||||
 | 
					        weights_summary=None,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Training loop
 | 
				
			||||||
 | 
					    trainer.fit(gng, train_loader)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    hparams = dict(
 | 
				
			||||||
 | 
					        distribution=[],
 | 
				
			||||||
 | 
					        lr=0.01,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Warm-start prototypes
 | 
				
			||||||
 | 
					    knn = pt.models.KNN(dict(k=1), data=train_ds)
 | 
				
			||||||
 | 
					    prototypes = gng.prototypes
 | 
				
			||||||
 | 
					    plabels = knn.predict(prototypes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initialize the model
 | 
				
			||||||
 | 
					    model = pt.models.GLVQ(
 | 
				
			||||||
 | 
					        hparams,
 | 
				
			||||||
 | 
					        optimizer=torch.optim.Adam,
 | 
				
			||||||
 | 
					        prototypes_initializer=pt.initializers.LCI(prototypes),
 | 
				
			||||||
 | 
					        labels_initializer=pt.initializers.LLI(plabels),
 | 
				
			||||||
 | 
					        lr_scheduler=ExponentialLR,
 | 
				
			||||||
 | 
					        lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    vis = pt.models.VisGLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setup trainer
 | 
				
			||||||
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					        args,
 | 
				
			||||||
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					        weights_summary="full",
 | 
				
			||||||
 | 
					        accelerator="ddp",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Training loop
 | 
				
			||||||
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
@@ -9,6 +9,7 @@ import torchmetrics
 | 
				
			|||||||
from ..core.competitions import WTAC
 | 
					from ..core.competitions import WTAC
 | 
				
			||||||
from ..core.components import Components, LabeledComponents
 | 
					from ..core.components import Components, LabeledComponents
 | 
				
			||||||
from ..core.distances import euclidean_distance
 | 
					from ..core.distances import euclidean_distance
 | 
				
			||||||
 | 
					from ..core.initializers import LabelsInitializer
 | 
				
			||||||
from ..core.pooling import stratified_min_pooling
 | 
					from ..core.pooling import stratified_min_pooling
 | 
				
			||||||
from ..nn.wrappers import LambdaLayer
 | 
					from ..nn.wrappers import LambdaLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -111,10 +112,13 @@ class SupervisedPrototypeModel(PrototypeModel):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Layers
 | 
					        # Layers
 | 
				
			||||||
        prototypes_initializer = kwargs.get("prototypes_initializer", None)
 | 
					        prototypes_initializer = kwargs.get("prototypes_initializer", None)
 | 
				
			||||||
 | 
					        labels_initializer = kwargs.get("labels_initializer",
 | 
				
			||||||
 | 
					                                        LabelsInitializer())
 | 
				
			||||||
        if prototypes_initializer is not None:
 | 
					        if prototypes_initializer is not None:
 | 
				
			||||||
            self.proto_layer = LabeledComponents(
 | 
					            self.proto_layer = LabeledComponents(
 | 
				
			||||||
                distribution=self.hparams.distribution,
 | 
					                distribution=self.hparams.distribution,
 | 
				
			||||||
                components_initializer=prototypes_initializer,
 | 
					                components_initializer=prototypes_initializer,
 | 
				
			||||||
 | 
					                labels_initializer=labels_initializer,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        self.competition_layer = WTAC()
 | 
					        self.competition_layer = WTAC()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -118,6 +118,7 @@ class GNGCallback(pl.Callback):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            # Add component
 | 
					            # Add component
 | 
				
			||||||
            pl_module.proto_layer.add_components(
 | 
					            pl_module.proto_layer.add_components(
 | 
				
			||||||
 | 
					                None,
 | 
				
			||||||
                initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
 | 
					                initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Adjust Topology
 | 
					            # Adjust Topology
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user