[TEST] Add more tests

This commit is contained in:
Jensun Ravichandran 2021-06-14 14:45:14 +02:00
parent d2d6f31e7b
commit 668c9a1fb7
2 changed files with 43 additions and 19 deletions

View File

@ -81,14 +81,9 @@ class ClassAwareCompInitializer(AbstractComponentsInitializer):
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution) distribution = parse_distribution(distribution)
initializers = {
k: self.subinit_type(self.data[self.targets == k])
for k in distribution.keys()
}
components = torch.tensor([]) components = torch.tensor([])
for k, v in distribution.items(): for k, v in distribution.items():
stratified_data = self.data[self.targets == k] stratified_data = self.data[self.targets == k]
# skip transform here
initializer = self.subinit_type( initializer = self.subinit_type(
stratified_data, stratified_data,
noise=self.noise, noise=self.noise,
@ -157,13 +152,14 @@ class UniformCompInitializer(OnesCompInitializer):
class RandomNormalCompInitializer(OnesCompInitializer): class RandomNormalCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a standard normal distribution.""" """Generate components by sampling from a standard normal distribution."""
def __init__(self, shape, scale=1.0): def __init__(self, shape, shift=0.0, scale=1.0):
super().__init__(shape) super().__init__(shape)
self.shift = shift
self.scale = scale self.scale = scale
def generate(self, num_components: int): def generate(self, num_components: int):
ones = super().generate(num_components) ones = super().generate(num_components)
components = self.scale * torch.randn_like(ones) components = self.scale * (torch.randn_like(ones) + self.shift)
return components return components

View File

@ -73,6 +73,22 @@ def test_fill_value_comp_generate():
assert torch.allclose(components, torch.zeros(3, 2)) assert torch.allclose(components, torch.zeros(3, 2))
def test_uniform_comp_generate_min_max_bound():
c = pt.initializers.UniformCompInitializer(2, -1.0, 1.0)
components = c.generate(num_components=1024)
assert components.min() >= -1.0
assert components.max() <= 1.0
def test_random_comp_generate_mean():
c = pt.initializers.RandomNormalCompInitializer(2, -1.0)
components = c.generate(num_components=1024)
assert torch.allclose(components.mean(),
torch.tensor(-1.0),
rtol=1e-05,
atol=1e-01)
def test_comp_generate_0_components(): def test_comp_generate_0_components():
c = pt.initializers.ZerosCompInitializer(2) c = pt.initializers.ZerosCompInitializer(2)
_ = c.generate(num_components=0) _ = c.generate(num_components=0)
@ -294,7 +310,8 @@ class TestCompetitions(unittest.TestCase):
def test_wtac(self): def test_wtac(self):
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
labels = torch.tensor([0, 1, 2, 3]) labels = torch.tensor([0, 1, 2, 3])
actual = pt.competitions.wtac(d, labels) competition_layer = pt.competitions.WTAC()
actual = competition_layer(d, labels)
desired = torch.tensor([2, 0]) desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -304,7 +321,8 @@ class TestCompetitions(unittest.TestCase):
def test_wtac_unequal_dist(self): def test_wtac_unequal_dist(self):
d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]]) d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
labels = torch.tensor([0, 1, 1]) labels = torch.tensor([0, 1, 1])
actual = pt.competitions.wtac(d, labels) competition_layer = pt.competitions.WTAC()
actual = competition_layer(d, labels)
desired = torch.tensor([0, 1]) desired = torch.tensor([0, 1])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -314,7 +332,8 @@ class TestCompetitions(unittest.TestCase):
def test_wtac_one_hot(self): def test_wtac_one_hot(self):
d = torch.tensor([[1.99, 3.01], [3.0, 2.01]]) d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
labels = torch.tensor([[0, 1], [1, 0]]) labels = torch.tensor([[0, 1], [1, 0]])
actual = pt.competitions.wtac(d, labels) competition_layer = pt.competitions.WTAC()
actual = competition_layer(d, labels)
desired = torch.tensor([[0, 1], [1, 0]]) desired = torch.tensor([[0, 1], [1, 0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -324,7 +343,8 @@ class TestCompetitions(unittest.TestCase):
def test_knnc_k1(self): def test_knnc_k1(self):
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
labels = torch.tensor([0, 1, 2, 3]) labels = torch.tensor([0, 1, 2, 3])
actual = pt.competitions.knnc(d, labels, k=1) competition_layer = pt.competitions.KNNC(k=1)
actual = competition_layer(d, labels)
desired = torch.tensor([2, 0]) desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -343,7 +363,8 @@ class TestPooling(unittest.TestCase):
def test_stratified_min(self): def test_stratified_min(self):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
labels = torch.tensor([0, 0, 1, 2]) labels = torch.tensor([0, 0, 1, 2])
actual = pt.pooling.stratified_min_pooling(d, labels) pooling_layer = pt.pooling.StratifiedMinPooling()
actual = pooling_layer(d, labels)
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -354,7 +375,8 @@ class TestPooling(unittest.TestCase):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
labels = torch.tensor([0, 0, 1, 2]) labels = torch.tensor([0, 0, 1, 2])
labels = torch.eye(3)[labels] labels = torch.eye(3)[labels]
actual = pt.pooling.stratified_min_pooling(d, labels) pooling_layer = pt.pooling.StratifiedMinPooling()
actual = pooling_layer(d, labels)
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -364,7 +386,8 @@ class TestPooling(unittest.TestCase):
def test_stratified_min_trivial(self): def test_stratified_min_trivial(self):
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]]) d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
labels = torch.tensor([0, 1, 2]) labels = torch.tensor([0, 1, 2])
actual = pt.pooling.stratified_min_pooling(d, labels) pooling_layer = pt.pooling.StratifiedMinPooling()
actual = pooling_layer(d, labels)
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -374,7 +397,8 @@ class TestPooling(unittest.TestCase):
def test_stratified_max(self): def test_stratified_max(self):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
labels = torch.tensor([0, 0, 3, 2, 0]) labels = torch.tensor([0, 0, 3, 2, 0])
actual = pt.pooling.stratified_max_pooling(d, labels) pooling_layer = pt.pooling.StratifiedMaxPooling()
actual = pooling_layer(d, labels)
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -385,7 +409,8 @@ class TestPooling(unittest.TestCase):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
labels = torch.tensor([0, 0, 2, 1, 0]) labels = torch.tensor([0, 0, 2, 1, 0])
labels = torch.nn.functional.one_hot(labels, num_classes=3) labels = torch.nn.functional.one_hot(labels, num_classes=3)
actual = pt.pooling.stratified_max_pooling(d, labels) pooling_layer = pt.pooling.StratifiedMaxPooling()
actual = pooling_layer(d, labels)
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -395,7 +420,8 @@ class TestPooling(unittest.TestCase):
def test_stratified_sum(self): def test_stratified_sum(self):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
labels = torch.LongTensor([0, 0, 1, 2]) labels = torch.LongTensor([0, 0, 1, 2])
actual = pt.pooling.stratified_sum_pooling(d, labels) pooling_layer = pt.pooling.StratifiedSumPooling()
actual = pooling_layer(d, labels)
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -406,7 +432,8 @@ class TestPooling(unittest.TestCase):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
labels = torch.tensor([0, 0, 1, 2]) labels = torch.tensor([0, 0, 1, 2])
labels = torch.eye(3)[labels] labels = torch.eye(3)[labels]
actual = pt.pooling.stratified_sum_pooling(d, labels) pooling_layer = pt.pooling.StratifiedSumPooling()
actual = pooling_layer(d, labels)
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@ -416,7 +443,8 @@ class TestPooling(unittest.TestCase):
def test_stratified_prod(self): def test_stratified_prod(self):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
labels = torch.tensor([0, 0, 3, 2, 0]) labels = torch.tensor([0, 0, 3, 2, 0])
actual = pt.pooling.stratified_prod_pooling(d, labels) pooling_layer = pt.pooling.StratifiedProdPooling()
actual = pooling_layer(d, labels)
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]]) desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,