From 21b0279839145a0ad64cab29e50d658f12e9b9a9 Mon Sep 17 00:00:00 2001 From: blackfly Date: Wed, 8 Apr 2020 22:47:08 +0200 Subject: [PATCH] Add test cases --- tests/test_functions.py | 236 +++++++++++++++++++++++----------------- tests/test_modules.py | 14 ++- 2 files changed, 148 insertions(+), 102 deletions(-) diff --git a/tests/test_functions.py b/tests/test_functions.py index f2891d5..3f27ba1 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -6,7 +6,107 @@ import numpy as np import torch from prototorch.functions import (activations, competitions, distances, - initializers) + initializers, losses) + + +class TestActivations(unittest.TestCase): + def setUp(self): + self.flist = ['identity', 'sigmoid_beta', 'swish_beta'] + self.x = torch.randn(1024, 1) + + def test_registry(self): + self.assertIsNotNone(activations.ACTIVATIONS) + + def test_funcname_deserialization(self): + for funcname in self.flist: + f = activations.get_activation(funcname) + iscallable = callable(f) + self.assertTrue(iscallable) + + # def test_torch_script(self): + # for funcname in self.flist: + # f = activations.get_activation(funcname) + # self.assertIsInstance(f, torch.jit.ScriptFunction) + + def test_callable_deserialization(self): + def dummy(x, **kwargs): + return x + + for f in [dummy, lambda x: x]: + f = activations.get_activation(f) + iscallable = callable(f) + self.assertTrue(iscallable) + self.assertEqual(1, f(1)) + + def test_unknown_deserialization(self): + for funcname in ['blubb', 'foobar']: + with self.assertRaises(NameError): + _ = activations.get_activation(funcname) + + def test_identity(self): + actual = activations.identity(self.x) + desired = self.x + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_sigmoid_beta1(self): + actual = activations.sigmoid_beta(self.x, beta=torch.tensor(1)) + desired = torch.sigmoid(self.x) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_swish_beta1(self): + actual = activations.swish_beta(self.x, beta=torch.tensor(1)) + desired = self.x * torch.sigmoid(self.x) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def tearDown(self): + del self.x + + +class TestCompetitions(unittest.TestCase): + def setUp(self): + pass + + def test_wtac(self): + d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]]) + labels = torch.tensor([0, 1, 2, 3]) + actual = competitions.wtac(d, labels) + desired = torch.tensor([2, 0]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_wtac_one_hot(self): + d = torch.tensor([[1.99, 3.01], [3., 2.01]]) + labels = torch.tensor([[0, 1], [1, 0]]) + actual = competitions.wtac(d, labels) + desired = torch.tensor([[0, 1], [1, 0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_knnc_k1(self): + d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]]) + labels = torch.tensor([0, 1, 2, 3]) + actual = competitions.knnc(d, labels, k=torch.tensor([1])) + desired = torch.tensor([2, 0]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def tearDown(self): + pass class TestDistances(unittest.TestCase): @@ -167,103 +267,12 @@ class TestDistances(unittest.TestCase): del self.x, self.y -class TestActivations(unittest.TestCase): - def setUp(self): - self.x = torch.randn(1024, 1) - - def test_registry(self): - self.assertIsNotNone(activations.ACTIVATIONS) - - def test_funcname_deserialization(self): - flist = ['identity', 'sigmoid_beta', 'swish_beta'] - for funcname in flist: - f = activations.get_activation(funcname) - iscallable = callable(f) - self.assertTrue(iscallable) - - def test_callable_deserialization(self): - def dummy(x, **kwargs): - return x - - for f in [dummy, lambda x: x]: - f = activations.get_activation(f) - iscallable = callable(f) - self.assertTrue(iscallable) - self.assertEqual(1, f(1)) - - def test_unknown_deserialization(self): - for funcname in ['blubb', 'foobar']: - with self.assertRaises(NameError): - _ = activations.get_activation(funcname) - - def test_identity(self): - actual = activations.identity(self.x) - desired = self.x - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_sigmoid_beta1(self): - actual = activations.sigmoid_beta(self.x, beta=1) - desired = torch.sigmoid(self.x) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_swish_beta1(self): - actual = activations.swish_beta(self.x, beta=1) - desired = self.x * torch.sigmoid(self.x) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def tearDown(self): - del self.x - - -class TestCompetitions(unittest.TestCase): - def setUp(self): - pass - - def test_wtac(self): - d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]]) - labels = torch.tensor([0, 1, 2, 3]) - actual = competitions.wtac(d, labels) - desired = torch.tensor([2, 0]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_wtac_one_hot(self): - d = torch.tensor([[1.99, 3.01], [3., 2.01]]) - labels = torch.tensor([[0, 1], [1, 0]]) - actual = competitions.wtac(d, labels) - desired = torch.tensor([[0, 1], [1, 0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_knnc_k1(self): - d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]]) - labels = torch.tensor([0, 1, 2, 3]) - actual = competitions.knnc(d, labels, k=1) - desired = torch.tensor([2, 0]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def tearDown(self): - pass - - class TestInitializers(unittest.TestCase): def setUp(self): + self.flist = [ + 'zeros', 'ones', 'rand', 'randn', 'stratified_mean', + 'stratified_random' + ] self.x = torch.tensor( [[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]], dtype=torch.float32) @@ -274,11 +283,7 @@ class TestInitializers(unittest.TestCase): self.assertIsNotNone(initializers.INITIALIZERS) def test_funcname_deserialization(self): - flist = [ - 'zeros', 'ones', 'rand', 'randn', 'stratified_mean', - 'stratified_random' - ] - for funcname in flist: + for funcname in self.flist: f = initializers.get_initializer(funcname) iscallable = callable(f) self.assertTrue(iscallable) @@ -385,3 +390,32 @@ class TestInitializers(unittest.TestCase): def tearDown(self): del self.x, self.y, self.gen _ = torch.seed() + + +class TestLosses(unittest.TestCase): + def setUp(self): + pass + + def test_glvq_loss_int_labels(self): + d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) + labels = torch.tensor([0, 1]) + targets = torch.ones(100) + batch_loss = losses.glvq_loss(distances=d, + target_labels=targets, + prototype_labels=labels) + loss_value = torch.sum(batch_loss, dim=0) + self.assertEqual(loss_value, -100) + + def test_glvq_loss_one_hot_labels(self): + d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) + labels = torch.tensor([[0, 1], [1, 0]]) + wl = torch.tensor([1, 0]) + targets = torch.stack([wl for _ in range(100)], dim=0) + batch_loss = losses.glvq_loss(distances=d, + target_labels=targets, + prototype_labels=labels) + loss_value = torch.sum(batch_loss, dim=0) + self.assertEqual(loss_value, -100) + + def tearDown(self): + pass diff --git a/tests/test_modules.py b/tests/test_modules.py index 8a8911a..47ae613 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -123,7 +123,19 @@ class TestLosses(unittest.TestCase): pass def test_glvqloss_init(self): - _ = losses.GLVQLoss() + _ = losses.GLVQLoss(0, 'swish_beta', beta=20) + + def test_glvqloss_forward(self): + criterion = losses.GLVQLoss(margin=0, + squashing='sigmoid_beta', + beta=100) + d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) + labels = torch.tensor([0, 1]) + targets = torch.ones(100) + outputs = [d, labels] + loss = criterion(outputs, targets) + loss_value = loss.item() + self.assertAlmostEqual(loss_value, 0.0) def tearDown(self): pass