Add support for multiple optimizers

This commit is contained in:
Jensun Ravichandran 2021-05-03 13:20:49 +02:00
parent 042b3fcaa2
commit 96aeaa3448
2 changed files with 19 additions and 2 deletions

View File

@ -1,11 +1,20 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.optim.lr_scheduler import ExponentialLR
class AbstractLightningModel(pl.LightningModule): class AbstractLightningModel(pl.LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class AbstractPrototypeModel(AbstractLightningModel): class AbstractPrototypeModel(AbstractLightningModel):

View File

@ -19,6 +19,7 @@ class GLVQ(AbstractPrototypeModel):
# Default Values # Default Values
self.hparams.setdefault("distance", euclidean_distance) self.hparams.setdefault("distance", euclidean_distance)
self.hparams.setdefault("optimizer", torch.optim.Adam)
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class), labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
@ -35,7 +36,7 @@ class GLVQ(AbstractPrototypeModel):
dis = self.hparams.distance(x, protos) dis = self.hparams.distance(x, protos)
return dis return dis
def training_step(self, train_batch, batch_idx): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
x, y = train_batch x, y = train_batch
x = x.view(x.size(0), -1) # flatten x = x.view(x.size(0), -1) # flatten
dis = self(x) dis = self(x)
@ -102,6 +103,13 @@ class SiameseGLVQ(GLVQ):
master_state = self.backbone.state_dict() master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True) self.backbone_dependent.load_state_dict(master_state, strict=True)
def configure_optimizers(self):
optim = self.hparams.optimizer
proto_opt = optim(self.proto_layer.parameters(),
lr=self.hparams.proto_lr)
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
return proto_opt, bb_opt
def forward(self, x): def forward(self, x):
self.sync_backbones() self.sync_backbones()
protos, _ = self.proto_layer() protos, _ = self.proto_layer()