Change optimizer using kwargs

This commit is contained in:
Jensun Ravichandran 2021-05-11 16:13:00 +02:00
parent b38acd58a8
commit eab1ec72c2
3 changed files with 21 additions and 41 deletions

View File

@ -3,9 +3,13 @@ import torch
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
class AbstractLightningModel(pl.LightningModule): class AbstractPrototypeModel(pl.LightningModule):
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer, scheduler = ExponentialLR(optimizer,
gamma=0.99, gamma=0.99,
last_epoch=-1, last_epoch=-1,
@ -15,9 +19,3 @@ class AbstractLightningModel(pl.LightningModule):
"interval": "step", "interval": "step",
} # called after each training step } # called after each training step
return [optimizer], [sch] return [optimizer], [sch]
class AbstractPrototypeModel(AbstractLightningModel):
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()

View File

@ -9,8 +9,6 @@ from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from .abstract import AbstractPrototypeModel from .abstract import AbstractPrototypeModel
from torch.optim.lr_scheduler import ExponentialLR
class GLVQ(AbstractPrototypeModel): class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization.""" """Generalized Learning Vector Quantization."""
@ -19,14 +17,15 @@ class GLVQ(AbstractPrototypeModel):
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
# Default Values # Default Values
self.hparams.setdefault("distance", euclidean_distance) self.hparams.setdefault("distance", euclidean_distance)
self.hparams.setdefault("optimizer", torch.optim.Adam)
self.hparams.setdefault("transfer_function", "identity") self.hparams.setdefault("transfer_function", "identity")
self.hparams.setdefault("transfer_beta", 10.0) self.hparams.setdefault("transfer_beta", 10.0)
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class), distribution=self.hparams.distribution,
initializer=self.hparams.prototype_initializer) initializer=self.hparams.prototype_initializer)
self.transfer_function = get_activation(self.hparams.transfer_function) self.transfer_function = get_activation(self.hparams.transfer_function)
@ -81,39 +80,19 @@ class GLVQ(AbstractPrototypeModel):
class LVQ1(GLVQ): class LVQ1(GLVQ):
"""Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.loss = lvq1_loss self.loss = lvq1_loss
self.optimizer = torch.optim.SGD
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class LVQ21(GLVQ): class LVQ21(GLVQ):
"""Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.loss = lvq21_loss self.loss = lvq21_loss
self.optimizer = torch.optim.SGD
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class ImageGLVQ(GLVQ): class ImageGLVQ(GLVQ):
@ -152,13 +131,13 @@ class SiameseGLVQ(GLVQ):
self.backbone_dependent.load_state_dict(master_state, strict=True) self.backbone_dependent.load_state_dict(master_state, strict=True)
def configure_optimizers(self): def configure_optimizers(self):
optim = self.hparams.optimizer proto_opt = self.optimizer(self.proto_layer.parameters(),
proto_opt = optim(self.proto_layer.parameters(),
lr=self.hparams.proto_lr) lr=self.hparams.proto_lr)
if list(self.backbone.parameters()): if list(self.backbone.parameters()):
# only add an optimizer is the backbone has trainable parameters # only add an optimizer is the backbone has trainable parameters
# otherwise, the next line fails # otherwise, the next line fails
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr) bb_opt = self.optimizer(self.backbone.parameters(),
lr=self.hparams.bb_lr)
return proto_opt, bb_opt return proto_opt, bb_opt
else: else:
return proto_opt return proto_opt

View File

@ -1,6 +1,7 @@
import torch import torch
from prototorch.components import Components from prototorch.components import Components
from prototorch.components import initializers as cinit from prototorch.components import initializers as cinit
from prototorch.components.initializers import ZerosInitializer
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import NeuralGasEnergy from prototorch.modules.losses import NeuralGasEnergy
@ -41,12 +42,14 @@ class NeuralGas(AbstractPrototypeModel):
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
# Default Values # Default Values
self.hparams.setdefault("input_dim", 2) self.hparams.setdefault("input_dim", 2)
self.hparams.setdefault("agelimit", 10) self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1) self.hparams.setdefault("lm", 1)
self.hparams.setdefault("prototype_initializer", self.hparams.setdefault("prototype_initializer",
cinit.ZerosInitializer(self.hparams.input_dim)) ZerosInitializer(self.hparams.input_dim))
self.proto_layer = Components( self.proto_layer = Components(
self.hparams.num_prototypes, self.hparams.num_prototypes,