[REFACTOR] Probabilistic losses
This commit is contained in:
parent
b724a28a6f
commit
ca8ac7a43b
@ -60,34 +60,33 @@ def lvq21_loss(distances, target_labels, prototype_labels):
|
|||||||
|
|
||||||
|
|
||||||
# Probabilistic
|
# Probabilistic
|
||||||
def log_likelihood_ratio_loss(probabilities, target, prototype_labels):
|
def _get_class_probabilities(probabilities, targets, prototype_labels):
|
||||||
|
# Create Label Mapping
|
||||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
uniques = prototype_labels.unique(sorted=True).tolist()
|
||||||
labels = target.tolist()
|
|
||||||
|
|
||||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||||
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
|
|
||||||
|
|
||||||
whole_probability = probabilities.sum(dim=1)
|
target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist())))
|
||||||
correct_probability = probabilities[torch.arange(len(probabilities)),
|
|
||||||
target_indices]
|
|
||||||
wrong_probability = whole_probability - correct_probability
|
|
||||||
|
|
||||||
likelihood = correct_probability / wrong_probability
|
whole = probabilities.sum(dim=1)
|
||||||
|
correct = probabilities[torch.arange(len(probabilities)), target_indices]
|
||||||
|
wrong = whole - correct
|
||||||
|
|
||||||
|
return whole, correct, wrong
|
||||||
|
|
||||||
|
|
||||||
|
def log_likelihood_ratio_loss(probabilities, targets, prototype_labels):
|
||||||
|
_, correct, wrong = _get_class_probabilities(probabilities, targets,
|
||||||
|
prototype_labels)
|
||||||
|
|
||||||
|
likelihood = correct / wrong
|
||||||
log_likelihood = torch.log(likelihood)
|
log_likelihood = torch.log(likelihood)
|
||||||
return log_likelihood
|
return -1.0 * log_likelihood
|
||||||
|
|
||||||
|
|
||||||
def robust_soft_loss(probabilities, target, prototype_labels):
|
def robust_soft_loss(probabilities, targets, prototype_labels):
|
||||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
whole, correct, _ = _get_class_probabilities(probabilities, targets,
|
||||||
labels = target.tolist()
|
prototype_labels)
|
||||||
|
|
||||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
likelihood = correct / whole
|
||||||
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
|
|
||||||
|
|
||||||
whole_probability = probabilities.sum(dim=1)
|
|
||||||
correct_probability = probabilities[torch.arange(len(probabilities)),
|
|
||||||
target_indices]
|
|
||||||
|
|
||||||
likelihood = correct_probability / whole_probability
|
|
||||||
log_likelihood = torch.log(likelihood)
|
log_likelihood = torch.log(likelihood)
|
||||||
return log_likelihood
|
return -1.0 * log_likelihood
|
||||||
|
Loading…
Reference in New Issue
Block a user