[refactor] Move probabilistic to Prototorch

This commit is contained in:
Alexander Engelsberger 2021-05-28 20:39:32 +02:00
parent e3392ee952
commit a60337ff27

View File

@ -2,44 +2,13 @@
import torch
from prototorch.functions.competitions import stratified_sum
from prototorch.functions.losses import (log_likelihood_ratio_loss,
robust_soft_loss)
from prototorch.functions.transform import gaussian
from .glvq import GLVQ
def likelihood_loss(probabilities, target, prototype_labels):
uniques = prototype_labels.unique(sorted=True).tolist()
labels = target.tolist()
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
whole_probability = probabilities.sum(dim=1)
correct_probability = probabilities[torch.arange(len(probabilities)),
target_indices]
wrong_probability = whole_probability - correct_probability
likelihood = correct_probability / wrong_probability
log_likelihood = torch.log(likelihood)
return log_likelihood
def robust_soft_loss(probabilities, target, prototype_labels):
uniques = prototype_labels.unique(sorted=True).tolist()
labels = target.tolist()
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
whole_probability = probabilities.sum(dim=1)
correct_probability = probabilities[torch.arange(len(probabilities)),
target_indices]
likelihood = correct_probability / whole_probability
log_likelihood = torch.log(likelihood)
return log_likelihood
class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=1.0, **kwargs):
super().__init__(hparams, **kwargs)
@ -78,17 +47,17 @@ class ProbabilisticLVQ(GLVQ):
class LikelihoodRatioLVQ(ProbabilisticLVQ):
"""Learning Vector Quantization based on Likelihood Ratios
"""
@property
def loss_fn(self):
return likelihood_loss
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_fn = log_likelihood_ratio_loss
class RSLVQ(ProbabilisticLVQ):
"""Learning Vector Quantization based on Likelihood Ratios
"""
@property
def loss_fn(self):
return robust_soft_loss
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_fn = robust_soft_loss
__all__ = ["LikelihoodRatioLVQ", "RSLVQ"]