diff --git a/examples/rslvq_iris.py b/examples/rslvq_iris.py index d70a763..c7d3961 100644 --- a/examples/rslvq_iris.py +++ b/examples/rslvq_iris.py @@ -5,7 +5,6 @@ import argparse import prototorch as pt import pytorch_lightning as pl import torch -from torchvision.transforms import Lambda if __name__ == "__main__": # Command-line arguments @@ -27,19 +26,17 @@ if __name__ == "__main__": distribution=[2, 2, 3], proto_lr=0.05, lambd=0.1, + variance=1.0, input_dim=2, latent_dim=2, bb_lr=0.01, ) # Initialize the model - model = pt.models.probabilistic.PLVQ( + model = pt.models.RSLVQ( hparams, optimizer=torch.optim.Adam, - # prototype_initializer=pt.components.SMI(train_ds), - prototype_initializer=pt.components.SSI(train_ds, noise=0.2), - # prototype_initializer=pt.components.Zeros(2), - # prototype_initializer=pt.components.Ones(2, scale=2.0), + prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2), ) # Compute intermediate input and output sizes @@ -49,7 +46,7 @@ if __name__ == "__main__": print(model) # Callbacks - vis = pt.models.VisSiameseGLVQ2D(data=train_ds) + vis = pt.models.VisGLVQ2D(data=train_ds) # Setup trainer trainer = pl.Trainer.from_argparse_args( diff --git a/prototorch/models/probabilistic.py b/prototorch/models/probabilistic.py index 6129bcf..d7276f3 100644 --- a/prototorch/models/probabilistic.py +++ b/prototorch/models/probabilistic.py @@ -54,7 +54,7 @@ class ProbabilisticLVQ(GLVQ): def training_step(self, batch, batch_idx, optimizer_idx=None): x, y = batch out = self.forward(x) - plabels = self.proto_layer.component_labels + plabels = self.proto_layer.labels batch_loss = self.loss(out, y, plabels) loss = batch_loss.sum(dim=0) return loss @@ -87,11 +87,10 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ): self.hparams.lambd) self.loss = torch.nn.KLDivLoss() - def training_step(self, batch, batch_idx, optimizer_idx=None): - x, y = batch - out = self.forward(x) - y_dist = torch.nn.functional.one_hot( - y.long(), num_classes=self.num_classes).float() - batch_loss = self.loss(out, y_dist) - loss = batch_loss.sum(dim=0) - return loss + # FIXME + # def training_step(self, batch, batch_idx, optimizer_idx=None): + # x, y = batch + # y_pred = self(x) + # batch_loss = self.loss(y_pred, y) + # loss = batch_loss.sum(dim=0) + # return loss