[BUGFIX] examples/glvq_iris.py works again
				
					
				
			This commit is contained in:
		| @@ -23,7 +23,7 @@ if __name__ == "__main__": | ||||
|     hparams = dict( | ||||
|         distribution={ | ||||
|             "num_classes": 3, | ||||
|             "prototypes_per_class": 4 | ||||
|             "per_class": 4 | ||||
|         }, | ||||
|         lr=0.01, | ||||
|     ) | ||||
| @@ -32,7 +32,7 @@ if __name__ == "__main__": | ||||
|     model = pt.models.GLVQ( | ||||
|         hparams, | ||||
|         optimizer=torch.optim.Adam, | ||||
|         prototype_initializer=pt.components.SMI(train_ds), | ||||
|         prototypes_initializer=pt.initializers.SMCI(train_ds), | ||||
|         lr_scheduler=ExponentialLR, | ||||
|         lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), | ||||
|     ) | ||||
|   | ||||
| @@ -56,7 +56,7 @@ class GLVQ(SupervisedPrototypeModel): | ||||
|     def shared_step(self, batch, batch_idx, optimizer_idx=None): | ||||
|         x, y = batch | ||||
|         out = self.compute_distances(x) | ||||
|         plabels = self.proto_layer.component_labels | ||||
|         plabels = self.proto_layer.labels | ||||
|         mu = self.loss(out, y, prototype_labels=plabels) | ||||
|         batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta) | ||||
|         loss = batch_loss.sum(dim=0) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user