[REFACTOR] Probabilistic losses

This commit is contained in:
Alexander Engelsberger 2021-06-03 14:01:13 +02:00
parent b724a28a6f
commit ca8ac7a43b

View File

@ -60,34 +60,33 @@ def lvq21_loss(distances, target_labels, prototype_labels):
# Probabilistic
def log_likelihood_ratio_loss(probabilities, target, prototype_labels):
def _get_class_probabilities(probabilities, targets, prototype_labels):
# Create Label Mapping
uniques = prototype_labels.unique(sorted=True).tolist()
labels = target.tolist()
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
whole_probability = probabilities.sum(dim=1)
correct_probability = probabilities[torch.arange(len(probabilities)),
target_indices]
wrong_probability = whole_probability - correct_probability
target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist())))
likelihood = correct_probability / wrong_probability
whole = probabilities.sum(dim=1)
correct = probabilities[torch.arange(len(probabilities)), target_indices]
wrong = whole - correct
return whole, correct, wrong
def log_likelihood_ratio_loss(probabilities, targets, prototype_labels):
_, correct, wrong = _get_class_probabilities(probabilities, targets,
prototype_labels)
likelihood = correct / wrong
log_likelihood = torch.log(likelihood)
return log_likelihood
return -1.0 * log_likelihood
def robust_soft_loss(probabilities, target, prototype_labels):
uniques = prototype_labels.unique(sorted=True).tolist()
labels = target.tolist()
def robust_soft_loss(probabilities, targets, prototype_labels):
whole, correct, _ = _get_class_probabilities(probabilities, targets,
prototype_labels)
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
whole_probability = probabilities.sum(dim=1)
correct_probability = probabilities[torch.arange(len(probabilities)),
target_indices]
likelihood = correct_probability / whole_probability
likelihood = correct / whole
log_likelihood = torch.log(likelihood)
return log_likelihood
return -1.0 * log_likelihood