fix: Add support for other LinearTransform initializers

This commit is contained in:
Alexander Engelsberger 2022-08-16 15:55:05 +02:00
parent 5a89f24c10
commit ec61881ca8
No known key found for this signature in database
GPG Key ID: DE8669706B6AC2E7
2 changed files with 10 additions and 6 deletions

View File

@ -2,7 +2,7 @@ import logging
import pytorch_lightning as pl import pytorch_lightning as pl
import torchmetrics import torchmetrics
from prototorch.core import SMCI from prototorch.core import SMCI, PCALinearTransformInitializer
from prototorch.datasets import Iris from prototorch.datasets import Iris
from prototorch.models.architectures.base import Steps from prototorch.models.architectures.base import Steps
from prototorch.models.callbacks import ( from prototorch.models.callbacks import (
@ -71,7 +71,9 @@ def main():
per_class=1, per_class=1,
), ),
component_initializer=components_initializer, component_initializer=components_initializer,
) omega_initializer=PCALinearTransformInitializer,
omega_initializer_kwargs=dict(
data=train_dataset.dataset[train_dataset.indices][0]))
# Create Model # Create Model
model = GMLVQ(hyperparameters) model = GMLVQ(hyperparameters)

View File

@ -86,6 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
latent_dim: int = 2 latent_dim: int = 2
omega_initializer: type[ omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
# Steps # Steps
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
@ -96,7 +97,8 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
if hparams.input_dim is None: if hparams.input_dim is None:
raise ValueError("input_dim must be specified.") raise ValueError("input_dim must be specified.")
else: else:
omega = hparams.omega_initializer().generate( omega = hparams.omega_initializer(
**hparams.omega_initializer_kwargs).generate(
hparams.input_dim, hparams.input_dim,
hparams.latent_dim, hparams.latent_dim,
) )