diff --git a/tests/test_core.py b/tests/test_core.py index 757e678..4862758 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -49,18 +49,41 @@ def test_parse_distribution_custom_labels(): # Components initializers +def test_literal_comp_generate(): + protos = torch.rand(4, 3, 5, 5) + c = pt.initializers.LiteralCompInitializer(protos) + components = c.generate() + assert torch.allclose(components, protos) + + +def test_literal_comp_generate_from_list(): + protos = [[0, 1], [2, 3], [4, 5]] + c = pt.initializers.LiteralCompInitializer(protos) + with pytest.warns(UserWarning): + components = c.generate() + assert torch.allclose(components, torch.Tensor(protos)) + + def test_shape_aware_raises_error(): with pytest.raises(TypeError): _ = pt.initializers.ShapeAwareCompInitializer(shape=(2, )) -def test_literal_comp_generate(): +def test_data_aware_comp_generate(): protos = torch.rand(4, 3, 5, 5) - c = pt.initializers.LiteralCompInitializer(protos) + c = pt.initializers.DataAwareCompInitializer(protos) components = c.generate(num_components="IgnoreMe!") assert torch.allclose(components, protos) +def test_class_aware_comp_generate(): + protos = torch.rand(4, 2, 3, 5, 5) + plabels = torch.tensor([0, 0, 1, 1]).long() + c = pt.initializers.ClassAwareCompInitializer([protos, plabels]) + components = c.generate(distribution=[]) + assert torch.allclose(components, protos) + + def test_zeros_comp_generate(): shape = (3, 5, 5) c = pt.initializers.ZerosCompInitializer(shape) @@ -136,6 +159,13 @@ def test_stratified_selection_comp_generate(): # Labels initializers +def test_literal_labels_init(): + l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2]) + with pytest.warns(UserWarning): + labels = l.generate() + assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2])) + + def test_labels_init_from_list(): l = pt.initializers.LabelsInitializer() components = l.generate(distribution=[1, 1, 1]) @@ -154,7 +184,22 @@ def test_labels_init_from_tuple_illegal(): _ = l.generate(distribution=(1, 1, 1)) +def test_data_aware_labels_init(): + data, targets = [0, 1, 2, 3], [0, 0, 1, 1] + ds = pt.datasets.NumpyDataset(data, targets) + l = pt.initializers.DataAwareLabelsInitializer(ds) + labels = l.generate() + assert torch.allclose(labels, torch.LongTensor(targets)) + + # Reasonings initializers +def test_literal_reasonings_init(): + r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2]) + with pytest.warns(UserWarning): + reasonings = r.generate() + assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2])) + + def test_random_reasonings_init(): r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8) reasonings = r.generate(distribution=[0, 1])