Add loss transfer function to glvq
This commit is contained in:
		@@ -35,6 +35,8 @@ if __name__ == "__main__":
 | 
			
		||||
        prototype_initializer=cinit.SSI(torch.Tensor(x_train),
 | 
			
		||||
                                        torch.Tensor(y_train),
 | 
			
		||||
                                        noise=1e-7),
 | 
			
		||||
        transfer_function="sigmoid_beta",
 | 
			
		||||
        transfer_beta=10.0,
 | 
			
		||||
        lr=0.01,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
import torchmetrics
 | 
			
		||||
from prototorch.components import LabeledComponents
 | 
			
		||||
from prototorch.functions.activations import get_activation
 | 
			
		||||
from prototorch.functions.competitions import wtac
 | 
			
		||||
from prototorch.functions.distances import (euclidean_distance,
 | 
			
		||||
                                            squared_euclidean_distance)
 | 
			
		||||
@@ -21,11 +22,14 @@ class GLVQ(AbstractPrototypeModel):
 | 
			
		||||
        # Default Values
 | 
			
		||||
        self.hparams.setdefault("distance", euclidean_distance)
 | 
			
		||||
        self.hparams.setdefault("optimizer", torch.optim.Adam)
 | 
			
		||||
        self.hparams.setdefault("transfer_function", "identity")
 | 
			
		||||
        self.hparams.setdefault("transfer_beta", 10.0)
 | 
			
		||||
 | 
			
		||||
        self.proto_layer = LabeledComponents(
 | 
			
		||||
            labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
 | 
			
		||||
            initializer=self.hparams.prototype_initializer)
 | 
			
		||||
 | 
			
		||||
        self.transfer_function = get_activation(self.hparams.transfer_function)
 | 
			
		||||
        self.train_acc = torchmetrics.Accuracy()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
@@ -43,7 +47,9 @@ class GLVQ(AbstractPrototypeModel):
 | 
			
		||||
        dis = self(x)
 | 
			
		||||
        plabels = self.proto_layer.component_labels
 | 
			
		||||
        mu = glvq_loss(dis, y, prototype_labels=plabels)
 | 
			
		||||
        loss = mu.sum(dim=0)
 | 
			
		||||
        batch_loss = self.transfer_function(mu,
 | 
			
		||||
                                            beta=self.hparams.transfer_beta)
 | 
			
		||||
        loss = batch_loss.sum(dim=0)
 | 
			
		||||
 | 
			
		||||
        # Compute training accuracy
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user