Add loss transfer function to glvq
This commit is contained in:
parent
f402eea884
commit
d644114090
@ -35,6 +35,8 @@ if __name__ == "__main__":
|
||||
prototype_initializer=cinit.SSI(torch.Tensor(x_train),
|
||||
torch.Tensor(y_train),
|
||||
noise=1e-7),
|
||||
transfer_function="sigmoid_beta",
|
||||
transfer_beta=10.0,
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
|
@ -2,6 +2,7 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (euclidean_distance,
|
||||
squared_euclidean_distance)
|
||||
@ -21,11 +22,14 @@ class GLVQ(AbstractPrototypeModel):
|
||||
# Default Values
|
||||
self.hparams.setdefault("distance", euclidean_distance)
|
||||
self.hparams.setdefault("optimizer", torch.optim.Adam)
|
||||
self.hparams.setdefault("transfer_function", "identity")
|
||||
self.hparams.setdefault("transfer_beta", 10.0)
|
||||
|
||||
self.proto_layer = LabeledComponents(
|
||||
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
|
||||
initializer=self.hparams.prototype_initializer)
|
||||
|
||||
self.transfer_function = get_activation(self.hparams.transfer_function)
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
|
||||
@property
|
||||
@ -43,7 +47,9 @@ class GLVQ(AbstractPrototypeModel):
|
||||
dis = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = glvq_loss(dis, y, prototype_labels=plabels)
|
||||
loss = mu.sum(dim=0)
|
||||
batch_loss = self.transfer_function(mu,
|
||||
beta=self.hparams.transfer_beta)
|
||||
loss = batch_loss.sum(dim=0)
|
||||
|
||||
# Compute training accuracy
|
||||
with torch.no_grad():
|
||||
|
Loading…
Reference in New Issue
Block a user