[BUGFIX] Update unit tests
This commit is contained in:
parent
2272c55092
commit
4f1c879528
@ -5,7 +5,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions import (activations, competitions, distances,
|
from prototorch.functions import (activations, competitions, distances,
|
||||||
initializers, losses)
|
initializers, losses, pooling)
|
||||||
|
|
||||||
|
|
||||||
class TestActivations(unittest.TestCase):
|
class TestActivations(unittest.TestCase):
|
||||||
@ -104,10 +104,28 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_knnc_k1(self):
|
||||||
|
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||||
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
|
actual = competitions.knnc(d, labels, k=1)
|
||||||
|
desired = torch.tensor([2, 0])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestPooling(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_stratified_min(self):
|
def test_stratified_min(self):
|
||||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 0, 1, 2])
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = pooling.stratified_min_pooling(d, labels)
|
||||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@ -118,7 +136,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 0, 1, 2])
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
labels = torch.eye(3)[labels]
|
labels = torch.eye(3)[labels]
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = pooling.stratified_min_pooling(d, labels)
|
||||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@ -128,7 +146,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
def test_stratified_min_trivial(self):
|
def test_stratified_min_trivial(self):
|
||||||
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 1, 2])
|
labels = torch.tensor([0, 1, 2])
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = pooling.stratified_min_pooling(d, labels)
|
||||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@ -138,7 +156,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
def test_stratified_max(self):
|
def test_stratified_max(self):
|
||||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||||
actual = competitions.stratified_max(d, labels)
|
actual = pooling.stratified_max_pooling(d, labels)
|
||||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@ -149,7 +167,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||||
labels = torch.tensor([0, 0, 2, 1, 0])
|
labels = torch.tensor([0, 0, 2, 1, 0])
|
||||||
labels = torch.nn.functional.one_hot(labels, num_classes=3)
|
labels = torch.nn.functional.one_hot(labels, num_classes=3)
|
||||||
actual = competitions.stratified_max(d, labels)
|
actual = pooling.stratified_max_pooling(d, labels)
|
||||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@ -159,7 +177,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
def test_stratified_sum(self):
|
def test_stratified_sum(self):
|
||||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
labels = torch.LongTensor([0, 0, 1, 2])
|
labels = torch.LongTensor([0, 0, 1, 2])
|
||||||
actual = competitions.stratified_sum(d, labels)
|
actual = pooling.stratified_sum_pooling(d, labels)
|
||||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@ -170,7 +188,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 0, 1, 2])
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
labels = torch.eye(3)[labels]
|
labels = torch.eye(3)[labels]
|
||||||
actual = competitions.stratified_sum(d, labels)
|
actual = pooling.stratified_sum_pooling(d, labels)
|
||||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@ -180,23 +198,13 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
def test_stratified_prod(self):
|
def test_stratified_prod(self):
|
||||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||||
actual = competitions.stratified_prod(d, labels)
|
actual = pooling.stratified_prod_pooling(d, labels)
|
||||||
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
|
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_knnc_k1(self):
|
|
||||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
|
||||||
actual = competitions.knnc(d, labels, k=1)
|
|
||||||
desired = torch.tensor([2, 0])
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user