Refactor functions/losses.py
This commit is contained in:
parent
9a7d3192c0
commit
466e9bde6b
@ -3,12 +3,19 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _get_dp_dm(distances, targets, plabels):
|
def _get_matcher(targets, labels):
|
||||||
matcher = torch.eq(targets.unsqueeze(dim=1), plabels)
|
"""Returns a boolean tensor."""
|
||||||
if plabels.ndim == 2:
|
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
||||||
|
if labels.ndim == 2:
|
||||||
# if the labels are one-hot vectors
|
# if the labels are one-hot vectors
|
||||||
nclasses = targets.size()[1]
|
nclasses = targets.size()[1]
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dp_dm(distances, targets, plabels):
|
||||||
|
"""Returns the d+ and d- values for a batch of distances."""
|
||||||
|
matcher = _get_matcher(targets, plabels)
|
||||||
not_matcher = torch.bitwise_not(matcher)
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
|
|
||||||
inf = torch.full_like(distances, fill_value=float("inf"))
|
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||||
|
Loading…
Reference in New Issue
Block a user