fix: Add support for other LinearTransform initializers
This commit is contained in:
@@ -2,7 +2,7 @@ import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torchmetrics
|
||||
from prototorch.core import SMCI
|
||||
from prototorch.core import SMCI, PCALinearTransformInitializer
|
||||
from prototorch.datasets import Iris
|
||||
from prototorch.models.architectures.base import Steps
|
||||
from prototorch.models.callbacks import (
|
||||
@@ -71,7 +71,9 @@ def main():
|
||||
per_class=1,
|
||||
),
|
||||
component_initializer=components_initializer,
|
||||
)
|
||||
omega_initializer=PCALinearTransformInitializer,
|
||||
omega_initializer_kwargs=dict(
|
||||
data=train_dataset.dataset[train_dataset.indices][0]))
|
||||
|
||||
# Create Model
|
||||
model = GMLVQ(hyperparameters)
|
||||
|
Reference in New Issue
Block a user