diff --git a/tests/test_core.py b/tests/test_core.py index 1a1327e..757e678 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -41,6 +41,13 @@ def test_parse_distribution_list(): assert distribution == {0: 1, 1: 1, 2: 0, 3: 2} +def test_parse_distribution_custom_labels(): + distribution = [1, 1, 0, 2] + clabels = [1, 2, 5, 3] + distribution = parse_distribution(distribution, clabels) + assert distribution == {1: 1, 2: 1, 5: 0, 3: 2} + + # Components initializers def test_shape_aware_raises_error(): with pytest.raises(TypeError): @@ -147,6 +154,50 @@ def test_labels_init_from_tuple_illegal(): _ = l.generate(distribution=(1, 1, 1)) +# Reasonings initializers +def test_random_reasonings_init(): + r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8) + reasonings = r.generate(distribution=[0, 1]) + assert torch.numel(reasonings) == 1 * 2 * 2 + assert reasonings.min() >= 0.2 + assert reasonings.max() <= 0.8 + + +def test_zeros_reasonings_init(): + r = pt.initializers.ZerosReasoningsInitializer() + reasonings = r.generate(distribution=[0, 1]) + assert torch.allclose(reasonings, torch.zeros(1, 2, 2)) + + +def test_ones_reasonings_init(): + r = pt.initializers.ZerosReasoningsInitializer() + reasonings = r.generate(distribution=[1, 2, 3]) + assert torch.allclose(reasonings, torch.zeros(6, 3, 2)) + + +def test_random_reasonings_init_channels_not_first(): + r = pt.initializers.RandomReasoningsInitializer(components_first=False) + reasonings = r.generate(distribution=[1, 2]) + assert reasonings.shape[0] == 2 + assert reasonings.shape[-1] == 3 + + +def test_pure_positive_reasonings_init_one_per_class(): + r = pt.initializers.PurePositiveReasoningsInitializer( + components_first=False) + reasonings = r.generate(distribution=(4, 1)) + assert torch.allclose(reasonings[0], torch.eye(4)) + + +def test_pure_positive_reasonings_init_unrepresented_class(): + r = pt.initializers.PurePositiveReasoningsInitializer( + components_first=False) + reasonings = r.generate(distribution=[1, 0, 1]) + assert reasonings.shape[0] == 2 + assert reasonings.shape[1] == 2 + assert reasonings.shape[2] == 3 + + # Components def test_components_no_initializer(): with pytest.raises(TypeError):