From aeb6417c28859ae460045838f5c632457259023d Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 6 Aug 2021 13:49:29 +0200 Subject: [PATCH] refactor: minor changes in `probabilistic.py` --- prototorch/models/probabilistic.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/prototorch/models/probabilistic.py b/prototorch/models/probabilistic.py index d5b5842..30ae139 100644 --- a/prototorch/models/probabilistic.py +++ b/prototorch/models/probabilistic.py @@ -1,5 +1,4 @@ """Probabilistic GLVQ methods""" - import torch from ..core.losses import nllr_loss, rslvq_loss @@ -32,7 +31,7 @@ class ProbabilisticLVQ(GLVQ): def __init__(self, hparams, rejection_confidence=0.0, **kwargs): super().__init__(hparams, **kwargs) - self.conditional_distribution = None + self.conditional_distribution = GaussianPrior(self.hparams.variance) self.rejection_confidence = rejection_confidence def forward(self, x): @@ -56,8 +55,9 @@ class ProbabilisticLVQ(GLVQ): out = self.forward(x) plabels = self.proto_layer.labels batch_loss = self.loss(out, y, plabels) - loss = batch_loss.sum() - return loss + train_loss = batch_loss.sum() + self.log("train_loss", train_loss) + return train_loss class SLVQ(ProbabilisticLVQ): @@ -65,7 +65,6 @@ class SLVQ(ProbabilisticLVQ): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.loss = LossLayer(nllr_loss) - self.conditional_distribution = GaussianPrior(self.hparams.variance) class RSLVQ(ProbabilisticLVQ): @@ -73,7 +72,6 @@ class RSLVQ(ProbabilisticLVQ): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.loss = LossLayer(rslvq_loss) - self.conditional_distribution = GaussianPrior(self.hparams.variance) class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):