Fix zero-distance bug in glvq_loss

This commit is contained in:
blackfly
2020-04-08 22:46:08 +02:00
parent 7d5ab81dbf
commit b19cbcb76a
2 changed files with 3 additions and 6 deletions

View File

@@ -12,7 +12,7 @@ class GLVQLoss(torch.nn.Module):
super().__init__(**kwargs)
self.margin = margin
self.squashing = get_activation(squashing)
self.beta = beta
self.beta = torch.tensor(beta)
def forward(self, outputs, targets):
distances, plabels = outputs