Add test cases
This commit is contained in:
		@@ -6,7 +6,107 @@ import numpy as np
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.functions import (activations, competitions, distances,
 | 
					from prototorch.functions import (activations, competitions, distances,
 | 
				
			||||||
                                  initializers)
 | 
					                                  initializers, losses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestActivations(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.flist = ['identity', 'sigmoid_beta', 'swish_beta']
 | 
				
			||||||
 | 
					        self.x = torch.randn(1024, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_registry(self):
 | 
				
			||||||
 | 
					        self.assertIsNotNone(activations.ACTIVATIONS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_funcname_deserialization(self):
 | 
				
			||||||
 | 
					        for funcname in self.flist:
 | 
				
			||||||
 | 
					            f = activations.get_activation(funcname)
 | 
				
			||||||
 | 
					            iscallable = callable(f)
 | 
				
			||||||
 | 
					            self.assertTrue(iscallable)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # def test_torch_script(self):
 | 
				
			||||||
 | 
					    #     for funcname in self.flist:
 | 
				
			||||||
 | 
					    #         f = activations.get_activation(funcname)
 | 
				
			||||||
 | 
					    #         self.assertIsInstance(f, torch.jit.ScriptFunction)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_callable_deserialization(self):
 | 
				
			||||||
 | 
					        def dummy(x, **kwargs):
 | 
				
			||||||
 | 
					            return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for f in [dummy, lambda x: x]:
 | 
				
			||||||
 | 
					            f = activations.get_activation(f)
 | 
				
			||||||
 | 
					            iscallable = callable(f)
 | 
				
			||||||
 | 
					            self.assertTrue(iscallable)
 | 
				
			||||||
 | 
					            self.assertEqual(1, f(1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_unknown_deserialization(self):
 | 
				
			||||||
 | 
					        for funcname in ['blubb', 'foobar']:
 | 
				
			||||||
 | 
					            with self.assertRaises(NameError):
 | 
				
			||||||
 | 
					                _ = activations.get_activation(funcname)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_identity(self):
 | 
				
			||||||
 | 
					        actual = activations.identity(self.x)
 | 
				
			||||||
 | 
					        desired = self.x
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_sigmoid_beta1(self):
 | 
				
			||||||
 | 
					        actual = activations.sigmoid_beta(self.x, beta=torch.tensor(1))
 | 
				
			||||||
 | 
					        desired = torch.sigmoid(self.x)
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_swish_beta1(self):
 | 
				
			||||||
 | 
					        actual = activations.swish_beta(self.x, beta=torch.tensor(1))
 | 
				
			||||||
 | 
					        desired = self.x * torch.sigmoid(self.x)
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        del self.x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestCompetitions(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_wtac(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 1, 2, 3])
 | 
				
			||||||
 | 
					        actual = competitions.wtac(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([2, 0])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_wtac_one_hot(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.99, 3.01], [3., 2.01]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([[0, 1], [1, 0]])
 | 
				
			||||||
 | 
					        actual = competitions.wtac(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[0, 1], [1, 0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_knnc_k1(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 1, 2, 3])
 | 
				
			||||||
 | 
					        actual = competitions.knnc(d, labels, k=torch.tensor([1]))
 | 
				
			||||||
 | 
					        desired = torch.tensor([2, 0])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestDistances(unittest.TestCase):
 | 
					class TestDistances(unittest.TestCase):
 | 
				
			||||||
@@ -167,103 +267,12 @@ class TestDistances(unittest.TestCase):
 | 
				
			|||||||
        del self.x, self.y
 | 
					        del self.x, self.y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestActivations(unittest.TestCase):
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        self.x = torch.randn(1024, 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_registry(self):
 | 
					 | 
				
			||||||
        self.assertIsNotNone(activations.ACTIVATIONS)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_funcname_deserialization(self):
 | 
					 | 
				
			||||||
        flist = ['identity', 'sigmoid_beta', 'swish_beta']
 | 
					 | 
				
			||||||
        for funcname in flist:
 | 
					 | 
				
			||||||
            f = activations.get_activation(funcname)
 | 
					 | 
				
			||||||
            iscallable = callable(f)
 | 
					 | 
				
			||||||
            self.assertTrue(iscallable)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_callable_deserialization(self):
 | 
					 | 
				
			||||||
        def dummy(x, **kwargs):
 | 
					 | 
				
			||||||
            return x
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for f in [dummy, lambda x: x]:
 | 
					 | 
				
			||||||
            f = activations.get_activation(f)
 | 
					 | 
				
			||||||
            iscallable = callable(f)
 | 
					 | 
				
			||||||
            self.assertTrue(iscallable)
 | 
					 | 
				
			||||||
            self.assertEqual(1, f(1))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_unknown_deserialization(self):
 | 
					 | 
				
			||||||
        for funcname in ['blubb', 'foobar']:
 | 
					 | 
				
			||||||
            with self.assertRaises(NameError):
 | 
					 | 
				
			||||||
                _ = activations.get_activation(funcname)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_identity(self):
 | 
					 | 
				
			||||||
        actual = activations.identity(self.x)
 | 
					 | 
				
			||||||
        desired = self.x
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_sigmoid_beta1(self):
 | 
					 | 
				
			||||||
        actual = activations.sigmoid_beta(self.x, beta=1)
 | 
					 | 
				
			||||||
        desired = torch.sigmoid(self.x)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_swish_beta1(self):
 | 
					 | 
				
			||||||
        actual = activations.swish_beta(self.x, beta=1)
 | 
					 | 
				
			||||||
        desired = self.x * torch.sigmoid(self.x)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def tearDown(self):
 | 
					 | 
				
			||||||
        del self.x
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestCompetitions(unittest.TestCase):
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_wtac(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 1, 2, 3])
 | 
					 | 
				
			||||||
        actual = competitions.wtac(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([2, 0])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_wtac_one_hot(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[1.99, 3.01], [3., 2.01]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([[0, 1], [1, 0]])
 | 
					 | 
				
			||||||
        actual = competitions.wtac(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[0, 1], [1, 0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_knnc_k1(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 1, 2, 3])
 | 
					 | 
				
			||||||
        actual = competitions.knnc(d, labels, k=1)
 | 
					 | 
				
			||||||
        desired = torch.tensor([2, 0])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def tearDown(self):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestInitializers(unittest.TestCase):
 | 
					class TestInitializers(unittest.TestCase):
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.flist = [
 | 
				
			||||||
 | 
					            'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
 | 
				
			||||||
 | 
					            'stratified_random'
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
        self.x = torch.tensor(
 | 
					        self.x = torch.tensor(
 | 
				
			||||||
            [[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
 | 
					            [[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
 | 
				
			||||||
            dtype=torch.float32)
 | 
					            dtype=torch.float32)
 | 
				
			||||||
@@ -274,11 +283,7 @@ class TestInitializers(unittest.TestCase):
 | 
				
			|||||||
        self.assertIsNotNone(initializers.INITIALIZERS)
 | 
					        self.assertIsNotNone(initializers.INITIALIZERS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_funcname_deserialization(self):
 | 
					    def test_funcname_deserialization(self):
 | 
				
			||||||
        flist = [
 | 
					        for funcname in self.flist:
 | 
				
			||||||
            'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
 | 
					 | 
				
			||||||
            'stratified_random'
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
        for funcname in flist:
 | 
					 | 
				
			||||||
            f = initializers.get_initializer(funcname)
 | 
					            f = initializers.get_initializer(funcname)
 | 
				
			||||||
            iscallable = callable(f)
 | 
					            iscallable = callable(f)
 | 
				
			||||||
            self.assertTrue(iscallable)
 | 
					            self.assertTrue(iscallable)
 | 
				
			||||||
@@ -385,3 +390,32 @@ class TestInitializers(unittest.TestCase):
 | 
				
			|||||||
    def tearDown(self):
 | 
					    def tearDown(self):
 | 
				
			||||||
        del self.x, self.y, self.gen
 | 
					        del self.x, self.y, self.gen
 | 
				
			||||||
        _ = torch.seed()
 | 
					        _ = torch.seed()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestLosses(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_glvq_loss_int_labels(self):
 | 
				
			||||||
 | 
					        d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 1])
 | 
				
			||||||
 | 
					        targets = torch.ones(100)
 | 
				
			||||||
 | 
					        batch_loss = losses.glvq_loss(distances=d,
 | 
				
			||||||
 | 
					                                      target_labels=targets,
 | 
				
			||||||
 | 
					                                      prototype_labels=labels)
 | 
				
			||||||
 | 
					        loss_value = torch.sum(batch_loss, dim=0)
 | 
				
			||||||
 | 
					        self.assertEqual(loss_value, -100)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_glvq_loss_one_hot_labels(self):
 | 
				
			||||||
 | 
					        d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
 | 
				
			||||||
 | 
					        labels = torch.tensor([[0, 1], [1, 0]])
 | 
				
			||||||
 | 
					        wl = torch.tensor([1, 0])
 | 
				
			||||||
 | 
					        targets = torch.stack([wl for _ in range(100)], dim=0)
 | 
				
			||||||
 | 
					        batch_loss = losses.glvq_loss(distances=d,
 | 
				
			||||||
 | 
					                                      target_labels=targets,
 | 
				
			||||||
 | 
					                                      prototype_labels=labels)
 | 
				
			||||||
 | 
					        loss_value = torch.sum(batch_loss, dim=0)
 | 
				
			||||||
 | 
					        self.assertEqual(loss_value, -100)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -123,7 +123,19 @@ class TestLosses(unittest.TestCase):
 | 
				
			|||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_glvqloss_init(self):
 | 
					    def test_glvqloss_init(self):
 | 
				
			||||||
        _ = losses.GLVQLoss()
 | 
					        _ = losses.GLVQLoss(0, 'swish_beta', beta=20)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_glvqloss_forward(self):
 | 
				
			||||||
 | 
					        criterion = losses.GLVQLoss(margin=0,
 | 
				
			||||||
 | 
					                                    squashing='sigmoid_beta',
 | 
				
			||||||
 | 
					                                    beta=100)
 | 
				
			||||||
 | 
					        d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 1])
 | 
				
			||||||
 | 
					        targets = torch.ones(100)
 | 
				
			||||||
 | 
					        outputs = [d, labels]
 | 
				
			||||||
 | 
					        loss = criterion(outputs, targets)
 | 
				
			||||||
 | 
					        loss_value = loss.item()
 | 
				
			||||||
 | 
					        self.assertAlmostEqual(loss_value, 0.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def tearDown(self):
 | 
					    def tearDown(self):
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user