From 9c90c902dc3d994ea90cbabbb5358319b8e6fe36 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 4 Apr 2022 21:52:13 +0200 Subject: [PATCH] fix: correct typo --- prototorch/models/glvq.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 0074a78..c5556ca 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -9,7 +9,7 @@ from ..core.distances import ( omega_distance, squared_euclidean_distance, ) -from ..core.initializers import EyeTransformInitializer +from ..core.initializers import EyeLinearTransformInitializer from ..core.losses import ( GLVQLoss, lvq1_loss, @@ -231,7 +231,7 @@ class SiameseGMLVQ(SiameseGLVQ): # Override the backbone omega_initializer = kwargs.get("omega_initializer", - EyeTransformInitializer()) + EyeLinearTransformInitializer()) self.backbone = LinearTransform( self.hparams.input_dim, self.hparams.latent_dim, @@ -263,7 +263,7 @@ class GMLVQ(GLVQ): # Additional parameters omega_initializer = kwargs.get("omega_initializer", - EyeTransformInitializer()) + EyeLinearTransformInitializer()) omega = omega_initializer.generate(self.hparams.input_dim, self.hparams.latent_dim) self.register_parameter("_omega", Parameter(omega))