fix: Add support for other LinearTransform initializers
This commit is contained in:
parent
5a89f24c10
commit
ec61881ca8
@ -2,7 +2,7 @@ import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torchmetrics
|
||||
from prototorch.core import SMCI
|
||||
from prototorch.core import SMCI, PCALinearTransformInitializer
|
||||
from prototorch.datasets import Iris
|
||||
from prototorch.models.architectures.base import Steps
|
||||
from prototorch.models.callbacks import (
|
||||
@ -71,7 +71,9 @@ def main():
|
||||
per_class=1,
|
||||
),
|
||||
component_initializer=components_initializer,
|
||||
)
|
||||
omega_initializer=PCALinearTransformInitializer,
|
||||
omega_initializer_kwargs=dict(
|
||||
data=train_dataset.dataset[train_dataset.indices][0]))
|
||||
|
||||
# Create Model
|
||||
model = GMLVQ(hyperparameters)
|
||||
|
@ -86,6 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||
latent_dim: int = 2
|
||||
omega_initializer: type[
|
||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@ -96,10 +97,11 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||
if hparams.input_dim is None:
|
||||
raise ValueError("input_dim must be specified.")
|
||||
else:
|
||||
omega = hparams.omega_initializer().generate(
|
||||
hparams.input_dim,
|
||||
hparams.latent_dim,
|
||||
)
|
||||
omega = hparams.omega_initializer(
|
||||
**hparams.omega_initializer_kwargs).generate(
|
||||
hparams.input_dim,
|
||||
hparams.latent_dim,
|
||||
)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
self.comparison_kwargs = dict(omega=self._omega)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user