Add loss transfer function to glvq
This commit is contained in:
parent
f402eea884
commit
d644114090
@ -35,6 +35,8 @@ if __name__ == "__main__":
|
|||||||
prototype_initializer=cinit.SSI(torch.Tensor(x_train),
|
prototype_initializer=cinit.SSI(torch.Tensor(x_train),
|
||||||
torch.Tensor(y_train),
|
torch.Tensor(y_train),
|
||||||
noise=1e-7),
|
noise=1e-7),
|
||||||
|
transfer_function="sigmoid_beta",
|
||||||
|
transfer_beta=10.0,
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import pytorch_lightning as pl
|
|||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.components import LabeledComponents
|
from prototorch.components import LabeledComponents
|
||||||
|
from prototorch.functions.activations import get_activation
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import (euclidean_distance,
|
from prototorch.functions.distances import (euclidean_distance,
|
||||||
squared_euclidean_distance)
|
squared_euclidean_distance)
|
||||||
@ -21,11 +22,14 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
# Default Values
|
# Default Values
|
||||||
self.hparams.setdefault("distance", euclidean_distance)
|
self.hparams.setdefault("distance", euclidean_distance)
|
||||||
self.hparams.setdefault("optimizer", torch.optim.Adam)
|
self.hparams.setdefault("optimizer", torch.optim.Adam)
|
||||||
|
self.hparams.setdefault("transfer_function", "identity")
|
||||||
|
self.hparams.setdefault("transfer_beta", 10.0)
|
||||||
|
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
|
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
|
||||||
initializer=self.hparams.prototype_initializer)
|
initializer=self.hparams.prototype_initializer)
|
||||||
|
|
||||||
|
self.transfer_function = get_activation(self.hparams.transfer_function)
|
||||||
self.train_acc = torchmetrics.Accuracy()
|
self.train_acc = torchmetrics.Accuracy()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -43,7 +47,9 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
dis = self(x)
|
dis = self(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
mu = glvq_loss(dis, y, prototype_labels=plabels)
|
mu = glvq_loss(dis, y, prototype_labels=plabels)
|
||||||
loss = mu.sum(dim=0)
|
batch_loss = self.transfer_function(mu,
|
||||||
|
beta=self.hparams.transfer_beta)
|
||||||
|
loss = batch_loss.sum(dim=0)
|
||||||
|
|
||||||
# Compute training accuracy
|
# Compute training accuracy
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
Loading…
Reference in New Issue
Block a user