[BUGFIX] Fix reasonings initializer dimension bug

This commit is contained in:
Jensun Ravichandran
2021-06-17 18:10:05 +02:00
parent ae11fedbf3
commit de61bf7bca
4 changed files with 23 additions and 20 deletions

View File

@@ -220,13 +220,6 @@ def test_ones_reasonings_init():
assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
def test_random_reasonings_init_channels_not_first():
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
reasonings = r.generate(distribution=[1, 2])
assert reasonings.shape[0] == 2
assert reasonings.shape[-1] == 3
def test_pure_positive_reasonings_init_one_per_class():
r = pt.initializers.PurePositiveReasoningsInitializer(
components_first=False)
@@ -234,13 +227,20 @@ def test_pure_positive_reasonings_init_one_per_class():
assert torch.allclose(reasonings[0], torch.eye(4))
def test_pure_positive_reasonings_init_unrepresented_class():
r = pt.initializers.PurePositiveReasoningsInitializer(
components_first=False)
reasonings = r.generate(distribution=[1, 0, 1])
def test_pure_positive_reasonings_init_unrepresented_classes():
r = pt.initializers.PurePositiveReasoningsInitializer()
reasonings = r.generate(distribution=[9, 0, 0, 0])
assert reasonings.shape[0] == 9
assert reasonings.shape[1] == 4
assert reasonings.shape[2] == 2
def test_random_reasonings_init_channels_not_first():
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
reasonings = r.generate(distribution=[0, 0, 0, 1])
assert reasonings.shape[0] == 2
assert reasonings.shape[1] == 2
assert reasonings.shape[2] == 3
assert reasonings.shape[1] == 4
assert reasonings.shape[2] == 1
# Transform initializers