Fix zero-distance bug in glvq_loss
This commit is contained in:
@@ -12,7 +12,7 @@ class GLVQLoss(torch.nn.Module):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
self.beta = beta
|
||||
self.beta = torch.tensor(beta)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
|
Reference in New Issue
Block a user