Fix divide-by-zero in example
This commit is contained in:
@@ -15,6 +15,6 @@ class GLVQLoss(torch.nn.Module):
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
mu = glvq_loss(distances, targets, plabels)
|
||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
||||
|
Reference in New Issue
Block a user