fix: Add support for other LinearTransform initializers
This commit is contained in:
parent
5a89f24c10
commit
ec61881ca8
@ -2,7 +2,7 @@ import logging
|
|||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.core import SMCI
|
from prototorch.core import SMCI, PCALinearTransformInitializer
|
||||||
from prototorch.datasets import Iris
|
from prototorch.datasets import Iris
|
||||||
from prototorch.models.architectures.base import Steps
|
from prototorch.models.architectures.base import Steps
|
||||||
from prototorch.models.callbacks import (
|
from prototorch.models.callbacks import (
|
||||||
@ -71,7 +71,9 @@ def main():
|
|||||||
per_class=1,
|
per_class=1,
|
||||||
),
|
),
|
||||||
component_initializer=components_initializer,
|
component_initializer=components_initializer,
|
||||||
)
|
omega_initializer=PCALinearTransformInitializer,
|
||||||
|
omega_initializer_kwargs=dict(
|
||||||
|
data=train_dataset.dataset[train_dataset.indices][0]))
|
||||||
|
|
||||||
# Create Model
|
# Create Model
|
||||||
model = GMLVQ(hyperparameters)
|
model = GMLVQ(hyperparameters)
|
||||||
|
@ -86,6 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
|
|||||||
latent_dim: int = 2
|
latent_dim: int = 2
|
||||||
omega_initializer: type[
|
omega_initializer: type[
|
||||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||||
|
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
|
||||||
|
|
||||||
# Steps
|
# Steps
|
||||||
# ----------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------
|
||||||
@ -96,10 +97,11 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
|
|||||||
if hparams.input_dim is None:
|
if hparams.input_dim is None:
|
||||||
raise ValueError("input_dim must be specified.")
|
raise ValueError("input_dim must be specified.")
|
||||||
else:
|
else:
|
||||||
omega = hparams.omega_initializer().generate(
|
omega = hparams.omega_initializer(
|
||||||
hparams.input_dim,
|
**hparams.omega_initializer_kwargs).generate(
|
||||||
hparams.latent_dim,
|
hparams.input_dim,
|
||||||
)
|
hparams.latent_dim,
|
||||||
|
)
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
self.comparison_kwargs = dict(omega=self._omega)
|
self.comparison_kwargs = dict(omega=self._omega)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user