diff --git a/examples/dynamic_components.py b/examples/dynamic_components.py new file mode 100644 index 0000000..3b5f68b --- /dev/null +++ b/examples/dynamic_components.py @@ -0,0 +1,56 @@ +"""Dynamically update the number of prototypes in GLVQ.""" + +import argparse + +import pytorch_lightning as pl +import torch + +import prototorch as pt + +if __name__ == "__main__": + # Command-line arguments + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # Dataset + train_ds = pt.datasets.Iris(dims=[0, 2]) + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32) + + # Hyperparameters + hparams = dict( + distribution=[1, 1, 1], + transfer_function="sigmoid_beta", + transfer_beta=10.0, + lr=0.01, + ) + + # Initialize the model + model = pt.models.GLVQ( + hparams, + prototype_initializer=pt.components.SMI(train_ds), + ) + + for _ in range(5): + # Callbacks + vis = pt.models.VisGLVQ2D(train_ds) + + # Setup trainer + trainer = pl.Trainer.from_argparse_args( + args, + max_epochs=20, + callbacks=[vis], + terminate_on_nan=True, + weights_summary=None, + ) + + # Training loop + trainer.fit(model, train_loader) + + # Increase prototypes + model.increase_prototypes( + pt.components.SMI(train_ds), + distribution=[1, 1, 1], + ) diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index dc3266a..7495ac4 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -1,6 +1,5 @@ import pytorch_lightning as pl import torch -from prototorch.functions.competitions import wtac from torch.optim.lr_scheduler import ExponentialLR diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 2b5f925..c8b5d7a 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -120,6 +120,9 @@ class GLVQ(AbstractPrototypeModel): # def predict_step(self, batch, batch_idx, dataloader_idx=None): # pass + def increase_prototypes(self, initializer, distribution): + self.proto_layer.increase_components(initializer, distribution) + def __repr__(self): super_repr = super().__repr__() return f"{super_repr}"