Fix zero-distance bug in glvq_loss

This commit is contained in:
blackfly 2020-04-08 22:46:08 +02:00
parent 7d5ab81dbf
commit b19cbcb76a
2 changed files with 3 additions and 6 deletions

View File

@ -12,12 +12,9 @@ def glvq_loss(distances, target_labels, prototype_labels):
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses) matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
not_matcher = torch.bitwise_not(matcher) not_matcher = torch.bitwise_not(matcher)
dplus_criterion = distances * matcher > 0.0
dminus_criterion = distances * not_matcher > 0.0
inf = torch.full_like(distances, fill_value=float('inf')) inf = torch.full_like(distances, fill_value=float('inf'))
distances_to_wpluses = torch.where(dplus_criterion, distances, inf) distances_to_wpluses = torch.where(matcher, distances, inf)
distances_to_wminuses = torch.where(dminus_criterion, distances, inf) distances_to_wminuses = torch.where(not_matcher, distances, inf)
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values

View File

@ -12,7 +12,7 @@ class GLVQLoss(torch.nn.Module):
super().__init__(**kwargs) super().__init__(**kwargs)
self.margin = margin self.margin = margin
self.squashing = get_activation(squashing) self.squashing = get_activation(squashing)
self.beta = beta self.beta = torch.tensor(beta)
def forward(self, outputs, targets): def forward(self, outputs, targets):
distances, plabels = outputs distances, plabels = outputs