No implicit learning rate scheduling
This commit is contained in:
parent
b0df61d1c3
commit
42d974e08c
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
@ -29,9 +29,16 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GLVQ(hparams,
|
model = pt.models.GLVQ(
|
||||||
optimizer=torch.optim.Adam,
|
hparams,
|
||||||
prototype_initializer=pt.components.SMI(train_ds))
|
optimizer=torch.optim.Adam,
|
||||||
|
prototype_initializer=pt.components.SMI(train_ds),
|
||||||
|
lr_scheduler=ExponentialLR,
|
||||||
|
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
@ -40,6 +47,8 @@ if __name__ == "__main__":
|
|||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
callbacks=[vis],
|
callbacks=[vis],
|
||||||
|
weights_summary="full",
|
||||||
|
accelerator="ddp",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractPrototypeModel(pl.LightningModule):
|
class AbstractPrototypeModel(pl.LightningModule):
|
||||||
@ -18,15 +17,16 @@ class AbstractPrototypeModel(pl.LightningModule):
|
|||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
||||||
scheduler = ExponentialLR(optimizer,
|
if self.lr_scheduler is not None:
|
||||||
gamma=0.99,
|
scheduler = self.lr_scheduler(optimizer,
|
||||||
last_epoch=-1,
|
**self.lr_scheduler_kwargs)
|
||||||
verbose=False)
|
sch = {
|
||||||
sch = {
|
"scheduler": scheduler,
|
||||||
"scheduler": scheduler,
|
"interval": "step",
|
||||||
"interval": "step",
|
} # called after each training step
|
||||||
} # called after each training step
|
return [optimizer], [sch]
|
||||||
return [optimizer], [sch]
|
else:
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
class PrototypeImageModel(pl.LightningModule):
|
class PrototypeImageModel(pl.LightningModule):
|
||||||
|
@ -5,9 +5,12 @@ import torchmetrics
|
|||||||
from prototorch.components import LabeledComponents
|
from prototorch.components import LabeledComponents
|
||||||
from prototorch.functions.activations import get_activation
|
from prototorch.functions.activations import get_activation
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import (euclidean_distance,
|
from prototorch.functions.distances import (
|
||||||
lomega_distance, omega_distance,
|
euclidean_distance,
|
||||||
squared_euclidean_distance)
|
lomega_distance,
|
||||||
|
omega_distance,
|
||||||
|
squared_euclidean_distance,
|
||||||
|
)
|
||||||
from prototorch.functions.helper import get_flat
|
from prototorch.functions.helper import get_flat
|
||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
from prototorch.modules import LambdaLayer
|
from prototorch.modules import LambdaLayer
|
||||||
@ -47,6 +50,8 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
self.initialize_prototype_win_ratios()
|
self.initialize_prototype_win_ratios()
|
||||||
|
|
||||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
|
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
||||||
|
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prototype_labels(self):
|
def prototype_labels(self):
|
||||||
@ -187,14 +192,25 @@ class SiameseGLVQ(GLVQ):
|
|||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||||
lr=self.hparams.proto_lr)
|
lr=self.hparams.proto_lr)
|
||||||
|
optimizer = None
|
||||||
if list(self.backbone.parameters()):
|
if list(self.backbone.parameters()):
|
||||||
# only add an optimizer is the backbone has trainable parameters
|
# only add an optimizer is the backbone has trainable parameters
|
||||||
# otherwise, the next line fails
|
# otherwise, the next line fails
|
||||||
bb_opt = self.optimizer(self.backbone.parameters(),
|
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||||
lr=self.hparams.bb_lr)
|
lr=self.hparams.bb_lr)
|
||||||
return proto_opt, bb_opt
|
optimizer = [proto_opt, bb_opt]
|
||||||
else:
|
else:
|
||||||
return proto_opt
|
optimizer = proto_opt
|
||||||
|
if self.lr_scheduler is not None:
|
||||||
|
scheduler = self.lr_scheduler(optimizer,
|
||||||
|
**self.lr_scheduler_kwargs)
|
||||||
|
sch = {
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"interval": "step",
|
||||||
|
} # called after each training step
|
||||||
|
return optimizer, [sch]
|
||||||
|
else:
|
||||||
|
return optimizer
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
|
Loading…
Reference in New Issue
Block a user