No implicit learning rate scheduling

This commit is contained in:
Jensun Ravichandran 2021-06-04 15:55:06 +02:00
parent b0df61d1c3
commit 42d974e08c
3 changed files with 45 additions and 20 deletions

View File

@ -2,10 +2,10 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -29,9 +29,16 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.GLVQ(hparams, model = pt.models.GLVQ(
hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SMI(train_ds)) prototype_initializer=pt.components.SMI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds) vis = pt.models.VisGLVQ2D(data=train_ds)
@ -40,6 +47,8 @@ if __name__ == "__main__":
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[vis],
weights_summary="full",
accelerator="ddp",
) )
# Training loop # Training loop

View File

@ -1,5 +1,4 @@
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.optim.lr_scheduler import ExponentialLR
class AbstractPrototypeModel(pl.LightningModule): class AbstractPrototypeModel(pl.LightningModule):
@ -18,15 +17,16 @@ class AbstractPrototypeModel(pl.LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr) optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer, if self.lr_scheduler is not None:
gamma=0.99, scheduler = self.lr_scheduler(optimizer,
last_epoch=-1, **self.lr_scheduler_kwargs)
verbose=False)
sch = { sch = {
"scheduler": scheduler, "scheduler": scheduler,
"interval": "step", "interval": "step",
} # called after each training step } # called after each training step
return [optimizer], [sch] return [optimizer], [sch]
else:
return optimizer
class PrototypeImageModel(pl.LightningModule): class PrototypeImageModel(pl.LightningModule):

View File

@ -5,9 +5,12 @@ import torchmetrics
from prototorch.components import LabeledComponents from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, from prototorch.functions.distances import (
lomega_distance, omega_distance, euclidean_distance,
squared_euclidean_distance) lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer from prototorch.modules import LambdaLayer
@ -47,6 +50,8 @@ class GLVQ(AbstractPrototypeModel):
self.initialize_prototype_win_ratios() self.initialize_prototype_win_ratios()
self.optimizer = kwargs.get("optimizer", torch.optim.Adam) self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
@property @property
def prototype_labels(self): def prototype_labels(self):
@ -187,14 +192,25 @@ class SiameseGLVQ(GLVQ):
def configure_optimizers(self): def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(), proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams.proto_lr) lr=self.hparams.proto_lr)
optimizer = None
if list(self.backbone.parameters()): if list(self.backbone.parameters()):
# only add an optimizer is the backbone has trainable parameters # only add an optimizer is the backbone has trainable parameters
# otherwise, the next line fails # otherwise, the next line fails
bb_opt = self.optimizer(self.backbone.parameters(), bb_opt = self.optimizer(self.backbone.parameters(),
lr=self.hparams.bb_lr) lr=self.hparams.bb_lr)
return proto_opt, bb_opt optimizer = [proto_opt, bb_opt]
else: else:
return proto_opt optimizer = proto_opt
if self.lr_scheduler is not None:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return optimizer, [sch]
else:
return optimizer
def _forward(self, x): def _forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()