diff --git a/prototorch/functions/losses.py b/prototorch/functions/losses.py index cc9df27..df29862 100644 --- a/prototorch/functions/losses.py +++ b/prototorch/functions/losses.py @@ -57,3 +57,37 @@ def lvq21_loss(distances, target_labels, prototype_labels): mu = dp - dm return mu + + +# Probabilistic +def log_likelihood_ratio_loss(probabilities, target, prototype_labels): + uniques = prototype_labels.unique(sorted=True).tolist() + labels = target.tolist() + + key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} + target_indices = torch.LongTensor(list(map(key_val.get, labels))) + + whole_probability = probabilities.sum(dim=1) + correct_probability = probabilities[torch.arange(len(probabilities)), + target_indices] + wrong_probability = whole_probability - correct_probability + + likelihood = correct_probability / wrong_probability + log_likelihood = torch.log(likelihood) + return log_likelihood + + +def robust_soft_loss(probabilities, target, prototype_labels): + uniques = prototype_labels.unique(sorted=True).tolist() + labels = target.tolist() + + key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} + target_indices = torch.LongTensor(list(map(key_val.get, labels))) + + whole_probability = probabilities.sum(dim=1) + correct_probability = probabilities[torch.arange(len(probabilities)), + target_indices] + + likelihood = correct_probability / whole_probability + log_likelihood = torch.log(likelihood) + return log_likelihood