56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
"""ProtoTorch loss functions."""
|
|
|
|
import torch
|
|
|
|
|
|
def _get_matcher(targets, labels):
|
|
"""Returns a boolean tensor."""
|
|
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
|
if labels.ndim == 2:
|
|
# if the labels are one-hot vectors
|
|
nclasses = targets.size()[1]
|
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
|
return matcher
|
|
|
|
|
|
def _get_dp_dm(distances, targets, plabels):
|
|
"""Returns the d+ and d- values for a batch of distances."""
|
|
matcher = _get_matcher(targets, plabels)
|
|
not_matcher = torch.bitwise_not(matcher)
|
|
|
|
inf = torch.full_like(distances, fill_value=float("inf"))
|
|
d_matching = torch.where(matcher, distances, inf)
|
|
d_unmatching = torch.where(not_matcher, distances, inf)
|
|
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
|
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
|
|
return dp, dm
|
|
|
|
|
|
def glvq_loss(distances, target_labels, prototype_labels):
|
|
"""GLVQ loss function with support for one-hot labels."""
|
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
|
mu = (dp - dm) / (dp + dm)
|
|
return mu
|
|
|
|
|
|
def lvq1_loss(distances, target_labels, prototype_labels):
|
|
"""LVQ1 loss function with support for one-hot labels.
|
|
|
|
See Section 4 [Sado&Yamada]
|
|
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
|
"""
|
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
|
mu = dp
|
|
mu[dp > dm] = -dm[dp > dm]
|
|
return mu
|
|
|
|
|
|
def lvq21_loss(distances, target_labels, prototype_labels):
|
|
"""LVQ2.1 loss function with support for one-hot labels.
|
|
|
|
See Section 4 [Sado&Yamada]
|
|
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
|
"""
|
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
|
mu = dp - dm
|
|
return mu |