fix: Add support for other LinearTransform initializers

This commit is contained in:
Alexander Engelsberger 2022-08-16 15:55:05 +02:00
parent 5a89f24c10
commit ec61881ca8
No known key found for this signature in database
GPG Key ID: DE8669706B6AC2E7
2 changed files with 10 additions and 6 deletions

View File

@ -2,7 +2,7 @@ import logging
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.core import SMCI, PCALinearTransformInitializer
from prototorch.datasets import Iris
from prototorch.models.architectures.base import Steps
from prototorch.models.callbacks import (
@ -71,7 +71,9 @@ def main():
per_class=1,
),
component_initializer=components_initializer,
)
omega_initializer=PCALinearTransformInitializer,
omega_initializer_kwargs=dict(
data=train_dataset.dataset[train_dataset.indices][0]))
# Create Model
model = GMLVQ(hyperparameters)

View File

@ -86,6 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------
@ -96,10 +97,11 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
omega = hparams.omega_initializer(
**hparams.omega_initializer_kwargs).generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)