No implicit learning rate scheduling
This commit is contained in:
		| @@ -2,10 +2,10 @@ | |||||||
|  |  | ||||||
| import argparse | import argparse | ||||||
|  |  | ||||||
|  | import prototorch as pt | ||||||
| import pytorch_lightning as pl | import pytorch_lightning as pl | ||||||
| import torch | import torch | ||||||
|  | from torch.optim.lr_scheduler import ExponentialLR | ||||||
| import prototorch as pt |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     # Command-line arguments |     # Command-line arguments | ||||||
| @@ -29,9 +29,16 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     # Initialize the model |     # Initialize the model | ||||||
|     model = pt.models.GLVQ(hparams, |     model = pt.models.GLVQ( | ||||||
|  |         hparams, | ||||||
|         optimizer=torch.optim.Adam, |         optimizer=torch.optim.Adam, | ||||||
|                            prototype_initializer=pt.components.SMI(train_ds)) |         prototype_initializer=pt.components.SMI(train_ds), | ||||||
|  |         lr_scheduler=ExponentialLR, | ||||||
|  |         lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Compute intermediate input and output sizes | ||||||
|  |     model.example_input_array = torch.zeros(4, 2) | ||||||
|  |  | ||||||
|     # Callbacks |     # Callbacks | ||||||
|     vis = pt.models.VisGLVQ2D(data=train_ds) |     vis = pt.models.VisGLVQ2D(data=train_ds) | ||||||
| @@ -40,6 +47,8 @@ if __name__ == "__main__": | |||||||
|     trainer = pl.Trainer.from_argparse_args( |     trainer = pl.Trainer.from_argparse_args( | ||||||
|         args, |         args, | ||||||
|         callbacks=[vis], |         callbacks=[vis], | ||||||
|  |         weights_summary="full", | ||||||
|  |         accelerator="ddp", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     # Training loop |     # Training loop | ||||||
|   | |||||||
| @@ -1,5 +1,4 @@ | |||||||
| import pytorch_lightning as pl | import pytorch_lightning as pl | ||||||
| from torch.optim.lr_scheduler import ExponentialLR |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AbstractPrototypeModel(pl.LightningModule): | class AbstractPrototypeModel(pl.LightningModule): | ||||||
| @@ -18,15 +17,16 @@ class AbstractPrototypeModel(pl.LightningModule): | |||||||
|  |  | ||||||
|     def configure_optimizers(self): |     def configure_optimizers(self): | ||||||
|         optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr) |         optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr) | ||||||
|         scheduler = ExponentialLR(optimizer, |         if self.lr_scheduler is not None: | ||||||
|                                   gamma=0.99, |             scheduler = self.lr_scheduler(optimizer, | ||||||
|                                   last_epoch=-1, |                                           **self.lr_scheduler_kwargs) | ||||||
|                                   verbose=False) |  | ||||||
|             sch = { |             sch = { | ||||||
|                 "scheduler": scheduler, |                 "scheduler": scheduler, | ||||||
|                 "interval": "step", |                 "interval": "step", | ||||||
|             }  # called after each training step |             }  # called after each training step | ||||||
|             return [optimizer], [sch] |             return [optimizer], [sch] | ||||||
|  |         else: | ||||||
|  |             return optimizer | ||||||
|  |  | ||||||
|  |  | ||||||
| class PrototypeImageModel(pl.LightningModule): | class PrototypeImageModel(pl.LightningModule): | ||||||
|   | |||||||
| @@ -5,9 +5,12 @@ import torchmetrics | |||||||
| from prototorch.components import LabeledComponents | from prototorch.components import LabeledComponents | ||||||
| from prototorch.functions.activations import get_activation | from prototorch.functions.activations import get_activation | ||||||
| from prototorch.functions.competitions import wtac | from prototorch.functions.competitions import wtac | ||||||
| from prototorch.functions.distances import (euclidean_distance, | from prototorch.functions.distances import ( | ||||||
|                                             lomega_distance, omega_distance, |     euclidean_distance, | ||||||
|                                             squared_euclidean_distance) |     lomega_distance, | ||||||
|  |     omega_distance, | ||||||
|  |     squared_euclidean_distance, | ||||||
|  | ) | ||||||
| from prototorch.functions.helper import get_flat | from prototorch.functions.helper import get_flat | ||||||
| from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss | from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss | ||||||
| from prototorch.modules import LambdaLayer | from prototorch.modules import LambdaLayer | ||||||
| @@ -47,6 +50,8 @@ class GLVQ(AbstractPrototypeModel): | |||||||
|         self.initialize_prototype_win_ratios() |         self.initialize_prototype_win_ratios() | ||||||
|  |  | ||||||
|         self.optimizer = kwargs.get("optimizer", torch.optim.Adam) |         self.optimizer = kwargs.get("optimizer", torch.optim.Adam) | ||||||
|  |         self.lr_scheduler = kwargs.get("lr_scheduler", None) | ||||||
|  |         self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict()) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def prototype_labels(self): |     def prototype_labels(self): | ||||||
| @@ -187,14 +192,25 @@ class SiameseGLVQ(GLVQ): | |||||||
|     def configure_optimizers(self): |     def configure_optimizers(self): | ||||||
|         proto_opt = self.optimizer(self.proto_layer.parameters(), |         proto_opt = self.optimizer(self.proto_layer.parameters(), | ||||||
|                                    lr=self.hparams.proto_lr) |                                    lr=self.hparams.proto_lr) | ||||||
|  |         optimizer = None | ||||||
|         if list(self.backbone.parameters()): |         if list(self.backbone.parameters()): | ||||||
|             # only add an optimizer is the backbone has trainable parameters |             # only add an optimizer is the backbone has trainable parameters | ||||||
|             # otherwise, the next line fails |             # otherwise, the next line fails | ||||||
|             bb_opt = self.optimizer(self.backbone.parameters(), |             bb_opt = self.optimizer(self.backbone.parameters(), | ||||||
|                                     lr=self.hparams.bb_lr) |                                     lr=self.hparams.bb_lr) | ||||||
|             return proto_opt, bb_opt |             optimizer = [proto_opt, bb_opt] | ||||||
|         else: |         else: | ||||||
|             return proto_opt |             optimizer = proto_opt | ||||||
|  |         if self.lr_scheduler is not None: | ||||||
|  |             scheduler = self.lr_scheduler(optimizer, | ||||||
|  |                                           **self.lr_scheduler_kwargs) | ||||||
|  |             sch = { | ||||||
|  |                 "scheduler": scheduler, | ||||||
|  |                 "interval": "step", | ||||||
|  |             }  # called after each training step | ||||||
|  |             return optimizer, [sch] | ||||||
|  |         else: | ||||||
|  |             return optimizer | ||||||
|  |  | ||||||
|     def _forward(self, x): |     def _forward(self, x): | ||||||
|         protos, _ = self.proto_layer() |         protos, _ = self.proto_layer() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user