fix: correct typo

This commit is contained in:
Jensun Ravichandran 2022-04-04 21:52:13 +02:00
parent 7d3f59e54b
commit 9c90c902dc
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921

View File

@ -9,7 +9,7 @@ from ..core.distances import (
omega_distance, omega_distance,
squared_euclidean_distance, squared_euclidean_distance,
) )
from ..core.initializers import EyeTransformInitializer from ..core.initializers import EyeLinearTransformInitializer
from ..core.losses import ( from ..core.losses import (
GLVQLoss, GLVQLoss,
lvq1_loss, lvq1_loss,
@ -231,7 +231,7 @@ class SiameseGMLVQ(SiameseGLVQ):
# Override the backbone # Override the backbone
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get("omega_initializer",
EyeTransformInitializer()) EyeLinearTransformInitializer())
self.backbone = LinearTransform( self.backbone = LinearTransform(
self.hparams.input_dim, self.hparams.input_dim,
self.hparams.latent_dim, self.hparams.latent_dim,
@ -263,7 +263,7 @@ class GMLVQ(GLVQ):
# Additional parameters # Additional parameters
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get("omega_initializer",
EyeTransformInitializer()) EyeLinearTransformInitializer())
omega = omega_initializer.generate(self.hparams.input_dim, omega = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim) self.hparams.latent_dim)
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))