Fix zero-distance bug in glvq_loss
This commit is contained in:
parent
7d5ab81dbf
commit
b19cbcb76a
@ -12,12 +12,9 @@ def glvq_loss(distances, target_labels, prototype_labels):
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
not_matcher = torch.bitwise_not(matcher)
|
||||
|
||||
dplus_criterion = distances * matcher > 0.0
|
||||
dminus_criterion = distances * not_matcher > 0.0
|
||||
|
||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
||||
distances_to_wpluses = torch.where(dplus_criterion, distances, inf)
|
||||
distances_to_wminuses = torch.where(dminus_criterion, distances, inf)
|
||||
distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||
distances_to_wminuses = torch.where(not_matcher, distances, inf)
|
||||
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
|
||||
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
|
||||
|
||||
|
@ -12,7 +12,7 @@ class GLVQLoss(torch.nn.Module):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
self.beta = beta
|
||||
self.beta = torch.tensor(beta)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
|
Loading…
Reference in New Issue
Block a user