Add example for dynamic components in callbacks
This commit is contained in:
		| @@ -2,10 +2,24 @@ | ||||
|  | ||||
| import argparse | ||||
|  | ||||
| import prototorch as pt | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| from pytorch_lightning.callbacks import Callback | ||||
|  | ||||
|  | ||||
| class PrototypeScheduler(Callback): | ||||
|     def __init__(self, train_ds, freq=20): | ||||
|         self.train_ds = train_ds | ||||
|         self.freq = freq | ||||
|  | ||||
|     def on_epoch_end(self, trainer, pl_module): | ||||
|         if (trainer.current_epoch + 1) % self.freq == 0: | ||||
|             pl_module.increase_prototypes( | ||||
|                 pt.components.SMI(self.train_ds), | ||||
|                 distribution=[1, 1, 1], | ||||
|             ) | ||||
|  | ||||
| import prototorch as pt | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     # Command-line arguments | ||||
| @@ -33,24 +47,17 @@ if __name__ == "__main__": | ||||
|         prototype_initializer=pt.components.SMI(train_ds), | ||||
|     ) | ||||
|  | ||||
|     for _ in range(5): | ||||
|         # Callbacks | ||||
|         vis = pt.models.VisGLVQ2D(train_ds) | ||||
|     # Callbacks | ||||
|     vis = pt.models.VisGLVQ2D(train_ds) | ||||
|     proto_scheduler = PrototypeScheduler(train_ds, 10) | ||||
|  | ||||
|         # Setup trainer | ||||
|         trainer = pl.Trainer.from_argparse_args( | ||||
|             args, | ||||
|             max_epochs=20, | ||||
|             callbacks=[vis], | ||||
|             terminate_on_nan=True, | ||||
|             weights_summary=None, | ||||
|         ) | ||||
|     # Setup trainer | ||||
|     trainer = pl.Trainer.from_argparse_args(args, | ||||
|                                             max_epochs=100, | ||||
|                                             callbacks=[vis, proto_scheduler], | ||||
|                                             terminate_on_nan=True, | ||||
|                                             weights_summary=None, | ||||
|                                             accelerator='ddp') | ||||
|  | ||||
|         # Training loop | ||||
|         trainer.fit(model, train_loader) | ||||
|  | ||||
|         # Increase prototypes | ||||
|         model.increase_prototypes( | ||||
|             pt.components.SMI(train_ds), | ||||
|             distribution=[1, 1, 1], | ||||
|         ) | ||||
|     # Training loop | ||||
|     trainer.fit(model, train_loader) | ||||
|   | ||||
| @@ -137,7 +137,7 @@ class GLVQ(AbstractPrototypeModel): | ||||
|  | ||||
|     def increase_prototypes(self, initializer, distribution): | ||||
|         self.proto_layer.increase_components(initializer, distribution) | ||||
|         #self.trainer.accelerated_backend.setup_optimizers(self) | ||||
|         self.trainer.accelerator_backend.setup_optimizers(self.trainer) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         super_repr = super().__repr__() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user