From 68034d56f6f57c3b8c1758a7347c3cf0f99167e3 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 14 Jun 2021 20:13:25 +0200 Subject: [PATCH] [BUGFIX] `examples/glvq_iris.py` works again --- examples/glvq_iris.py | 4 ++-- prototorch/models/glvq.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 486ef5e..f9556ae 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -23,7 +23,7 @@ if __name__ == "__main__": hparams = dict( distribution={ "num_classes": 3, - "prototypes_per_class": 4 + "per_class": 4 }, lr=0.01, ) @@ -32,7 +32,7 @@ if __name__ == "__main__": model = pt.models.GLVQ( hparams, optimizer=torch.optim.Adam, - prototype_initializer=pt.components.SMI(train_ds), + prototypes_initializer=pt.initializers.SMCI(train_ds), lr_scheduler=ExponentialLR, lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), ) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 15a5ca3..953a1c6 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -56,7 +56,7 @@ class GLVQ(SupervisedPrototypeModel): def shared_step(self, batch, batch_idx, optimizer_idx=None): x, y = batch out = self.compute_distances(x) - plabels = self.proto_layer.component_labels + plabels = self.proto_layer.labels mu = self.loss(out, y, prototype_labels=plabels) batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta) loss = batch_loss.sum(dim=0)