refactor: minor changes in probabilistic.py

This commit is contained in:
Jensun Ravichandran 2021-08-06 13:49:29 +02:00
parent cb7fb91c95
commit aeb6417c28
No known key found for this signature in database
GPG Key ID: 3331B0F18B6D4D93

View File

@ -1,5 +1,4 @@
"""Probabilistic GLVQ methods""" """Probabilistic GLVQ methods"""
import torch import torch
from ..core.losses import nllr_loss, rslvq_loss from ..core.losses import nllr_loss, rslvq_loss
@ -32,7 +31,7 @@ class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs): def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.conditional_distribution = None self.conditional_distribution = GaussianPrior(self.hparams.variance)
self.rejection_confidence = rejection_confidence self.rejection_confidence = rejection_confidence
def forward(self, x): def forward(self, x):
@ -56,8 +55,9 @@ class ProbabilisticLVQ(GLVQ):
out = self.forward(x) out = self.forward(x)
plabels = self.proto_layer.labels plabels = self.proto_layer.labels
batch_loss = self.loss(out, y, plabels) batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum() train_loss = batch_loss.sum()
return loss self.log("train_loss", train_loss)
return train_loss
class SLVQ(ProbabilisticLVQ): class SLVQ(ProbabilisticLVQ):
@ -65,7 +65,6 @@ class SLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.loss = LossLayer(nllr_loss) self.loss = LossLayer(nllr_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class RSLVQ(ProbabilisticLVQ): class RSLVQ(ProbabilisticLVQ):
@ -73,7 +72,6 @@ class RSLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.loss = LossLayer(rslvq_loss) self.loss = LossLayer(rslvq_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ): class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):