fix: Add support for other LinearTransform initializers

This commit is contained in:
Alexander Engelsberger
2022-08-16 15:55:05 +02:00
parent 5a89f24c10
commit ec61881ca8
2 changed files with 10 additions and 6 deletions

View File

@@ -86,6 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------
@@ -96,10 +97,11 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
omega = hparams.omega_initializer(
**hparams.omega_initializer_kwargs).generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)