[TEST] Add more tests
This commit is contained in:
parent
d2d6f31e7b
commit
668c9a1fb7
@ -81,14 +81,9 @@ class ClassAwareCompInitializer(AbstractComponentsInitializer):
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
initializers = {
|
||||
k: self.subinit_type(self.data[self.targets == k])
|
||||
for k in distribution.keys()
|
||||
}
|
||||
components = torch.tensor([])
|
||||
for k, v in distribution.items():
|
||||
stratified_data = self.data[self.targets == k]
|
||||
# skip transform here
|
||||
initializer = self.subinit_type(
|
||||
stratified_data,
|
||||
noise=self.noise,
|
||||
@ -157,13 +152,14 @@ class UniformCompInitializer(OnesCompInitializer):
|
||||
|
||||
class RandomNormalCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a standard normal distribution."""
|
||||
def __init__(self, shape, scale=1.0):
|
||||
def __init__(self, shape, shift=0.0, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.shift = shift
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = self.scale * torch.randn_like(ones)
|
||||
components = self.scale * (torch.randn_like(ones) + self.shift)
|
||||
return components
|
||||
|
||||
|
||||
|
@ -73,6 +73,22 @@ def test_fill_value_comp_generate():
|
||||
assert torch.allclose(components, torch.zeros(3, 2))
|
||||
|
||||
|
||||
def test_uniform_comp_generate_min_max_bound():
|
||||
c = pt.initializers.UniformCompInitializer(2, -1.0, 1.0)
|
||||
components = c.generate(num_components=1024)
|
||||
assert components.min() >= -1.0
|
||||
assert components.max() <= 1.0
|
||||
|
||||
|
||||
def test_random_comp_generate_mean():
|
||||
c = pt.initializers.RandomNormalCompInitializer(2, -1.0)
|
||||
components = c.generate(num_components=1024)
|
||||
assert torch.allclose(components.mean(),
|
||||
torch.tensor(-1.0),
|
||||
rtol=1e-05,
|
||||
atol=1e-01)
|
||||
|
||||
|
||||
def test_comp_generate_0_components():
|
||||
c = pt.initializers.ZerosCompInitializer(2)
|
||||
_ = c.generate(num_components=0)
|
||||
@ -294,7 +310,8 @@ class TestCompetitions(unittest.TestCase):
|
||||
def test_wtac(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
actual = pt.competitions.wtac(d, labels)
|
||||
competition_layer = pt.competitions.WTAC()
|
||||
actual = competition_layer(d, labels)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -304,7 +321,8 @@ class TestCompetitions(unittest.TestCase):
|
||||
def test_wtac_unequal_dist(self):
|
||||
d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
|
||||
labels = torch.tensor([0, 1, 1])
|
||||
actual = pt.competitions.wtac(d, labels)
|
||||
competition_layer = pt.competitions.WTAC()
|
||||
actual = competition_layer(d, labels)
|
||||
desired = torch.tensor([0, 1])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -314,7 +332,8 @@ class TestCompetitions(unittest.TestCase):
|
||||
def test_wtac_one_hot(self):
|
||||
d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
|
||||
labels = torch.tensor([[0, 1], [1, 0]])
|
||||
actual = pt.competitions.wtac(d, labels)
|
||||
competition_layer = pt.competitions.WTAC()
|
||||
actual = competition_layer(d, labels)
|
||||
desired = torch.tensor([[0, 1], [1, 0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -324,7 +343,8 @@ class TestCompetitions(unittest.TestCase):
|
||||
def test_knnc_k1(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
actual = pt.competitions.knnc(d, labels, k=1)
|
||||
competition_layer = pt.competitions.KNNC(k=1)
|
||||
actual = competition_layer(d, labels)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -343,7 +363,8 @@ class TestPooling(unittest.TestCase):
|
||||
def test_stratified_min(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
actual = pt.pooling.stratified_min_pooling(d, labels)
|
||||
pooling_layer = pt.pooling.StratifiedMinPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -354,7 +375,8 @@ class TestPooling(unittest.TestCase):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
actual = pt.pooling.stratified_min_pooling(d, labels)
|
||||
pooling_layer = pt.pooling.StratifiedMinPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -364,7 +386,8 @@ class TestPooling(unittest.TestCase):
|
||||
def test_stratified_min_trivial(self):
|
||||
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 1, 2])
|
||||
actual = pt.pooling.stratified_min_pooling(d, labels)
|
||||
pooling_layer = pt.pooling.StratifiedMinPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -374,7 +397,8 @@ class TestPooling(unittest.TestCase):
|
||||
def test_stratified_max(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
actual = pt.pooling.stratified_max_pooling(d, labels)
|
||||
pooling_layer = pt.pooling.StratifiedMaxPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -385,7 +409,8 @@ class TestPooling(unittest.TestCase):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 2, 1, 0])
|
||||
labels = torch.nn.functional.one_hot(labels, num_classes=3)
|
||||
actual = pt.pooling.stratified_max_pooling(d, labels)
|
||||
pooling_layer = pt.pooling.StratifiedMaxPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -395,7 +420,8 @@ class TestPooling(unittest.TestCase):
|
||||
def test_stratified_sum(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.LongTensor([0, 0, 1, 2])
|
||||
actual = pt.pooling.stratified_sum_pooling(d, labels)
|
||||
pooling_layer = pt.pooling.StratifiedSumPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -406,7 +432,8 @@ class TestPooling(unittest.TestCase):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
actual = pt.pooling.stratified_sum_pooling(d, labels)
|
||||
pooling_layer = pt.pooling.StratifiedSumPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@ -416,7 +443,8 @@ class TestPooling(unittest.TestCase):
|
||||
def test_stratified_prod(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
actual = pt.pooling.stratified_prod_pooling(d, labels)
|
||||
pooling_layer = pt.pooling.StratifiedProdPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
|
Loading…
Reference in New Issue
Block a user