diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index 562b6d3..082b7ed 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -3,9 +3,13 @@ import torch from torch.optim.lr_scheduler import ExponentialLR -class AbstractLightningModel(pl.LightningModule): +class AbstractPrototypeModel(pl.LightningModule): + @property + def prototypes(self): + return self.proto_layer.components.detach().cpu() + def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) + optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr) scheduler = ExponentialLR(optimizer, gamma=0.99, last_epoch=-1, @@ -15,9 +19,3 @@ class AbstractLightningModel(pl.LightningModule): "interval": "step", } # called after each training step return [optimizer], [sch] - - -class AbstractPrototypeModel(AbstractLightningModel): - @property - def prototypes(self): - return self.proto_layer.components.detach().cpu() diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index b88ea23..1766064 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -9,8 +9,6 @@ from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from .abstract import AbstractPrototypeModel -from torch.optim.lr_scheduler import ExponentialLR - class GLVQ(AbstractPrototypeModel): """Generalized Learning Vector Quantization.""" @@ -19,14 +17,15 @@ class GLVQ(AbstractPrototypeModel): self.save_hyperparameters(hparams) + self.optimizer = kwargs.get("optimizer", torch.optim.Adam) + # Default Values self.hparams.setdefault("distance", euclidean_distance) - self.hparams.setdefault("optimizer", torch.optim.Adam) self.hparams.setdefault("transfer_function", "identity") self.hparams.setdefault("transfer_beta", 10.0) self.proto_layer = LabeledComponents( - labels=(self.hparams.nclasses, self.hparams.prototypes_per_class), + distribution=self.hparams.distribution, initializer=self.hparams.prototype_initializer) self.transfer_function = get_activation(self.hparams.transfer_function) @@ -81,39 +80,19 @@ class GLVQ(AbstractPrototypeModel): class LVQ1(GLVQ): + """Learning Vector Quantization 1.""" def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) self.loss = lvq1_loss - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr) - scheduler = ExponentialLR(optimizer, - gamma=0.99, - last_epoch=-1, - verbose=False) - sch = { - "scheduler": scheduler, - "interval": "step", - } # called after each training step - return [optimizer], [sch] + self.optimizer = torch.optim.SGD class LVQ21(GLVQ): + """Learning Vector Quantization 2.1.""" def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) self.loss = lvq21_loss - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr) - scheduler = ExponentialLR(optimizer, - gamma=0.99, - last_epoch=-1, - verbose=False) - sch = { - "scheduler": scheduler, - "interval": "step", - } # called after each training step - return [optimizer], [sch] + self.optimizer = torch.optim.SGD class ImageGLVQ(GLVQ): @@ -152,13 +131,13 @@ class SiameseGLVQ(GLVQ): self.backbone_dependent.load_state_dict(master_state, strict=True) def configure_optimizers(self): - optim = self.hparams.optimizer - proto_opt = optim(self.proto_layer.parameters(), - lr=self.hparams.proto_lr) + proto_opt = self.optimizer(self.proto_layer.parameters(), + lr=self.hparams.proto_lr) if list(self.backbone.parameters()): # only add an optimizer is the backbone has trainable parameters # otherwise, the next line fails - bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr) + bb_opt = self.optimizer(self.backbone.parameters(), + lr=self.hparams.bb_lr) return proto_opt, bb_opt else: return proto_opt diff --git a/prototorch/models/neural_gas.py b/prototorch/models/neural_gas.py index bebd289..29e2b83 100644 --- a/prototorch/models/neural_gas.py +++ b/prototorch/models/neural_gas.py @@ -1,6 +1,7 @@ import torch from prototorch.components import Components from prototorch.components import initializers as cinit +from prototorch.components.initializers import ZerosInitializer from prototorch.functions.distances import euclidean_distance from prototorch.modules.losses import NeuralGasEnergy @@ -41,12 +42,14 @@ class NeuralGas(AbstractPrototypeModel): self.save_hyperparameters(hparams) + self.optimizer = kwargs.get("optimizer", torch.optim.Adam) + # Default Values self.hparams.setdefault("input_dim", 2) self.hparams.setdefault("agelimit", 10) self.hparams.setdefault("lm", 1) self.hparams.setdefault("prototype_initializer", - cinit.ZerosInitializer(self.hparams.input_dim)) + ZerosInitializer(self.hparams.input_dim)) self.proto_layer = Components( self.hparams.num_prototypes,