prototorch_models/tests/test_models.py

194 lines
4.6 KiB
Python
Raw Normal View History

2022-03-30 13:12:33 +00:00
"""prototorch.models test suite."""
2023-06-20 19:18:28 +00:00
import prototorch.models
2022-03-30 13:12:33 +00:00
def test_glvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.GLVQ(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_glvq1_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.GLVQ1(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_glvq21_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.GLVQ1(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_gmlvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.GMLVQ(
2022-03-30 13:12:33 +00:00
{
"distribution": (3, 2),
"input_dim": 2,
"latent_dim": 2,
},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_grlvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.GRLVQ(
2022-03-30 13:12:33 +00:00
{
"distribution": (3, 2),
"input_dim": 2,
},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_gtlvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.GTLVQ(
2022-03-30 13:12:33 +00:00
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_lgmlvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.LGMLVQ(
2022-03-30 13:12:33 +00:00
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_image_glvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.ImageGLVQ(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(16),
2022-03-30 13:12:33 +00:00
)
def test_image_gmlvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.ImageGMLVQ(
2022-03-30 13:12:33 +00:00
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(16),
2022-03-30 13:12:33 +00:00
)
def test_image_gtlvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.ImageGMLVQ(
2022-03-30 13:12:33 +00:00
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(16),
2022-03-30 13:12:33 +00:00
)
def test_siamese_glvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.SiameseGLVQ(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(4),
2022-03-30 13:12:33 +00:00
)
def test_siamese_gmlvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.SiameseGMLVQ(
2022-03-30 13:12:33 +00:00
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(4),
2022-03-30 13:12:33 +00:00
)
def test_siamese_gtlvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.SiameseGTLVQ(
2022-03-30 13:12:33 +00:00
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(4),
2022-03-30 13:12:33 +00:00
)
def test_knn_model_build():
2023-06-20 19:18:28 +00:00
train_ds = prototorch.datasets.Iris(dims=[0, 2])
model = prototorch.models.KNN(dict(k=3), data=train_ds)
2022-03-30 13:12:33 +00:00
def test_lvq1_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.LVQ1(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_lvq21_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.LVQ21(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_median_lvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.MedianLVQ(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_celvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.CELVQ(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_rslvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.RSLVQ(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_slvq_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.SLVQ(
2022-03-30 13:12:33 +00:00
{"distribution": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_growing_neural_gas_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.GrowingNeuralGas(
2022-03-30 13:12:33 +00:00
{"num_prototypes": 5},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_kohonen_som_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.KohonenSOM(
2022-03-30 13:12:33 +00:00
{"shape": (3, 2)},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)
def test_neural_gas_model_build():
2023-06-20 19:18:28 +00:00
model = prototorch.models.NeuralGas(
2022-03-30 13:12:33 +00:00
{"num_prototypes": 5},
2023-06-20 19:18:28 +00:00
prototypes_initializer=prototorch.initializers.RNCI(2),
2022-03-30 13:12:33 +00:00
)