[TEST] Add tests for reasonings initializers
This commit is contained in:
parent
9241475570
commit
549e6a10c1
@ -41,6 +41,13 @@ def test_parse_distribution_list():
|
|||||||
assert distribution == {0: 1, 1: 1, 2: 0, 3: 2}
|
assert distribution == {0: 1, 1: 1, 2: 0, 3: 2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_distribution_custom_labels():
|
||||||
|
distribution = [1, 1, 0, 2]
|
||||||
|
clabels = [1, 2, 5, 3]
|
||||||
|
distribution = parse_distribution(distribution, clabels)
|
||||||
|
assert distribution == {1: 1, 2: 1, 5: 0, 3: 2}
|
||||||
|
|
||||||
|
|
||||||
# Components initializers
|
# Components initializers
|
||||||
def test_shape_aware_raises_error():
|
def test_shape_aware_raises_error():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
@ -147,6 +154,50 @@ def test_labels_init_from_tuple_illegal():
|
|||||||
_ = l.generate(distribution=(1, 1, 1))
|
_ = l.generate(distribution=(1, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
# Reasonings initializers
|
||||||
|
def test_random_reasonings_init():
|
||||||
|
r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8)
|
||||||
|
reasonings = r.generate(distribution=[0, 1])
|
||||||
|
assert torch.numel(reasonings) == 1 * 2 * 2
|
||||||
|
assert reasonings.min() >= 0.2
|
||||||
|
assert reasonings.max() <= 0.8
|
||||||
|
|
||||||
|
|
||||||
|
def test_zeros_reasonings_init():
|
||||||
|
r = pt.initializers.ZerosReasoningsInitializer()
|
||||||
|
reasonings = r.generate(distribution=[0, 1])
|
||||||
|
assert torch.allclose(reasonings, torch.zeros(1, 2, 2))
|
||||||
|
|
||||||
|
|
||||||
|
def test_ones_reasonings_init():
|
||||||
|
r = pt.initializers.ZerosReasoningsInitializer()
|
||||||
|
reasonings = r.generate(distribution=[1, 2, 3])
|
||||||
|
assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_reasonings_init_channels_not_first():
|
||||||
|
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
|
||||||
|
reasonings = r.generate(distribution=[1, 2])
|
||||||
|
assert reasonings.shape[0] == 2
|
||||||
|
assert reasonings.shape[-1] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_pure_positive_reasonings_init_one_per_class():
|
||||||
|
r = pt.initializers.PurePositiveReasoningsInitializer(
|
||||||
|
components_first=False)
|
||||||
|
reasonings = r.generate(distribution=(4, 1))
|
||||||
|
assert torch.allclose(reasonings[0], torch.eye(4))
|
||||||
|
|
||||||
|
|
||||||
|
def test_pure_positive_reasonings_init_unrepresented_class():
|
||||||
|
r = pt.initializers.PurePositiveReasoningsInitializer(
|
||||||
|
components_first=False)
|
||||||
|
reasonings = r.generate(distribution=[1, 0, 1])
|
||||||
|
assert reasonings.shape[0] == 2
|
||||||
|
assert reasonings.shape[1] == 2
|
||||||
|
assert reasonings.shape[2] == 3
|
||||||
|
|
||||||
|
|
||||||
# Components
|
# Components
|
||||||
def test_components_no_initializer():
|
def test_components_no_initializer():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
|
Loading…
Reference in New Issue
Block a user