fix: Add support for other LinearTransform initializers

This commit is contained in:
Alexander Engelsberger
2022-08-16 15:55:05 +02:00
parent 5a89f24c10
commit ec61881ca8
2 changed files with 10 additions and 6 deletions

View File

@@ -2,7 +2,7 @@ import logging
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.core import SMCI, PCALinearTransformInitializer
from prototorch.datasets import Iris
from prototorch.models.architectures.base import Steps
from prototorch.models.callbacks import (
@@ -71,7 +71,9 @@ def main():
per_class=1,
),
component_initializer=components_initializer,
)
omega_initializer=PCALinearTransformInitializer,
omega_initializer_kwargs=dict(
data=train_dataset.dataset[train_dataset.indices][0]))
# Create Model
model = GMLVQ(hyperparameters)