Fix zero-distance bug in glvq_loss
This commit is contained in:
parent
7d5ab81dbf
commit
b19cbcb76a
@ -12,12 +12,9 @@ def glvq_loss(distances, target_labels, prototype_labels):
|
|||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
not_matcher = torch.bitwise_not(matcher)
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
|
|
||||||
dplus_criterion = distances * matcher > 0.0
|
|
||||||
dminus_criterion = distances * not_matcher > 0.0
|
|
||||||
|
|
||||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
inf = torch.full_like(distances, fill_value=float('inf'))
|
||||||
distances_to_wpluses = torch.where(dplus_criterion, distances, inf)
|
distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||||
distances_to_wminuses = torch.where(dminus_criterion, distances, inf)
|
distances_to_wminuses = torch.where(not_matcher, distances, inf)
|
||||||
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
|
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
|
||||||
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
|
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ class GLVQLoss(torch.nn.Module):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.squashing = get_activation(squashing)
|
self.squashing = get_activation(squashing)
|
||||||
self.beta = beta
|
self.beta = torch.tensor(beta)
|
||||||
|
|
||||||
def forward(self, outputs, targets):
|
def forward(self, outputs, targets):
|
||||||
distances, plabels = outputs
|
distances, plabels = outputs
|
||||||
|
Loading…
Reference in New Issue
Block a user