No implicit learning rate scheduling
This commit is contained in:
parent
b0df61d1c3
commit
42d974e08c
@ -2,10 +2,10 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
@ -29,9 +29,16 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GLVQ(hparams,
|
||||
model = pt.models.GLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.SMI(train_ds))
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
lr_scheduler=ExponentialLR,
|
||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
@ -40,6 +47,8 @@ if __name__ == "__main__":
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
|
@ -1,5 +1,4 @@
|
||||
import pytorch_lightning as pl
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
|
||||
class AbstractPrototypeModel(pl.LightningModule):
|
||||
@ -18,15 +17,16 @@ class AbstractPrototypeModel(pl.LightningModule):
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
||||
scheduler = ExponentialLR(optimizer,
|
||||
gamma=0.99,
|
||||
last_epoch=-1,
|
||||
verbose=False)
|
||||
if self.lr_scheduler is not None:
|
||||
scheduler = self.lr_scheduler(optimizer,
|
||||
**self.lr_scheduler_kwargs)
|
||||
sch = {
|
||||
"scheduler": scheduler,
|
||||
"interval": "step",
|
||||
} # called after each training step
|
||||
return [optimizer], [sch]
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
|
||||
class PrototypeImageModel(pl.LightningModule):
|
||||
|
@ -5,9 +5,12 @@ import torchmetrics
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (euclidean_distance,
|
||||
lomega_distance, omega_distance,
|
||||
squared_euclidean_distance)
|
||||
from prototorch.functions.distances import (
|
||||
euclidean_distance,
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from prototorch.functions.helper import get_flat
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from prototorch.modules import LambdaLayer
|
||||
@ -47,6 +50,8 @@ class GLVQ(AbstractPrototypeModel):
|
||||
self.initialize_prototype_win_ratios()
|
||||
|
||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
||||
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
@ -187,14 +192,25 @@ class SiameseGLVQ(GLVQ):
|
||||
def configure_optimizers(self):
|
||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||
lr=self.hparams.proto_lr)
|
||||
optimizer = None
|
||||
if list(self.backbone.parameters()):
|
||||
# only add an optimizer is the backbone has trainable parameters
|
||||
# otherwise, the next line fails
|
||||
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||
lr=self.hparams.bb_lr)
|
||||
return proto_opt, bb_opt
|
||||
optimizer = [proto_opt, bb_opt]
|
||||
else:
|
||||
return proto_opt
|
||||
optimizer = proto_opt
|
||||
if self.lr_scheduler is not None:
|
||||
scheduler = self.lr_scheduler(optimizer,
|
||||
**self.lr_scheduler_kwargs)
|
||||
sch = {
|
||||
"scheduler": scheduler,
|
||||
"interval": "step",
|
||||
} # called after each training step
|
||||
return optimizer, [sch]
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
|
Loading…
Reference in New Issue
Block a user