[BUG] PLVQ seems broken

This commit is contained in:
Jensun Ravichandran 2021-06-14 20:56:38 +02:00
parent 24ebfdc667
commit a44219ee47
2 changed files with 12 additions and 16 deletions

View File

@ -5,7 +5,6 @@ import argparse
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torchvision.transforms import Lambda
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -27,19 +26,17 @@ if __name__ == "__main__":
distribution=[2, 2, 3], distribution=[2, 2, 3],
proto_lr=0.05, proto_lr=0.05,
lambd=0.1, lambd=0.1,
variance=1.0,
input_dim=2, input_dim=2,
latent_dim=2, latent_dim=2,
bb_lr=0.01, bb_lr=0.01,
) )
# Initialize the model # Initialize the model
model = pt.models.probabilistic.PLVQ( model = pt.models.RSLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
# prototype_initializer=pt.components.SMI(train_ds), prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
# prototype_initializer=pt.components.Zeros(2),
# prototype_initializer=pt.components.Ones(2, scale=2.0),
) )
# Compute intermediate input and output sizes # Compute intermediate input and output sizes
@ -49,7 +46,7 @@ if __name__ == "__main__":
print(model) print(model)
# Callbacks # Callbacks
vis = pt.models.VisSiameseGLVQ2D(data=train_ds) vis = pt.models.VisGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(

View File

@ -54,7 +54,7 @@ class ProbabilisticLVQ(GLVQ):
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self.forward(x) out = self.forward(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.labels
batch_loss = self.loss(out, y, plabels) batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum(dim=0) loss = batch_loss.sum(dim=0)
return loss return loss
@ -87,11 +87,10 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
self.hparams.lambd) self.hparams.lambd)
self.loss = torch.nn.KLDivLoss() self.loss = torch.nn.KLDivLoss()
def training_step(self, batch, batch_idx, optimizer_idx=None): # FIXME
x, y = batch # def training_step(self, batch, batch_idx, optimizer_idx=None):
out = self.forward(x) # x, y = batch
y_dist = torch.nn.functional.one_hot( # y_pred = self(x)
y.long(), num_classes=self.num_classes).float() # batch_loss = self.loss(y_pred, y)
batch_loss = self.loss(out, y_dist) # loss = batch_loss.sum(dim=0)
loss = batch_loss.sum(dim=0) # return loss
return loss