[refactor] Move probabilistic to Prototorch
This commit is contained in:
parent
e3392ee952
commit
a60337ff27
@ -2,44 +2,13 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.competitions import stratified_sum
|
from prototorch.functions.competitions import stratified_sum
|
||||||
|
from prototorch.functions.losses import (log_likelihood_ratio_loss,
|
||||||
|
robust_soft_loss)
|
||||||
from prototorch.functions.transform import gaussian
|
from prototorch.functions.transform import gaussian
|
||||||
|
|
||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
|
||||||
|
|
||||||
def likelihood_loss(probabilities, target, prototype_labels):
|
|
||||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
|
||||||
labels = target.tolist()
|
|
||||||
|
|
||||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
|
||||||
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
|
|
||||||
|
|
||||||
whole_probability = probabilities.sum(dim=1)
|
|
||||||
correct_probability = probabilities[torch.arange(len(probabilities)),
|
|
||||||
target_indices]
|
|
||||||
wrong_probability = whole_probability - correct_probability
|
|
||||||
|
|
||||||
likelihood = correct_probability / wrong_probability
|
|
||||||
log_likelihood = torch.log(likelihood)
|
|
||||||
return log_likelihood
|
|
||||||
|
|
||||||
|
|
||||||
def robust_soft_loss(probabilities, target, prototype_labels):
|
|
||||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
|
||||||
labels = target.tolist()
|
|
||||||
|
|
||||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
|
||||||
target_indices = torch.LongTensor(list(map(key_val.get, labels)))
|
|
||||||
|
|
||||||
whole_probability = probabilities.sum(dim=1)
|
|
||||||
correct_probability = probabilities[torch.arange(len(probabilities)),
|
|
||||||
target_indices]
|
|
||||||
|
|
||||||
likelihood = correct_probability / whole_probability
|
|
||||||
log_likelihood = torch.log(likelihood)
|
|
||||||
return log_likelihood
|
|
||||||
|
|
||||||
|
|
||||||
class ProbabilisticLVQ(GLVQ):
|
class ProbabilisticLVQ(GLVQ):
|
||||||
def __init__(self, hparams, rejection_confidence=1.0, **kwargs):
|
def __init__(self, hparams, rejection_confidence=1.0, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
@ -78,17 +47,17 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
class LikelihoodRatioLVQ(ProbabilisticLVQ):
|
class LikelihoodRatioLVQ(ProbabilisticLVQ):
|
||||||
"""Learning Vector Quantization based on Likelihood Ratios
|
"""Learning Vector Quantization based on Likelihood Ratios
|
||||||
"""
|
"""
|
||||||
@property
|
def __init__(self, *args, **kwargs):
|
||||||
def loss_fn(self):
|
super().__init__(*args, **kwargs)
|
||||||
return likelihood_loss
|
self.loss_fn = log_likelihood_ratio_loss
|
||||||
|
|
||||||
|
|
||||||
class RSLVQ(ProbabilisticLVQ):
|
class RSLVQ(ProbabilisticLVQ):
|
||||||
"""Learning Vector Quantization based on Likelihood Ratios
|
"""Learning Vector Quantization based on Likelihood Ratios
|
||||||
"""
|
"""
|
||||||
@property
|
def __init__(self, *args, **kwargs):
|
||||||
def loss_fn(self):
|
super().__init__(*args, **kwargs)
|
||||||
return robust_soft_loss
|
self.loss_fn = robust_soft_loss
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["LikelihoodRatioLVQ", "RSLVQ"]
|
__all__ = ["LikelihoodRatioLVQ", "RSLVQ"]
|
||||||
|
Loading…
Reference in New Issue
Block a user