Update _get_dp_dm
This commit is contained in:
parent
503ef0e05f
commit
b935e9caf3
@ -13,7 +13,7 @@ def _get_matcher(targets, labels):
|
|||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
def _get_dp_dm(distances, targets, plabels):
|
def _get_dp_dm(distances, targets, plabels, with_indices=False):
|
||||||
"""Returns the d+ and d- values for a batch of distances."""
|
"""Returns the d+ and d- values for a batch of distances."""
|
||||||
matcher = _get_matcher(targets, plabels)
|
matcher = _get_matcher(targets, plabels)
|
||||||
not_matcher = torch.bitwise_not(matcher)
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
@ -21,9 +21,11 @@ def _get_dp_dm(distances, targets, plabels):
|
|||||||
inf = torch.full_like(distances, fill_value=float("inf"))
|
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||||
d_matching = torch.where(matcher, distances, inf)
|
d_matching = torch.where(matcher, distances, inf)
|
||||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||||
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
dp = torch.min(d_matching, dim=-1, keepdim=True)
|
||||||
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
|
dm = torch.min(d_unmatching, dim=-1, keepdim=True)
|
||||||
return dp, dm
|
if with_indices:
|
||||||
|
return dp, dm
|
||||||
|
return dp.values, dm.values
|
||||||
|
|
||||||
|
|
||||||
def glvq_loss(distances, target_labels, prototype_labels):
|
def glvq_loss(distances, target_labels, prototype_labels):
|
||||||
|
Loading…
Reference in New Issue
Block a user