Add test cases
This commit is contained in:
parent
b19cbcb76a
commit
21b0279839
@ -6,7 +6,107 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions import (activations, competitions, distances,
|
from prototorch.functions import (activations, competitions, distances,
|
||||||
initializers)
|
initializers, losses)
|
||||||
|
|
||||||
|
|
||||||
|
class TestActivations(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.flist = ['identity', 'sigmoid_beta', 'swish_beta']
|
||||||
|
self.x = torch.randn(1024, 1)
|
||||||
|
|
||||||
|
def test_registry(self):
|
||||||
|
self.assertIsNotNone(activations.ACTIVATIONS)
|
||||||
|
|
||||||
|
def test_funcname_deserialization(self):
|
||||||
|
for funcname in self.flist:
|
||||||
|
f = activations.get_activation(funcname)
|
||||||
|
iscallable = callable(f)
|
||||||
|
self.assertTrue(iscallable)
|
||||||
|
|
||||||
|
# def test_torch_script(self):
|
||||||
|
# for funcname in self.flist:
|
||||||
|
# f = activations.get_activation(funcname)
|
||||||
|
# self.assertIsInstance(f, torch.jit.ScriptFunction)
|
||||||
|
|
||||||
|
def test_callable_deserialization(self):
|
||||||
|
def dummy(x, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
for f in [dummy, lambda x: x]:
|
||||||
|
f = activations.get_activation(f)
|
||||||
|
iscallable = callable(f)
|
||||||
|
self.assertTrue(iscallable)
|
||||||
|
self.assertEqual(1, f(1))
|
||||||
|
|
||||||
|
def test_unknown_deserialization(self):
|
||||||
|
for funcname in ['blubb', 'foobar']:
|
||||||
|
with self.assertRaises(NameError):
|
||||||
|
_ = activations.get_activation(funcname)
|
||||||
|
|
||||||
|
def test_identity(self):
|
||||||
|
actual = activations.identity(self.x)
|
||||||
|
desired = self.x
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_sigmoid_beta1(self):
|
||||||
|
actual = activations.sigmoid_beta(self.x, beta=torch.tensor(1))
|
||||||
|
desired = torch.sigmoid(self.x)
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_swish_beta1(self):
|
||||||
|
actual = activations.swish_beta(self.x, beta=torch.tensor(1))
|
||||||
|
desired = self.x * torch.sigmoid(self.x)
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del self.x
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompetitions(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_wtac(self):
|
||||||
|
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
||||||
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
|
actual = competitions.wtac(d, labels)
|
||||||
|
desired = torch.tensor([2, 0])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_wtac_one_hot(self):
|
||||||
|
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
||||||
|
labels = torch.tensor([[0, 1], [1, 0]])
|
||||||
|
actual = competitions.wtac(d, labels)
|
||||||
|
desired = torch.tensor([[0, 1], [1, 0]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_knnc_k1(self):
|
||||||
|
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
||||||
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
|
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
|
||||||
|
desired = torch.tensor([2, 0])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestDistances(unittest.TestCase):
|
class TestDistances(unittest.TestCase):
|
||||||
@ -167,103 +267,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
del self.x, self.y
|
del self.x, self.y
|
||||||
|
|
||||||
|
|
||||||
class TestActivations(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.x = torch.randn(1024, 1)
|
|
||||||
|
|
||||||
def test_registry(self):
|
|
||||||
self.assertIsNotNone(activations.ACTIVATIONS)
|
|
||||||
|
|
||||||
def test_funcname_deserialization(self):
|
|
||||||
flist = ['identity', 'sigmoid_beta', 'swish_beta']
|
|
||||||
for funcname in flist:
|
|
||||||
f = activations.get_activation(funcname)
|
|
||||||
iscallable = callable(f)
|
|
||||||
self.assertTrue(iscallable)
|
|
||||||
|
|
||||||
def test_callable_deserialization(self):
|
|
||||||
def dummy(x, **kwargs):
|
|
||||||
return x
|
|
||||||
|
|
||||||
for f in [dummy, lambda x: x]:
|
|
||||||
f = activations.get_activation(f)
|
|
||||||
iscallable = callable(f)
|
|
||||||
self.assertTrue(iscallable)
|
|
||||||
self.assertEqual(1, f(1))
|
|
||||||
|
|
||||||
def test_unknown_deserialization(self):
|
|
||||||
for funcname in ['blubb', 'foobar']:
|
|
||||||
with self.assertRaises(NameError):
|
|
||||||
_ = activations.get_activation(funcname)
|
|
||||||
|
|
||||||
def test_identity(self):
|
|
||||||
actual = activations.identity(self.x)
|
|
||||||
desired = self.x
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_sigmoid_beta1(self):
|
|
||||||
actual = activations.sigmoid_beta(self.x, beta=1)
|
|
||||||
desired = torch.sigmoid(self.x)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_swish_beta1(self):
|
|
||||||
actual = activations.swish_beta(self.x, beta=1)
|
|
||||||
desired = self.x * torch.sigmoid(self.x)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
del self.x
|
|
||||||
|
|
||||||
|
|
||||||
class TestCompetitions(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_wtac(self):
|
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
|
||||||
actual = competitions.wtac(d, labels)
|
|
||||||
desired = torch.tensor([2, 0])
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_wtac_one_hot(self):
|
|
||||||
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
|
||||||
labels = torch.tensor([[0, 1], [1, 0]])
|
|
||||||
actual = competitions.wtac(d, labels)
|
|
||||||
desired = torch.tensor([[0, 1], [1, 0]])
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_knnc_k1(self):
|
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
|
||||||
actual = competitions.knnc(d, labels, k=1)
|
|
||||||
desired = torch.tensor([2, 0])
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TestInitializers(unittest.TestCase):
|
class TestInitializers(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
self.flist = [
|
||||||
|
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
|
||||||
|
'stratified_random'
|
||||||
|
]
|
||||||
self.x = torch.tensor(
|
self.x = torch.tensor(
|
||||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
@ -274,11 +283,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
self.assertIsNotNone(initializers.INITIALIZERS)
|
self.assertIsNotNone(initializers.INITIALIZERS)
|
||||||
|
|
||||||
def test_funcname_deserialization(self):
|
def test_funcname_deserialization(self):
|
||||||
flist = [
|
for funcname in self.flist:
|
||||||
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
|
|
||||||
'stratified_random'
|
|
||||||
]
|
|
||||||
for funcname in flist:
|
|
||||||
f = initializers.get_initializer(funcname)
|
f = initializers.get_initializer(funcname)
|
||||||
iscallable = callable(f)
|
iscallable = callable(f)
|
||||||
self.assertTrue(iscallable)
|
self.assertTrue(iscallable)
|
||||||
@ -385,3 +390,32 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
del self.x, self.y, self.gen
|
del self.x, self.y, self.gen
|
||||||
_ = torch.seed()
|
_ = torch.seed()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLosses(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_glvq_loss_int_labels(self):
|
||||||
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
|
labels = torch.tensor([0, 1])
|
||||||
|
targets = torch.ones(100)
|
||||||
|
batch_loss = losses.glvq_loss(distances=d,
|
||||||
|
target_labels=targets,
|
||||||
|
prototype_labels=labels)
|
||||||
|
loss_value = torch.sum(batch_loss, dim=0)
|
||||||
|
self.assertEqual(loss_value, -100)
|
||||||
|
|
||||||
|
def test_glvq_loss_one_hot_labels(self):
|
||||||
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
|
labels = torch.tensor([[0, 1], [1, 0]])
|
||||||
|
wl = torch.tensor([1, 0])
|
||||||
|
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||||
|
batch_loss = losses.glvq_loss(distances=d,
|
||||||
|
target_labels=targets,
|
||||||
|
prototype_labels=labels)
|
||||||
|
loss_value = torch.sum(batch_loss, dim=0)
|
||||||
|
self.assertEqual(loss_value, -100)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
@ -123,7 +123,19 @@ class TestLosses(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_glvqloss_init(self):
|
def test_glvqloss_init(self):
|
||||||
_ = losses.GLVQLoss()
|
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
|
||||||
|
|
||||||
|
def test_glvqloss_forward(self):
|
||||||
|
criterion = losses.GLVQLoss(margin=0,
|
||||||
|
squashing='sigmoid_beta',
|
||||||
|
beta=100)
|
||||||
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
|
labels = torch.tensor([0, 1])
|
||||||
|
targets = torch.ones(100)
|
||||||
|
outputs = [d, labels]
|
||||||
|
loss = criterion(outputs, targets)
|
||||||
|
loss_value = loss.item()
|
||||||
|
self.assertAlmostEqual(loss_value, 0.0)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user