[REFACTOR] Probabilistic losses
This commit is contained in:
parent
b724a28a6f
commit
ca8ac7a43b
@ -60,34 +60,33 @@ def lvq21_loss(distances, target_labels, prototype_labels):
|
||||
|
||||
|
||||
# Probabilistic
|
||||
def log_likelihood_ratio_loss(probabilities, target, prototype_labels):
|
||||
def _get_class_probabilities(probabilities, targets, prototype_labels):
|
||||
# Create Label Mapping
|
||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
||||
labels = target.tolist()
|
||||
|
||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
|
||||
|
||||
whole_probability = probabilities.sum(dim=1)
|
||||
correct_probability = probabilities[torch.arange(len(probabilities)),
|
||||
target_indices]
|
||||
wrong_probability = whole_probability - correct_probability
|
||||
target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist())))
|
||||
|
||||
likelihood = correct_probability / wrong_probability
|
||||
whole = probabilities.sum(dim=1)
|
||||
correct = probabilities[torch.arange(len(probabilities)), target_indices]
|
||||
wrong = whole - correct
|
||||
|
||||
return whole, correct, wrong
|
||||
|
||||
|
||||
def log_likelihood_ratio_loss(probabilities, targets, prototype_labels):
|
||||
_, correct, wrong = _get_class_probabilities(probabilities, targets,
|
||||
prototype_labels)
|
||||
|
||||
likelihood = correct / wrong
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return log_likelihood
|
||||
return -1.0 * log_likelihood
|
||||
|
||||
|
||||
def robust_soft_loss(probabilities, target, prototype_labels):
|
||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
||||
labels = target.tolist()
|
||||
def robust_soft_loss(probabilities, targets, prototype_labels):
|
||||
whole, correct, _ = _get_class_probabilities(probabilities, targets,
|
||||
prototype_labels)
|
||||
|
||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
|
||||
|
||||
whole_probability = probabilities.sum(dim=1)
|
||||
correct_probability = probabilities[torch.arange(len(probabilities)),
|
||||
target_indices]
|
||||
|
||||
likelihood = correct_probability / whole_probability
|
||||
likelihood = correct / whole
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return log_likelihood
|
||||
return -1.0 * log_likelihood
|
||||
|
Loading…
Reference in New Issue
Block a user