fix: Add support for other LinearTransform initializers

This commit is contained in:
Alexander Engelsberger
2022-08-16 15:55:05 +02:00
parent 5a89f24c10
commit ec61881ca8
2 changed files with 10 additions and 6 deletions

View File

@@ -2,7 +2,7 @@ import logging
import pytorch_lightning as pl import pytorch_lightning as pl
import torchmetrics import torchmetrics
from prototorch.core import SMCI from prototorch.core import SMCI, PCALinearTransformInitializer
from prototorch.datasets import Iris from prototorch.datasets import Iris
from prototorch.models.architectures.base import Steps from prototorch.models.architectures.base import Steps
from prototorch.models.callbacks import ( from prototorch.models.callbacks import (
@@ -71,7 +71,9 @@ def main():
per_class=1, per_class=1,
), ),
component_initializer=components_initializer, component_initializer=components_initializer,
) omega_initializer=PCALinearTransformInitializer,
omega_initializer_kwargs=dict(
data=train_dataset.dataset[train_dataset.indices][0]))
# Create Model # Create Model
model = GMLVQ(hyperparameters) model = GMLVQ(hyperparameters)

View File

@@ -86,6 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
latent_dim: int = 2 latent_dim: int = 2
omega_initializer: type[ omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
# Steps # Steps
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
@@ -96,10 +97,11 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
if hparams.input_dim is None: if hparams.input_dim is None:
raise ValueError("input_dim must be specified.") raise ValueError("input_dim must be specified.")
else: else:
omega = hparams.omega_initializer().generate( omega = hparams.omega_initializer(
hparams.input_dim, **hparams.omega_initializer_kwargs).generate(
hparams.latent_dim, hparams.input_dim,
) hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega) self.comparison_kwargs = dict(omega=self._omega)