Refactor functions/losses.py

This commit is contained in:
blackfly 2020-04-27 12:47:15 +02:00
parent dab91e471a
commit c11a3860df

View File

@ -3,20 +3,24 @@
import torch
def glvq_loss(distances, target_labels, prototype_labels):
"""GLVQ loss function with support for one-hot labels."""
matcher = torch.eq(target_labels.unsqueeze(dim=1), prototype_labels)
if prototype_labels.ndim == 2:
def _get_dp_dm(distances, targets, plabels):
matcher = torch.eq(targets.unsqueeze(dim=1), plabels)
if plabels.ndim == 2:
# if the labels are one-hot vectors
nclasses = target_labels.size()[1]
nclasses = targets.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
not_matcher = torch.bitwise_not(matcher)
inf = torch.full_like(distances, fill_value=float('inf'))
distances_to_wpluses = torch.where(matcher, distances, inf)
distances_to_wminuses = torch.where(not_matcher, distances, inf)
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
d_matching = torch.where(matcher, distances, inf)
d_unmatching = torch.where(not_matcher, distances, inf)
dp = torch.min(d_matching, dim=1, keepdim=True).values
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
return dp, dm
mu = (dpluses - dminuses) / (dpluses + dminuses)
def glvq_loss(distances, target_labels, prototype_labels):
"""GLVQ loss function with support for one-hot labels."""
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = (dp - dm) / (dp + dm)
return mu