Add test cases to test recently added features
This commit is contained in:
parent
88cbe0a126
commit
a0f20a40f6
@ -85,6 +85,16 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_wtac_unequal_dist(self):
|
||||||
|
d = torch.tensor([[2., 3., 4.], [2., 3., 1.]])
|
||||||
|
labels = torch.tensor([0, 1, 1])
|
||||||
|
actual = competitions.wtac(d, labels)
|
||||||
|
desired = torch.tensor([0, 1])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_wtac_one_hot(self):
|
def test_wtac_one_hot(self):
|
||||||
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
||||||
labels = torch.tensor([[0, 1], [1, 0]])
|
labels = torch.tensor([[0, 1], [1, 0]])
|
||||||
@ -95,6 +105,27 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_min(self):
|
||||||
|
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
||||||
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
|
actual = competitions.stratified_min(d, labels)
|
||||||
|
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_min_one_hot(self):
|
||||||
|
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
||||||
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
|
labels = torch.eye(3)[labels]
|
||||||
|
actual = competitions.stratified_min(d, labels)
|
||||||
|
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_knnc_k1(self):
|
def test_knnc_k1(self):
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
@ -351,7 +382,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_random_equal1(self):
|
def test_stratified_random_equal1(self):
|
||||||
pdist = torch.tensor([1, 1])
|
pdist = torch.tensor([1, 1])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
||||||
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.]])
|
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -367,6 +398,16 @@ class TestInitializers(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_random_equal2(self):
|
||||||
|
pdist = torch.tensor([2, 2])
|
||||||
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
||||||
|
desired = torch.tensor([[0., -1., -2.], [0., -1., -2.], [0., 0., 0.],
|
||||||
|
[0., 0., 0.]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_stratified_mean_unequal(self):
|
def test_stratified_mean_unequal(self):
|
||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
||||||
@ -380,7 +421,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_random_unequal(self):
|
def test_stratified_random_unequal(self):
|
||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
||||||
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.], [0., 0., 0.],
|
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
||||||
[0., 0., 0.]])
|
[0., 0., 0.]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@ -417,5 +458,17 @@ class TestLosses(unittest.TestCase):
|
|||||||
loss_value = torch.sum(batch_loss, dim=0)
|
loss_value = torch.sum(batch_loss, dim=0)
|
||||||
self.assertEqual(loss_value, -100)
|
self.assertEqual(loss_value, -100)
|
||||||
|
|
||||||
|
def test_glvq_loss_one_hot_unequal(self):
|
||||||
|
dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
|
||||||
|
d = torch.stack(dlist, dim=1)
|
||||||
|
labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
|
||||||
|
wl = torch.tensor([1, 0])
|
||||||
|
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||||
|
batch_loss = losses.glvq_loss(distances=d,
|
||||||
|
target_labels=targets,
|
||||||
|
prototype_labels=labels)
|
||||||
|
loss_value = torch.sum(batch_loss, dim=0)
|
||||||
|
self.assertEqual(loss_value, -100)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
pass
|
pass
|
||||||
|
@ -51,7 +51,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_prototypes1d_proto_init_without_data(self):
|
def test_prototypes1d_proto_init_without_data(self):
|
||||||
with self.assertWarns(Warning):
|
with self.assertWarns(UserWarning):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
input_dim=3,
|
input_dim=3,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
@ -168,6 +168,16 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_prototypes1d_dist_check(self):
|
||||||
|
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
||||||
|
with self.assertWarns(UserWarning):
|
||||||
|
_ = p1._check_prototype_distribution()
|
||||||
|
|
||||||
|
def test_prototypes1d_check_extra_repr_not_empty(self):
|
||||||
|
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
||||||
|
rep = p1.extra_repr()
|
||||||
|
self.assertNotEqual(rep, '')
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
del self.x, self.y, self.gen
|
del self.x, self.y, self.gen
|
||||||
_ = torch.seed()
|
_ = torch.seed()
|
||||||
@ -194,7 +204,3 @@ class TestLosses(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user