From d644114090f56aae095ca38c72519178940c4284 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 4 May 2021 20:56:16 +0200 Subject: [PATCH] Add loss transfer function to glvq --- examples/glvq_spiral.py | 2 ++ prototorch/models/glvq.py | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py index a2d3571..ccfa191 100644 --- a/examples/glvq_spiral.py +++ b/examples/glvq_spiral.py @@ -35,6 +35,8 @@ if __name__ == "__main__": prototype_initializer=cinit.SSI(torch.Tensor(x_train), torch.Tensor(y_train), noise=1e-7), + transfer_function="sigmoid_beta", + transfer_beta=10.0, lr=0.01, ) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 6f851ca..fa0a46a 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -2,6 +2,7 @@ import pytorch_lightning as pl import torch import torchmetrics from prototorch.components import LabeledComponents +from prototorch.functions.activations import get_activation from prototorch.functions.competitions import wtac from prototorch.functions.distances import (euclidean_distance, squared_euclidean_distance) @@ -21,11 +22,14 @@ class GLVQ(AbstractPrototypeModel): # Default Values self.hparams.setdefault("distance", euclidean_distance) self.hparams.setdefault("optimizer", torch.optim.Adam) + self.hparams.setdefault("transfer_function", "identity") + self.hparams.setdefault("transfer_beta", 10.0) self.proto_layer = LabeledComponents( labels=(self.hparams.nclasses, self.hparams.prototypes_per_class), initializer=self.hparams.prototype_initializer) + self.transfer_function = get_activation(self.hparams.transfer_function) self.train_acc = torchmetrics.Accuracy() @property @@ -43,7 +47,9 @@ class GLVQ(AbstractPrototypeModel): dis = self(x) plabels = self.proto_layer.component_labels mu = glvq_loss(dis, y, prototype_labels=plabels) - loss = mu.sum(dim=0) + batch_loss = self.transfer_function(mu, + beta=self.hparams.transfer_beta) + loss = batch_loss.sum(dim=0) # Compute training accuracy with torch.no_grad():