Update _get_dp_dm
This commit is contained in:
parent
503ef0e05f
commit
b935e9caf3
@ -13,7 +13,7 @@ def _get_matcher(targets, labels):
|
||||
return matcher
|
||||
|
||||
|
||||
def _get_dp_dm(distances, targets, plabels):
|
||||
def _get_dp_dm(distances, targets, plabels, with_indices=False):
|
||||
"""Returns the d+ and d- values for a batch of distances."""
|
||||
matcher = _get_matcher(targets, plabels)
|
||||
not_matcher = torch.bitwise_not(matcher)
|
||||
@ -21,9 +21,11 @@ def _get_dp_dm(distances, targets, plabels):
|
||||
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||
d_matching = torch.where(matcher, distances, inf)
|
||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
||||
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
|
||||
dp = torch.min(d_matching, dim=-1, keepdim=True)
|
||||
dm = torch.min(d_unmatching, dim=-1, keepdim=True)
|
||||
if with_indices:
|
||||
return dp, dm
|
||||
return dp.values, dm.values
|
||||
|
||||
|
||||
def glvq_loss(distances, target_labels, prototype_labels):
|
||||
|
Loading…
Reference in New Issue
Block a user