Add loss transfer function to glvq

This commit is contained in:
Jensun Ravichandran 2021-05-04 20:56:16 +02:00
parent f402eea884
commit d644114090
2 changed files with 9 additions and 1 deletions

View File

@ -35,6 +35,8 @@ if __name__ == "__main__":
prototype_initializer=cinit.SSI(torch.Tensor(x_train), prototype_initializer=cinit.SSI(torch.Tensor(x_train),
torch.Tensor(y_train), torch.Tensor(y_train),
noise=1e-7), noise=1e-7),
transfer_function="sigmoid_beta",
transfer_beta=10.0,
lr=0.01, lr=0.01,
) )

View File

@ -2,6 +2,7 @@ import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.components import LabeledComponents from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, from prototorch.functions.distances import (euclidean_distance,
squared_euclidean_distance) squared_euclidean_distance)
@ -21,11 +22,14 @@ class GLVQ(AbstractPrototypeModel):
# Default Values # Default Values
self.hparams.setdefault("distance", euclidean_distance) self.hparams.setdefault("distance", euclidean_distance)
self.hparams.setdefault("optimizer", torch.optim.Adam) self.hparams.setdefault("optimizer", torch.optim.Adam)
self.hparams.setdefault("transfer_function", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class), labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
initializer=self.hparams.prototype_initializer) initializer=self.hparams.prototype_initializer)
self.transfer_function = get_activation(self.hparams.transfer_function)
self.train_acc = torchmetrics.Accuracy() self.train_acc = torchmetrics.Accuracy()
@property @property
@ -43,7 +47,9 @@ class GLVQ(AbstractPrototypeModel):
dis = self(x) dis = self(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
mu = glvq_loss(dis, y, prototype_labels=plabels) mu = glvq_loss(dis, y, prototype_labels=plabels)
loss = mu.sum(dim=0) batch_loss = self.transfer_function(mu,
beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
# Compute training accuracy # Compute training accuracy
with torch.no_grad(): with torch.no_grad():