fix: Add support for other LinearTransform initializers
This commit is contained in:
@@ -86,6 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||
latent_dim: int = 2
|
||||
omega_initializer: type[
|
||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@@ -96,10 +97,11 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||
if hparams.input_dim is None:
|
||||
raise ValueError("input_dim must be specified.")
|
||||
else:
|
||||
omega = hparams.omega_initializer().generate(
|
||||
hparams.input_dim,
|
||||
hparams.latent_dim,
|
||||
)
|
||||
omega = hparams.omega_initializer(
|
||||
**hparams.omega_initializer_kwargs).generate(
|
||||
hparams.input_dim,
|
||||
hparams.latent_dim,
|
||||
)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
self.comparison_kwargs = dict(omega=self._omega)
|
||||
|
||||
|
Reference in New Issue
Block a user