test: add unit tests
This commit is contained in:
parent
9da47b1dba
commit
7d3f59e54b
@ -26,13 +26,11 @@ from .lvq import (
|
||||
)
|
||||
from .probabilistic import (
|
||||
CELVQ,
|
||||
PLVQ,
|
||||
RSLVQ,
|
||||
SLVQ,
|
||||
)
|
||||
from .unsupervised import (
|
||||
GrowingNeuralGas,
|
||||
HeskesSOM,
|
||||
KohonenSOM,
|
||||
NeuralGas,
|
||||
)
|
||||
|
@ -16,7 +16,7 @@ class KNN(SupervisedPrototypeModel):
|
||||
"""K-Nearest-Neighbors classification algorithm."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
super().__init__(hparams, skip_proto_layer=True, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("k", 1)
|
||||
@ -28,7 +28,7 @@ class KNN(SupervisedPrototypeModel):
|
||||
|
||||
# Layers
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=[],
|
||||
distribution=len(data) * [1],
|
||||
components_initializer=LiteralCompInitializer(data),
|
||||
labels_initializer=LiteralLabelsInitializer(targets))
|
||||
self.competition_layer = KNNC(k=self.hparams.k)
|
||||
|
@ -67,8 +67,13 @@ class SLVQ(ProbabilisticLVQ):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("variance", 1.0)
|
||||
variance = self.hparams.get("variance")
|
||||
|
||||
self.conditional_distribution = GaussianPrior(variance)
|
||||
self.loss = LossLayer(nllr_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class RSLVQ(ProbabilisticLVQ):
|
||||
@ -76,8 +81,13 @@ class RSLVQ(ProbabilisticLVQ):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("variance", 1.0)
|
||||
variance = self.hparams.get("variance")
|
||||
|
||||
self.conditional_distribution = GaussianPrior(variance)
|
||||
self.loss = LossLayer(rslvq_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
@ -88,8 +98,12 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conditional_distribution = RankScaledGaussianPrior(
|
||||
self.hparams.lambd)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("lambda", 1.0)
|
||||
lam = self.hparams.get("lambda", 1.0)
|
||||
|
||||
self.conditional_distribution = RankScaledGaussianPrior(lam)
|
||||
self.loss = torch.nn.KLDivLoss()
|
||||
|
||||
# FIXME
|
||||
|
@ -35,7 +35,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
|
||||
# Additional parameters
|
||||
x, y = torch.arange(h), torch.arange(w)
|
||||
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
|
||||
grid = torch.stack(torch.meshgrid(x, y, indexing="ij"), dim=-1)
|
||||
self.register_buffer("_grid", grid)
|
||||
self._sigma = self.hparams.sigma
|
||||
self._lr = self.hparams.lr
|
||||
@ -88,12 +88,12 @@ class NeuralGas(UnsupervisedPrototypeModel):
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("agelimit", 10)
|
||||
self.hparams.setdefault("age_limit", 10)
|
||||
self.hparams.setdefault("lm", 1)
|
||||
|
||||
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
|
||||
self.topology_layer = ConnectionTopology(
|
||||
agelimit=self.hparams.agelimit,
|
||||
agelimit=self.hparams.age_limit,
|
||||
num_prototypes=self.hparams.num_prototypes,
|
||||
)
|
||||
|
||||
|
@ -1,15 +0,0 @@
|
||||
"""prototorch.models test suite."""
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestDummy(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_dummy(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
195
tests/test_models.py
Normal file
195
tests/test_models.py
Normal file
@ -0,0 +1,195 @@
|
||||
"""prototorch.models test suite."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def test_glvq_model_build():
|
||||
model = pt.models.GLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_glvq1_model_build():
|
||||
model = pt.models.GLVQ1(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_glvq21_model_build():
|
||||
model = pt.models.GLVQ1(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_gmlvq_model_build():
|
||||
model = pt.models.GMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 2,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_grlvq_model_build():
|
||||
model = pt.models.GRLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_gtlvq_model_build():
|
||||
model = pt.models.GTLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_lgmlvq_model_build():
|
||||
model = pt.models.LGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_image_glvq_model_build():
|
||||
model = pt.models.ImageGLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(16),
|
||||
)
|
||||
|
||||
|
||||
def test_image_gmlvq_model_build():
|
||||
model = pt.models.ImageGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 16,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(16),
|
||||
)
|
||||
|
||||
|
||||
def test_image_gtlvq_model_build():
|
||||
model = pt.models.ImageGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 16,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(16),
|
||||
)
|
||||
|
||||
|
||||
def test_siamese_glvq_model_build():
|
||||
model = pt.models.SiameseGLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(4),
|
||||
)
|
||||
|
||||
|
||||
def test_siamese_gmlvq_model_build():
|
||||
model = pt.models.SiameseGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(4),
|
||||
)
|
||||
|
||||
|
||||
def test_siamese_gtlvq_model_build():
|
||||
model = pt.models.SiameseGTLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(4),
|
||||
)
|
||||
|
||||
|
||||
def test_knn_model_build():
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
model = pt.models.KNN(dict(k=3), data=train_ds)
|
||||
|
||||
|
||||
def test_lvq1_model_build():
|
||||
model = pt.models.LVQ1(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_lvq21_model_build():
|
||||
model = pt.models.LVQ21(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_median_lvq_model_build():
|
||||
model = pt.models.MedianLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_celvq_model_build():
|
||||
model = pt.models.CELVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_rslvq_model_build():
|
||||
model = pt.models.RSLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_slvq_model_build():
|
||||
model = pt.models.SLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_growing_neural_gas_model_build():
|
||||
model = pt.models.GrowingNeuralGas(
|
||||
{"num_prototypes": 5},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_kohonen_som_model_build():
|
||||
model = pt.models.KohonenSOM(
|
||||
{"shape": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_neural_gas_model_build():
|
||||
model = pt.models.NeuralGas(
|
||||
{"num_prototypes": 5},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
Loading…
Reference in New Issue
Block a user