Fix divide-by-zero in example

This commit is contained in:
Jensun Ravichandran
2020-09-23 15:29:26 +02:00
parent 3e6aa6a20b
commit d5ab9c3771
3 changed files with 19 additions and 11 deletions

View File

@@ -15,6 +15,6 @@ class GLVQLoss(torch.nn.Module):
def forward(self, outputs, targets):
distances, plabels = outputs
mu = glvq_loss(distances, targets, plabels)
mu = glvq_loss(distances, targets, prototype_labels=plabels)
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
return torch.sum(batch_loss, dim=0)