[REFACTOR] Probabilistic loss signs changed

This commit is contained in:
Alexander Engelsberger 2021-06-03 14:00:47 +02:00
parent 5918f1cc21
commit 459f7c24be

View File

@ -2,7 +2,8 @@
import torch
from prototorch.functions.competitions import stratified_min, stratified_sum
from prototorch.functions.losses import log_likelihood_ratio_loss, robust_soft_loss
from prototorch.functions.losses import (log_likelihood_ratio_loss,
robust_soft_loss)
from prototorch.functions.transforms import gaussian
from .glvq import GLVQ
@ -51,7 +52,7 @@ class ProbabilisticLVQ(GLVQ):
X, y = batch
out = self.forward(X)
plabels = self.proto_layer.component_labels
batch_loss = -self.loss_fn(out, y, plabels)
batch_loss = self.loss_fn(out, y, plabels)
loss = batch_loss.sum(dim=0)
return loss