diff --git a/prototorch/functions/losses.py b/prototorch/functions/losses.py index 5a8fe23..b7069d0 100644 --- a/prototorch/functions/losses.py +++ b/prototorch/functions/losses.py @@ -31,3 +31,26 @@ def glvq_loss(distances, target_labels, prototype_labels): dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) mu = (dp - dm) / (dp + dm) return mu + + +def lvq1_loss(distances, target_labels, prototype_labels): + """LVQ1 loss function with support for one-hot labels. + + See Section 4 [Sado&Yamada] + https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf + """ + dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) + mu = dp + mu[dp > dm] = -dm[dp > dm] + return mu + + +def lvq21_loss(distances, target_labels, prototype_labels): + """LVQ2.1 loss function with support for one-hot labels. + + See Section 4 [Sado&Yamada] + https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf + """ + dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) + mu = dp - dm + return mu \ No newline at end of file