Change optimizer using kwargs
This commit is contained in:
parent
b38acd58a8
commit
eab1ec72c2
@ -3,9 +3,13 @@ import torch
|
|||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
|
|
||||||
class AbstractLightningModel(pl.LightningModule):
|
class AbstractPrototypeModel(pl.LightningModule):
|
||||||
|
@property
|
||||||
|
def prototypes(self):
|
||||||
|
return self.proto_layer.components.detach().cpu()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
||||||
scheduler = ExponentialLR(optimizer,
|
scheduler = ExponentialLR(optimizer,
|
||||||
gamma=0.99,
|
gamma=0.99,
|
||||||
last_epoch=-1,
|
last_epoch=-1,
|
||||||
@ -15,9 +19,3 @@ class AbstractLightningModel(pl.LightningModule):
|
|||||||
"interval": "step",
|
"interval": "step",
|
||||||
} # called after each training step
|
} # called after each training step
|
||||||
return [optimizer], [sch]
|
return [optimizer], [sch]
|
||||||
|
|
||||||
|
|
||||||
class AbstractPrototypeModel(AbstractLightningModel):
|
|
||||||
@property
|
|
||||||
def prototypes(self):
|
|
||||||
return self.proto_layer.components.detach().cpu()
|
|
||||||
|
@ -9,8 +9,6 @@ from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
|||||||
|
|
||||||
from .abstract import AbstractPrototypeModel
|
from .abstract import AbstractPrototypeModel
|
||||||
|
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
|
||||||
|
|
||||||
|
|
||||||
class GLVQ(AbstractPrototypeModel):
|
class GLVQ(AbstractPrototypeModel):
|
||||||
"""Generalized Learning Vector Quantization."""
|
"""Generalized Learning Vector Quantization."""
|
||||||
@ -19,14 +17,15 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
self.save_hyperparameters(hparams)
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
|
|
||||||
# Default Values
|
# Default Values
|
||||||
self.hparams.setdefault("distance", euclidean_distance)
|
self.hparams.setdefault("distance", euclidean_distance)
|
||||||
self.hparams.setdefault("optimizer", torch.optim.Adam)
|
|
||||||
self.hparams.setdefault("transfer_function", "identity")
|
self.hparams.setdefault("transfer_function", "identity")
|
||||||
self.hparams.setdefault("transfer_beta", 10.0)
|
self.hparams.setdefault("transfer_beta", 10.0)
|
||||||
|
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
|
distribution=self.hparams.distribution,
|
||||||
initializer=self.hparams.prototype_initializer)
|
initializer=self.hparams.prototype_initializer)
|
||||||
|
|
||||||
self.transfer_function = get_activation(self.hparams.transfer_function)
|
self.transfer_function = get_activation(self.hparams.transfer_function)
|
||||||
@ -81,39 +80,19 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
|
|
||||||
class LVQ1(GLVQ):
|
class LVQ1(GLVQ):
|
||||||
|
"""Learning Vector Quantization 1."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.loss = lvq1_loss
|
self.loss = lvq1_loss
|
||||||
|
self.optimizer = torch.optim.SGD
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
|
|
||||||
scheduler = ExponentialLR(optimizer,
|
|
||||||
gamma=0.99,
|
|
||||||
last_epoch=-1,
|
|
||||||
verbose=False)
|
|
||||||
sch = {
|
|
||||||
"scheduler": scheduler,
|
|
||||||
"interval": "step",
|
|
||||||
} # called after each training step
|
|
||||||
return [optimizer], [sch]
|
|
||||||
|
|
||||||
|
|
||||||
class LVQ21(GLVQ):
|
class LVQ21(GLVQ):
|
||||||
|
"""Learning Vector Quantization 2.1."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.loss = lvq21_loss
|
self.loss = lvq21_loss
|
||||||
|
self.optimizer = torch.optim.SGD
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
|
|
||||||
scheduler = ExponentialLR(optimizer,
|
|
||||||
gamma=0.99,
|
|
||||||
last_epoch=-1,
|
|
||||||
verbose=False)
|
|
||||||
sch = {
|
|
||||||
"scheduler": scheduler,
|
|
||||||
"interval": "step",
|
|
||||||
} # called after each training step
|
|
||||||
return [optimizer], [sch]
|
|
||||||
|
|
||||||
|
|
||||||
class ImageGLVQ(GLVQ):
|
class ImageGLVQ(GLVQ):
|
||||||
@ -152,13 +131,13 @@ class SiameseGLVQ(GLVQ):
|
|||||||
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optim = self.hparams.optimizer
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||||
proto_opt = optim(self.proto_layer.parameters(),
|
|
||||||
lr=self.hparams.proto_lr)
|
lr=self.hparams.proto_lr)
|
||||||
if list(self.backbone.parameters()):
|
if list(self.backbone.parameters()):
|
||||||
# only add an optimizer is the backbone has trainable parameters
|
# only add an optimizer is the backbone has trainable parameters
|
||||||
# otherwise, the next line fails
|
# otherwise, the next line fails
|
||||||
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
|
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||||
|
lr=self.hparams.bb_lr)
|
||||||
return proto_opt, bb_opt
|
return proto_opt, bb_opt
|
||||||
else:
|
else:
|
||||||
return proto_opt
|
return proto_opt
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from prototorch.components import Components
|
from prototorch.components import Components
|
||||||
from prototorch.components import initializers as cinit
|
from prototorch.components import initializers as cinit
|
||||||
|
from prototorch.components.initializers import ZerosInitializer
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import NeuralGasEnergy
|
from prototorch.modules.losses import NeuralGasEnergy
|
||||||
|
|
||||||
@ -41,12 +42,14 @@ class NeuralGas(AbstractPrototypeModel):
|
|||||||
|
|
||||||
self.save_hyperparameters(hparams)
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
|
|
||||||
# Default Values
|
# Default Values
|
||||||
self.hparams.setdefault("input_dim", 2)
|
self.hparams.setdefault("input_dim", 2)
|
||||||
self.hparams.setdefault("agelimit", 10)
|
self.hparams.setdefault("agelimit", 10)
|
||||||
self.hparams.setdefault("lm", 1)
|
self.hparams.setdefault("lm", 1)
|
||||||
self.hparams.setdefault("prototype_initializer",
|
self.hparams.setdefault("prototype_initializer",
|
||||||
cinit.ZerosInitializer(self.hparams.input_dim))
|
ZerosInitializer(self.hparams.input_dim))
|
||||||
|
|
||||||
self.proto_layer = Components(
|
self.proto_layer = Components(
|
||||||
self.hparams.num_prototypes,
|
self.hparams.num_prototypes,
|
||||||
|
Loading…
Reference in New Issue
Block a user