diff --git a/examples/dynamic_components.py b/examples/dynamic_components.py index 3b5f68b..7b29e11 100644 --- a/examples/dynamic_components.py +++ b/examples/dynamic_components.py @@ -2,10 +2,24 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch +from pytorch_lightning.callbacks import Callback + + +class PrototypeScheduler(Callback): + def __init__(self, train_ds, freq=20): + self.train_ds = train_ds + self.freq = freq + + def on_epoch_end(self, trainer, pl_module): + if (trainer.current_epoch + 1) % self.freq == 0: + pl_module.increase_prototypes( + pt.components.SMI(self.train_ds), + distribution=[1, 1, 1], + ) -import prototorch as pt if __name__ == "__main__": # Command-line arguments @@ -33,24 +47,17 @@ if __name__ == "__main__": prototype_initializer=pt.components.SMI(train_ds), ) - for _ in range(5): - # Callbacks - vis = pt.models.VisGLVQ2D(train_ds) + # Callbacks + vis = pt.models.VisGLVQ2D(train_ds) + proto_scheduler = PrototypeScheduler(train_ds, 10) - # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, - max_epochs=20, - callbacks=[vis], - terminate_on_nan=True, - weights_summary=None, - ) + # Setup trainer + trainer = pl.Trainer.from_argparse_args(args, + max_epochs=100, + callbacks=[vis, proto_scheduler], + terminate_on_nan=True, + weights_summary=None, + accelerator='ddp') - # Training loop - trainer.fit(model, train_loader) - - # Increase prototypes - model.increase_prototypes( - pt.components.SMI(train_ds), - distribution=[1, 1, 1], - ) + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 29159eb..3e4baa5 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -137,7 +137,7 @@ class GLVQ(AbstractPrototypeModel): def increase_prototypes(self, initializer, distribution): self.proto_layer.increase_components(initializer, distribution) - #self.trainer.accelerated_backend.setup_optimizers(self) + self.trainer.accelerator_backend.setup_optimizers(self.trainer) def __repr__(self): super_repr = super().__repr__()