refactor: minor changes in probabilistic.py
This commit is contained in:
parent
cb7fb91c95
commit
aeb6417c28
@ -1,5 +1,4 @@
|
||||
"""Probabilistic GLVQ methods"""
|
||||
|
||||
import torch
|
||||
|
||||
from ..core.losses import nllr_loss, rslvq_loss
|
||||
@ -32,7 +31,7 @@ class ProbabilisticLVQ(GLVQ):
|
||||
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
self.conditional_distribution = None
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
self.rejection_confidence = rejection_confidence
|
||||
|
||||
def forward(self, x):
|
||||
@ -56,8 +55,9 @@ class ProbabilisticLVQ(GLVQ):
|
||||
out = self.forward(x)
|
||||
plabels = self.proto_layer.labels
|
||||
batch_loss = self.loss(out, y, plabels)
|
||||
loss = batch_loss.sum()
|
||||
return loss
|
||||
train_loss = batch_loss.sum()
|
||||
self.log("train_loss", train_loss)
|
||||
return train_loss
|
||||
|
||||
|
||||
class SLVQ(ProbabilisticLVQ):
|
||||
@ -65,7 +65,6 @@ class SLVQ(ProbabilisticLVQ):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss = LossLayer(nllr_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class RSLVQ(ProbabilisticLVQ):
|
||||
@ -73,7 +72,6 @@ class RSLVQ(ProbabilisticLVQ):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss = LossLayer(rslvq_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
|
Loading…
Reference in New Issue
Block a user