[BUGFIX] Fix reasonings initializer dimension bug
This commit is contained in:
@@ -220,13 +220,6 @@ def test_ones_reasonings_init():
|
||||
assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
|
||||
|
||||
|
||||
def test_random_reasonings_init_channels_not_first():
|
||||
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
|
||||
reasonings = r.generate(distribution=[1, 2])
|
||||
assert reasonings.shape[0] == 2
|
||||
assert reasonings.shape[-1] == 3
|
||||
|
||||
|
||||
def test_pure_positive_reasonings_init_one_per_class():
|
||||
r = pt.initializers.PurePositiveReasoningsInitializer(
|
||||
components_first=False)
|
||||
@@ -234,13 +227,20 @@ def test_pure_positive_reasonings_init_one_per_class():
|
||||
assert torch.allclose(reasonings[0], torch.eye(4))
|
||||
|
||||
|
||||
def test_pure_positive_reasonings_init_unrepresented_class():
|
||||
r = pt.initializers.PurePositiveReasoningsInitializer(
|
||||
components_first=False)
|
||||
reasonings = r.generate(distribution=[1, 0, 1])
|
||||
def test_pure_positive_reasonings_init_unrepresented_classes():
|
||||
r = pt.initializers.PurePositiveReasoningsInitializer()
|
||||
reasonings = r.generate(distribution=[9, 0, 0, 0])
|
||||
assert reasonings.shape[0] == 9
|
||||
assert reasonings.shape[1] == 4
|
||||
assert reasonings.shape[2] == 2
|
||||
|
||||
|
||||
def test_random_reasonings_init_channels_not_first():
|
||||
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
|
||||
reasonings = r.generate(distribution=[0, 0, 0, 1])
|
||||
assert reasonings.shape[0] == 2
|
||||
assert reasonings.shape[1] == 2
|
||||
assert reasonings.shape[2] == 3
|
||||
assert reasonings.shape[1] == 4
|
||||
assert reasonings.shape[2] == 1
|
||||
|
||||
|
||||
# Transform initializers
|
||||
|
Reference in New Issue
Block a user