diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..b35d43e --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,13 @@ +"""prototorch.models test suite.""" + +import prototorch as pt +from prototorch.models.library import GLVQ + + +def test_glvq_model_build(): + hparams = GLVQ.HyperParameters( + distribution=dict(num_classes=2, per_class=1), + component_initializer=pt.initializers.RNCI(2), + ) + + model = GLVQ(hparams=hparams)