Stop passing component initializers as hparams
Pass the component initializer as an hparam slows down the script very much. The API has now been changed to pass it as a kwarg to the models instead. The example scripts have also been updated to reflect the new changes. Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also been added.
This commit is contained in:
@@ -5,9 +5,18 @@ from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
||||
squared_euclidean_distance)
|
||||
from prototorch.functions.helper import get_flat
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from prototorch.modules.mappings import OmegaMapping
|
||||
|
||||
from .abstract import AbstractPrototypeModel
|
||||
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||
|
||||
|
||||
class GLVQ(AbstractPrototypeModel):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
|
||||
|
||||
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||
|
||||
|
||||
class GLVQ(AbstractPrototypeModel):
|
||||
@@ -18,6 +27,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||
|
||||
# Default Values
|
||||
self.hparams.setdefault("distance", euclidean_distance)
|
||||
@@ -26,7 +36,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=self.hparams.distribution,
|
||||
initializer=self.hparams.prototype_initializer)
|
||||
initializer=prototype_initializer)
|
||||
|
||||
self.transfer_function = get_activation(self.hparams.transfer_function)
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
@@ -44,7 +54,6 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
x, y = train_batch
|
||||
x = x.view(x.size(0), -1) # flatten
|
||||
dis = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||
@@ -95,15 +104,14 @@ class LVQ21(GLVQ):
|
||||
self.optimizer = torch.optim.SGD
|
||||
|
||||
|
||||
class ImageGLVQ(GLVQ):
|
||||
class ImageGLVQ(GLVQ, PrototypeImageModel):
|
||||
"""GLVQ for training on image data.
|
||||
|
||||
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
after updates.
|
||||
|
||||
"""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
pass
|
||||
|
||||
|
||||
class SiameseGLVQ(GLVQ):
|
||||
@@ -235,6 +243,7 @@ class GMLVQ(GLVQ):
|
||||
|
||||
def forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = get_flat(x, protos)
|
||||
latent_x = self.omega_layer(x)
|
||||
latent_protos = self.omega_layer(protos)
|
||||
dis = squared_euclidean_distance(latent_x, latent_protos)
|
||||
@@ -256,6 +265,16 @@ class GMLVQ(GLVQ):
|
||||
return y_pred.numpy()
|
||||
|
||||
|
||||
class ImageGMLVQ(GMLVQ, PrototypeImageModel):
|
||||
"""GMLVQ for training on image data.
|
||||
|
||||
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
after updates.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LVQMLN(GLVQ):
|
||||
"""Learning Vector Quantization Multi-Layer Network.
|
||||
|
||||
|
Reference in New Issue
Block a user