Add neural gas energy function as loss.
This commit is contained in:
parent
c88f288d12
commit
4540c8848e
@ -18,3 +18,23 @@ class GLVQLoss(torch.nn.Module):
|
||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
||||
|
||||
|
||||
class NeuralGasEnergy(torch.nn.Module):
|
||||
def __init__(self, lm):
|
||||
super().__init__()
|
||||
self.lm = lm
|
||||
|
||||
def forward(self, d):
|
||||
order = torch.argsort(d, dim=1)
|
||||
ranks = torch.argsort(order, dim=1)
|
||||
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
|
||||
|
||||
return cost, order
|
||||
|
||||
def extra_repr(self):
|
||||
return f"lambda: {self.lm}"
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, lm):
|
||||
return torch.exp(-rankings / lm)
|
Loading…
Reference in New Issue
Block a user