[refactor] Move probabilistic to Prototorch
This commit is contained in:
parent
e3392ee952
commit
a60337ff27
@ -2,44 +2,13 @@
|
||||
|
||||
import torch
|
||||
from prototorch.functions.competitions import stratified_sum
|
||||
from prototorch.functions.losses import (log_likelihood_ratio_loss,
|
||||
robust_soft_loss)
|
||||
from prototorch.functions.transform import gaussian
|
||||
|
||||
from .glvq import GLVQ
|
||||
|
||||
|
||||
def likelihood_loss(probabilities, target, prototype_labels):
|
||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
||||
labels = target.tolist()
|
||||
|
||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
|
||||
|
||||
whole_probability = probabilities.sum(dim=1)
|
||||
correct_probability = probabilities[torch.arange(len(probabilities)),
|
||||
target_indices]
|
||||
wrong_probability = whole_probability - correct_probability
|
||||
|
||||
likelihood = correct_probability / wrong_probability
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return log_likelihood
|
||||
|
||||
|
||||
def robust_soft_loss(probabilities, target, prototype_labels):
|
||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
||||
labels = target.tolist()
|
||||
|
||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
|
||||
|
||||
whole_probability = probabilities.sum(dim=1)
|
||||
correct_probability = probabilities[torch.arange(len(probabilities)),
|
||||
target_indices]
|
||||
|
||||
likelihood = correct_probability / whole_probability
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return log_likelihood
|
||||
|
||||
|
||||
class ProbabilisticLVQ(GLVQ):
|
||||
def __init__(self, hparams, rejection_confidence=1.0, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
@ -78,17 +47,17 @@ class ProbabilisticLVQ(GLVQ):
|
||||
class LikelihoodRatioLVQ(ProbabilisticLVQ):
|
||||
"""Learning Vector Quantization based on Likelihood Ratios
|
||||
"""
|
||||
@property
|
||||
def loss_fn(self):
|
||||
return likelihood_loss
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss_fn = log_likelihood_ratio_loss
|
||||
|
||||
|
||||
class RSLVQ(ProbabilisticLVQ):
|
||||
"""Learning Vector Quantization based on Likelihood Ratios
|
||||
"""
|
||||
@property
|
||||
def loss_fn(self):
|
||||
return robust_soft_loss
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss_fn = robust_soft_loss
|
||||
|
||||
|
||||
__all__ = ["LikelihoodRatioLVQ", "RSLVQ"]
|
||||
|
Loading…
Reference in New Issue
Block a user