Update _get_dp_dm

This commit is contained in:
Jensun Ravichandran 2021-05-18 13:09:11 +02:00
parent 503ef0e05f
commit b935e9caf3

View File

@ -13,7 +13,7 @@ def _get_matcher(targets, labels):
return matcher return matcher
def _get_dp_dm(distances, targets, plabels): def _get_dp_dm(distances, targets, plabels, with_indices=False):
"""Returns the d+ and d- values for a batch of distances.""" """Returns the d+ and d- values for a batch of distances."""
matcher = _get_matcher(targets, plabels) matcher = _get_matcher(targets, plabels)
not_matcher = torch.bitwise_not(matcher) not_matcher = torch.bitwise_not(matcher)
@ -21,9 +21,11 @@ def _get_dp_dm(distances, targets, plabels):
inf = torch.full_like(distances, fill_value=float("inf")) inf = torch.full_like(distances, fill_value=float("inf"))
d_matching = torch.where(matcher, distances, inf) d_matching = torch.where(matcher, distances, inf)
d_unmatching = torch.where(not_matcher, distances, inf) d_unmatching = torch.where(not_matcher, distances, inf)
dp = torch.min(d_matching, dim=1, keepdim=True).values dp = torch.min(d_matching, dim=-1, keepdim=True)
dm = torch.min(d_unmatching, dim=1, keepdim=True).values dm = torch.min(d_unmatching, dim=-1, keepdim=True)
if with_indices:
return dp, dm return dp, dm
return dp.values, dm.values
def glvq_loss(distances, target_labels, prototype_labels): def glvq_loss(distances, target_labels, prototype_labels):