fix: correct typo
This commit is contained in:
parent
7d3f59e54b
commit
9c90c902dc
@ -9,7 +9,7 @@ from ..core.distances import (
|
|||||||
omega_distance,
|
omega_distance,
|
||||||
squared_euclidean_distance,
|
squared_euclidean_distance,
|
||||||
)
|
)
|
||||||
from ..core.initializers import EyeTransformInitializer
|
from ..core.initializers import EyeLinearTransformInitializer
|
||||||
from ..core.losses import (
|
from ..core.losses import (
|
||||||
GLVQLoss,
|
GLVQLoss,
|
||||||
lvq1_loss,
|
lvq1_loss,
|
||||||
@ -231,7 +231,7 @@ class SiameseGMLVQ(SiameseGLVQ):
|
|||||||
|
|
||||||
# Override the backbone
|
# Override the backbone
|
||||||
omega_initializer = kwargs.get("omega_initializer",
|
omega_initializer = kwargs.get("omega_initializer",
|
||||||
EyeTransformInitializer())
|
EyeLinearTransformInitializer())
|
||||||
self.backbone = LinearTransform(
|
self.backbone = LinearTransform(
|
||||||
self.hparams.input_dim,
|
self.hparams.input_dim,
|
||||||
self.hparams.latent_dim,
|
self.hparams.latent_dim,
|
||||||
@ -263,7 +263,7 @@ class GMLVQ(GLVQ):
|
|||||||
|
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
omega_initializer = kwargs.get("omega_initializer",
|
omega_initializer = kwargs.get("omega_initializer",
|
||||||
EyeTransformInitializer())
|
EyeLinearTransformInitializer())
|
||||||
omega = omega_initializer.generate(self.hparams.input_dim,
|
omega = omega_initializer.generate(self.hparams.input_dim,
|
||||||
self.hparams.latent_dim)
|
self.hparams.latent_dim)
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
Loading…
Reference in New Issue
Block a user