[BUGFIX] Fix RSLVQ
This commit is contained in:
parent
9c1a41997b
commit
21023a88d7
@ -4,6 +4,10 @@ from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
|
||||
class AbstractPrototypeModel(pl.LightningModule):
|
||||
@property
|
||||
def num_prototypes(self):
|
||||
return len(self.proto_layer.components)
|
||||
|
||||
@property
|
||||
def prototypes(self):
|
||||
return self.proto_layer.components.detach().cpu()
|
||||
|
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
from prototorch.functions.competitions import stratified_sum
|
||||
from prototorch.functions.losses import log_likelihood_ratio_loss, robust_soft_loss
|
||||
from prototorch.functions.transform import gaussian
|
||||
from prototorch.functions.transforms import gaussian
|
||||
|
||||
from .glvq import GLVQ
|
||||
|
||||
@ -15,24 +15,22 @@ class ProbabilisticLVQ(GLVQ):
|
||||
self.conditional_distribution = gaussian
|
||||
self.rejection_confidence = rejection_confidence
|
||||
|
||||
def predict(self, x):
|
||||
probabilities = self.forward(x)
|
||||
confidence, prediction = torch.max(probabilities, dim=1)
|
||||
prediction[confidence < self.rejection_confidence] = -1
|
||||
return prediction
|
||||
|
||||
def forward(self, x):
|
||||
distances = self._forward(x)
|
||||
conditional = self.conditional_distribution(distances,
|
||||
self.hparams.variance)
|
||||
prior = 1.0 / torch.Tensor(self.proto_layer.distribution).sum().item()
|
||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes)
|
||||
posterior = conditional * prior
|
||||
|
||||
plabels = torch.LongTensor(self.proto_layer.component_labels)
|
||||
y_pred = stratified_sum(posterior.T, plabels)
|
||||
|
||||
plabels = self.proto_layer._labels
|
||||
y_pred = stratified_sum(posterior, plabels)
|
||||
return y_pred
|
||||
|
||||
def predict(self, x):
|
||||
y_pred = self.forward(x)
|
||||
confidence, prediction = torch.max(y_pred, dim=1)
|
||||
prediction[confidence < self.rejection_confidence] = -1
|
||||
return prediction
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
X, y = batch
|
||||
out = self.forward(X)
|
||||
|
Loading…
Reference in New Issue
Block a user