Add LVQ 1 and LVQ 2.1 loss functions.
This commit is contained in:
parent
2b676ee06e
commit
c1c21e92df
@ -31,3 +31,26 @@ def glvq_loss(distances, target_labels, prototype_labels):
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = (dp - dm) / (dp + dm)
|
||||
return mu
|
||||
|
||||
|
||||
def lvq1_loss(distances, target_labels, prototype_labels):
|
||||
"""LVQ1 loss function with support for one-hot labels.
|
||||
|
||||
See Section 4 [Sado&Yamada]
|
||||
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||
"""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = dp
|
||||
mu[dp > dm] = -dm[dp > dm]
|
||||
return mu
|
||||
|
||||
|
||||
def lvq21_loss(distances, target_labels, prototype_labels):
|
||||
"""LVQ2.1 loss function with support for one-hot labels.
|
||||
|
||||
See Section 4 [Sado&Yamada]
|
||||
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||
"""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = dp - dm
|
||||
return mu
|
Loading…
Reference in New Issue
Block a user