[BUG] PLVQ seems broken

This commit is contained in:
Jensun Ravichandran
2021-06-14 20:56:38 +02:00
parent 24ebfdc667
commit a44219ee47
2 changed files with 12 additions and 16 deletions

View File

@@ -54,7 +54,7 @@ class ProbabilisticLVQ(GLVQ):
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.forward(x)
plabels = self.proto_layer.component_labels
plabels = self.proto_layer.labels
batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum(dim=0)
return loss
@@ -87,11 +87,10 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
self.hparams.lambd)
self.loss = torch.nn.KLDivLoss()
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.forward(x)
y_dist = torch.nn.functional.one_hot(
y.long(), num_classes=self.num_classes).float()
batch_loss = self.loss(out, y_dist)
loss = batch_loss.sum(dim=0)
return loss
# FIXME
# def training_step(self, batch, batch_idx, optimizer_idx=None):
# x, y = batch
# y_pred = self(x)
# batch_loss = self.loss(y_pred, y)
# loss = batch_loss.sum(dim=0)
# return loss