[BUGFIX] Fix RSLVQ

This commit is contained in:
Jensun Ravichandran 2021-06-01 17:44:10 +02:00
parent 9c1a41997b
commit 21023a88d7
2 changed files with 14 additions and 12 deletions

View File

@ -4,6 +4,10 @@ from torch.optim.lr_scheduler import ExponentialLR
class AbstractPrototypeModel(pl.LightningModule):
@property
def num_prototypes(self):
return len(self.proto_layer.components)
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()

View File

@ -3,7 +3,7 @@
import torch
from prototorch.functions.competitions import stratified_sum
from prototorch.functions.losses import log_likelihood_ratio_loss, robust_soft_loss
from prototorch.functions.transform import gaussian
from prototorch.functions.transforms import gaussian
from .glvq import GLVQ
@ -15,24 +15,22 @@ class ProbabilisticLVQ(GLVQ):
self.conditional_distribution = gaussian
self.rejection_confidence = rejection_confidence
def predict(self, x):
probabilities = self.forward(x)
confidence, prediction = torch.max(probabilities, dim=1)
prediction[confidence < self.rejection_confidence] = -1
return prediction
def forward(self, x):
distances = self._forward(x)
conditional = self.conditional_distribution(distances,
self.hparams.variance)
prior = 1.0 / torch.Tensor(self.proto_layer.distribution).sum().item()
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes)
posterior = conditional * prior
plabels = torch.LongTensor(self.proto_layer.component_labels)
y_pred = stratified_sum(posterior.T, plabels)
plabels = self.proto_layer._labels
y_pred = stratified_sum(posterior, plabels)
return y_pred
def predict(self, x):
y_pred = self.forward(x)
confidence, prediction = torch.max(y_pred, dim=1)
prediction[confidence < self.rejection_confidence] = -1
return prediction
def training_step(self, batch, batch_idx, optimizer_idx=None):
X, y = batch
out = self.forward(X)