From 60990f42d2f4eaac43a8f36f7b3e7a857780eaff Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Tue, 20 Jun 2023 21:18:28 +0200 Subject: [PATCH] fix: update import in tests --- tests/test_models.py | 96 ++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 49 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 2a632ae..8e45fd5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,195 +1,193 @@ """prototorch.models test suite.""" -import prototorch as pt -import pytest -import torch +import prototorch.models def test_glvq_model_build(): - model = pt.models.GLVQ( + model = prototorch.models.GLVQ( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_glvq1_model_build(): - model = pt.models.GLVQ1( + model = prototorch.models.GLVQ1( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_glvq21_model_build(): - model = pt.models.GLVQ1( + model = prototorch.models.GLVQ1( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_gmlvq_model_build(): - model = pt.models.GMLVQ( + model = prototorch.models.GMLVQ( { "distribution": (3, 2), "input_dim": 2, "latent_dim": 2, }, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_grlvq_model_build(): - model = pt.models.GRLVQ( + model = prototorch.models.GRLVQ( { "distribution": (3, 2), "input_dim": 2, }, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_gtlvq_model_build(): - model = pt.models.GTLVQ( + model = prototorch.models.GTLVQ( { "distribution": (3, 2), "input_dim": 4, "latent_dim": 2, }, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_lgmlvq_model_build(): - model = pt.models.LGMLVQ( + model = prototorch.models.LGMLVQ( { "distribution": (3, 2), "input_dim": 4, "latent_dim": 2, }, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_image_glvq_model_build(): - model = pt.models.ImageGLVQ( + model = prototorch.models.ImageGLVQ( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(16), + prototypes_initializer=prototorch.initializers.RNCI(16), ) def test_image_gmlvq_model_build(): - model = pt.models.ImageGMLVQ( + model = prototorch.models.ImageGMLVQ( { "distribution": (3, 2), "input_dim": 16, "latent_dim": 2, }, - prototypes_initializer=pt.initializers.RNCI(16), + prototypes_initializer=prototorch.initializers.RNCI(16), ) def test_image_gtlvq_model_build(): - model = pt.models.ImageGMLVQ( + model = prototorch.models.ImageGMLVQ( { "distribution": (3, 2), "input_dim": 16, "latent_dim": 2, }, - prototypes_initializer=pt.initializers.RNCI(16), + prototypes_initializer=prototorch.initializers.RNCI(16), ) def test_siamese_glvq_model_build(): - model = pt.models.SiameseGLVQ( + model = prototorch.models.SiameseGLVQ( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(4), + prototypes_initializer=prototorch.initializers.RNCI(4), ) def test_siamese_gmlvq_model_build(): - model = pt.models.SiameseGMLVQ( + model = prototorch.models.SiameseGMLVQ( { "distribution": (3, 2), "input_dim": 4, "latent_dim": 2, }, - prototypes_initializer=pt.initializers.RNCI(4), + prototypes_initializer=prototorch.initializers.RNCI(4), ) def test_siamese_gtlvq_model_build(): - model = pt.models.SiameseGTLVQ( + model = prototorch.models.SiameseGTLVQ( { "distribution": (3, 2), "input_dim": 4, "latent_dim": 2, }, - prototypes_initializer=pt.initializers.RNCI(4), + prototypes_initializer=prototorch.initializers.RNCI(4), ) def test_knn_model_build(): - train_ds = pt.datasets.Iris(dims=[0, 2]) - model = pt.models.KNN(dict(k=3), data=train_ds) + train_ds = prototorch.datasets.Iris(dims=[0, 2]) + model = prototorch.models.KNN(dict(k=3), data=train_ds) def test_lvq1_model_build(): - model = pt.models.LVQ1( + model = prototorch.models.LVQ1( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_lvq21_model_build(): - model = pt.models.LVQ21( + model = prototorch.models.LVQ21( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_median_lvq_model_build(): - model = pt.models.MedianLVQ( + model = prototorch.models.MedianLVQ( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_celvq_model_build(): - model = pt.models.CELVQ( + model = prototorch.models.CELVQ( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_rslvq_model_build(): - model = pt.models.RSLVQ( + model = prototorch.models.RSLVQ( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_slvq_model_build(): - model = pt.models.SLVQ( + model = prototorch.models.SLVQ( {"distribution": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_growing_neural_gas_model_build(): - model = pt.models.GrowingNeuralGas( + model = prototorch.models.GrowingNeuralGas( {"num_prototypes": 5}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_kohonen_som_model_build(): - model = pt.models.KohonenSOM( + model = prototorch.models.KohonenSOM( {"shape": (3, 2)}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), ) def test_neural_gas_model_build(): - model = pt.models.NeuralGas( + model = prototorch.models.NeuralGas( {"num_prototypes": 5}, - prototypes_initializer=pt.initializers.RNCI(2), + prototypes_initializer=prototorch.initializers.RNCI(2), )