[BUGFIX] Fix RSLVQ
This commit is contained in:
parent
9c1a41997b
commit
21023a88d7
@ -4,6 +4,10 @@ from torch.optim.lr_scheduler import ExponentialLR
|
|||||||
|
|
||||||
|
|
||||||
class AbstractPrototypeModel(pl.LightningModule):
|
class AbstractPrototypeModel(pl.LightningModule):
|
||||||
|
@property
|
||||||
|
def num_prototypes(self):
|
||||||
|
return len(self.proto_layer.components)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prototypes(self):
|
def prototypes(self):
|
||||||
return self.proto_layer.components.detach().cpu()
|
return self.proto_layer.components.detach().cpu()
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from prototorch.functions.competitions import stratified_sum
|
from prototorch.functions.competitions import stratified_sum
|
||||||
from prototorch.functions.losses import log_likelihood_ratio_loss, robust_soft_loss
|
from prototorch.functions.losses import log_likelihood_ratio_loss, robust_soft_loss
|
||||||
from prototorch.functions.transform import gaussian
|
from prototorch.functions.transforms import gaussian
|
||||||
|
|
||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
|
||||||
@ -15,24 +15,22 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
self.conditional_distribution = gaussian
|
self.conditional_distribution = gaussian
|
||||||
self.rejection_confidence = rejection_confidence
|
self.rejection_confidence = rejection_confidence
|
||||||
|
|
||||||
def predict(self, x):
|
|
||||||
probabilities = self.forward(x)
|
|
||||||
confidence, prediction = torch.max(probabilities, dim=1)
|
|
||||||
prediction[confidence < self.rejection_confidence] = -1
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
distances = self._forward(x)
|
distances = self._forward(x)
|
||||||
conditional = self.conditional_distribution(distances,
|
conditional = self.conditional_distribution(distances,
|
||||||
self.hparams.variance)
|
self.hparams.variance)
|
||||||
prior = 1.0 / torch.Tensor(self.proto_layer.distribution).sum().item()
|
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes)
|
||||||
posterior = conditional * prior
|
posterior = conditional * prior
|
||||||
|
plabels = self.proto_layer._labels
|
||||||
plabels = torch.LongTensor(self.proto_layer.component_labels)
|
y_pred = stratified_sum(posterior, plabels)
|
||||||
y_pred = stratified_sum(posterior.T, plabels)
|
|
||||||
|
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
|
def predict(self, x):
|
||||||
|
y_pred = self.forward(x)
|
||||||
|
confidence, prediction = torch.max(y_pred, dim=1)
|
||||||
|
prediction[confidence < self.rejection_confidence] = -1
|
||||||
|
return prediction
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
X, y = batch
|
X, y = batch
|
||||||
out = self.forward(X)
|
out = self.forward(X)
|
||||||
|
Loading…
Reference in New Issue
Block a user