[TEST] Add tests for reasonings initializers

This commit is contained in:
Jensun Ravichandran 2021-06-14 17:20:57 +02:00
parent 9241475570
commit 549e6a10c1

View File

@ -41,6 +41,13 @@ def test_parse_distribution_list():
assert distribution == {0: 1, 1: 1, 2: 0, 3: 2}
def test_parse_distribution_custom_labels():
distribution = [1, 1, 0, 2]
clabels = [1, 2, 5, 3]
distribution = parse_distribution(distribution, clabels)
assert distribution == {1: 1, 2: 1, 5: 0, 3: 2}
# Components initializers
def test_shape_aware_raises_error():
with pytest.raises(TypeError):
@ -147,6 +154,50 @@ def test_labels_init_from_tuple_illegal():
_ = l.generate(distribution=(1, 1, 1))
# Reasonings initializers
def test_random_reasonings_init():
r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8)
reasonings = r.generate(distribution=[0, 1])
assert torch.numel(reasonings) == 1 * 2 * 2
assert reasonings.min() >= 0.2
assert reasonings.max() <= 0.8
def test_zeros_reasonings_init():
r = pt.initializers.ZerosReasoningsInitializer()
reasonings = r.generate(distribution=[0, 1])
assert torch.allclose(reasonings, torch.zeros(1, 2, 2))
def test_ones_reasonings_init():
r = pt.initializers.ZerosReasoningsInitializer()
reasonings = r.generate(distribution=[1, 2, 3])
assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
def test_random_reasonings_init_channels_not_first():
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
reasonings = r.generate(distribution=[1, 2])
assert reasonings.shape[0] == 2
assert reasonings.shape[-1] == 3
def test_pure_positive_reasonings_init_one_per_class():
r = pt.initializers.PurePositiveReasoningsInitializer(
components_first=False)
reasonings = r.generate(distribution=(4, 1))
assert torch.allclose(reasonings[0], torch.eye(4))
def test_pure_positive_reasonings_init_unrepresented_class():
r = pt.initializers.PurePositiveReasoningsInitializer(
components_first=False)
reasonings = r.generate(distribution=[1, 0, 1])
assert reasonings.shape[0] == 2
assert reasonings.shape[1] == 2
assert reasonings.shape[2] == 3
# Components
def test_components_no_initializer():
with pytest.raises(TypeError):