Refactor functions/losses.py
This commit is contained in:
parent
dab91e471a
commit
c11a3860df
@ -3,20 +3,24 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def glvq_loss(distances, target_labels, prototype_labels):
|
def _get_dp_dm(distances, targets, plabels):
|
||||||
"""GLVQ loss function with support for one-hot labels."""
|
matcher = torch.eq(targets.unsqueeze(dim=1), plabels)
|
||||||
matcher = torch.eq(target_labels.unsqueeze(dim=1), prototype_labels)
|
if plabels.ndim == 2:
|
||||||
if prototype_labels.ndim == 2:
|
|
||||||
# if the labels are one-hot vectors
|
# if the labels are one-hot vectors
|
||||||
nclasses = target_labels.size()[1]
|
nclasses = targets.size()[1]
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
not_matcher = torch.bitwise_not(matcher)
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
|
|
||||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
inf = torch.full_like(distances, fill_value=float('inf'))
|
||||||
distances_to_wpluses = torch.where(matcher, distances, inf)
|
d_matching = torch.where(matcher, distances, inf)
|
||||||
distances_to_wminuses = torch.where(not_matcher, distances, inf)
|
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||||
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
|
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
||||||
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
|
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
|
||||||
|
return dp, dm
|
||||||
|
|
||||||
mu = (dpluses - dminuses) / (dpluses + dminuses)
|
|
||||||
|
def glvq_loss(distances, target_labels, prototype_labels):
|
||||||
|
"""GLVQ loss function with support for one-hot labels."""
|
||||||
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||||
|
mu = (dp - dm) / (dp + dm)
|
||||||
return mu
|
return mu
|
||||||
|
Loading…
Reference in New Issue
Block a user