Add support for multiple optimizers
This commit is contained in:
parent
042b3fcaa2
commit
96aeaa3448
@ -1,11 +1,20 @@
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
|
|
||||||
class AbstractLightningModel(pl.LightningModule):
|
class AbstractLightningModel(pl.LightningModule):
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
||||||
return optimizer
|
scheduler = ExponentialLR(optimizer,
|
||||||
|
gamma=0.99,
|
||||||
|
last_epoch=-1,
|
||||||
|
verbose=False)
|
||||||
|
sch = {
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"interval": "step",
|
||||||
|
} # called after each training step
|
||||||
|
return [optimizer], [sch]
|
||||||
|
|
||||||
|
|
||||||
class AbstractPrototypeModel(AbstractLightningModel):
|
class AbstractPrototypeModel(AbstractLightningModel):
|
||||||
|
@ -19,6 +19,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
# Default Values
|
# Default Values
|
||||||
self.hparams.setdefault("distance", euclidean_distance)
|
self.hparams.setdefault("distance", euclidean_distance)
|
||||||
|
self.hparams.setdefault("optimizer", torch.optim.Adam)
|
||||||
|
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
|
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
|
||||||
@ -35,7 +36,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
dis = self.hparams.distance(x, protos)
|
dis = self.hparams.distance(x, protos)
|
||||||
return dis
|
return dis
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
x = x.view(x.size(0), -1) # flatten
|
x = x.view(x.size(0), -1) # flatten
|
||||||
dis = self(x)
|
dis = self(x)
|
||||||
@ -102,6 +103,13 @@ class SiameseGLVQ(GLVQ):
|
|||||||
master_state = self.backbone.state_dict()
|
master_state = self.backbone.state_dict()
|
||||||
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optim = self.hparams.optimizer
|
||||||
|
proto_opt = optim(self.proto_layer.parameters(),
|
||||||
|
lr=self.hparams.proto_lr)
|
||||||
|
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
|
||||||
|
return proto_opt, bb_opt
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
self.sync_backbones()
|
self.sync_backbones()
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
|
Loading…
Reference in New Issue
Block a user