diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 0074a78..c5556ca 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -9,7 +9,7 @@ from ..core.distances import ( omega_distance, squared_euclidean_distance, ) -from ..core.initializers import EyeTransformInitializer +from ..core.initializers import EyeLinearTransformInitializer from ..core.losses import ( GLVQLoss, lvq1_loss, @@ -231,7 +231,7 @@ class SiameseGMLVQ(SiameseGLVQ): # Override the backbone omega_initializer = kwargs.get("omega_initializer", - EyeTransformInitializer()) + EyeLinearTransformInitializer()) self.backbone = LinearTransform( self.hparams.input_dim, self.hparams.latent_dim, @@ -263,7 +263,7 @@ class GMLVQ(GLVQ): # Additional parameters omega_initializer = kwargs.get("omega_initializer", - EyeTransformInitializer()) + EyeLinearTransformInitializer()) omega = omega_initializer.generate(self.hparams.input_dim, self.hparams.latent_dim) self.register_parameter("_omega", Parameter(omega))