diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index 3329848..ddadec5 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -422,7 +422,7 @@ "trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n", "model = pt.models.SiameseGMLVQ(\n", " dict(input_dim=2,\n", - " output_dim=2,\n", + " latent_dim=2,\n", " distribution=(3, 2),\n", " proto_lr=0.0001,\n", " bb_lr=0.0001),\n", diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 71573b3..dc2858a 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -230,7 +230,7 @@ class SiameseGMLVQ(SiameseGLVQ): EyeTransformInitializer()) self.backbone = LinearTransform( self.hparams.input_dim, - self.hparams.output_dim, + self.hparams.latent_dim, initializer=omega_initializer, )