[TEST] Test literal initializers
This commit is contained in:
		| @@ -49,18 +49,41 @@ def test_parse_distribution_custom_labels(): | |||||||
|  |  | ||||||
|  |  | ||||||
| # Components initializers | # Components initializers | ||||||
|  | def test_literal_comp_generate(): | ||||||
|  |     protos = torch.rand(4, 3, 5, 5) | ||||||
|  |     c = pt.initializers.LiteralCompInitializer(protos) | ||||||
|  |     components = c.generate() | ||||||
|  |     assert torch.allclose(components, protos) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_literal_comp_generate_from_list(): | ||||||
|  |     protos = [[0, 1], [2, 3], [4, 5]] | ||||||
|  |     c = pt.initializers.LiteralCompInitializer(protos) | ||||||
|  |     with pytest.warns(UserWarning): | ||||||
|  |         components = c.generate() | ||||||
|  |     assert torch.allclose(components, torch.Tensor(protos)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_shape_aware_raises_error(): | def test_shape_aware_raises_error(): | ||||||
|     with pytest.raises(TypeError): |     with pytest.raises(TypeError): | ||||||
|         _ = pt.initializers.ShapeAwareCompInitializer(shape=(2, )) |         _ = pt.initializers.ShapeAwareCompInitializer(shape=(2, )) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_literal_comp_generate(): | def test_data_aware_comp_generate(): | ||||||
|     protos = torch.rand(4, 3, 5, 5) |     protos = torch.rand(4, 3, 5, 5) | ||||||
|     c = pt.initializers.LiteralCompInitializer(protos) |     c = pt.initializers.DataAwareCompInitializer(protos) | ||||||
|     components = c.generate(num_components="IgnoreMe!") |     components = c.generate(num_components="IgnoreMe!") | ||||||
|     assert torch.allclose(components, protos) |     assert torch.allclose(components, protos) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_class_aware_comp_generate(): | ||||||
|  |     protos = torch.rand(4, 2, 3, 5, 5) | ||||||
|  |     plabels = torch.tensor([0, 0, 1, 1]).long() | ||||||
|  |     c = pt.initializers.ClassAwareCompInitializer([protos, plabels]) | ||||||
|  |     components = c.generate(distribution=[]) | ||||||
|  |     assert torch.allclose(components, protos) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_zeros_comp_generate(): | def test_zeros_comp_generate(): | ||||||
|     shape = (3, 5, 5) |     shape = (3, 5, 5) | ||||||
|     c = pt.initializers.ZerosCompInitializer(shape) |     c = pt.initializers.ZerosCompInitializer(shape) | ||||||
| @@ -136,6 +159,13 @@ def test_stratified_selection_comp_generate(): | |||||||
|  |  | ||||||
|  |  | ||||||
| # Labels initializers | # Labels initializers | ||||||
|  | def test_literal_labels_init(): | ||||||
|  |     l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2]) | ||||||
|  |     with pytest.warns(UserWarning): | ||||||
|  |         labels = l.generate() | ||||||
|  |     assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2])) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_labels_init_from_list(): | def test_labels_init_from_list(): | ||||||
|     l = pt.initializers.LabelsInitializer() |     l = pt.initializers.LabelsInitializer() | ||||||
|     components = l.generate(distribution=[1, 1, 1]) |     components = l.generate(distribution=[1, 1, 1]) | ||||||
| @@ -154,7 +184,22 @@ def test_labels_init_from_tuple_illegal(): | |||||||
|         _ = l.generate(distribution=(1, 1, 1)) |         _ = l.generate(distribution=(1, 1, 1)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_data_aware_labels_init(): | ||||||
|  |     data, targets = [0, 1, 2, 3], [0, 0, 1, 1] | ||||||
|  |     ds = pt.datasets.NumpyDataset(data, targets) | ||||||
|  |     l = pt.initializers.DataAwareLabelsInitializer(ds) | ||||||
|  |     labels = l.generate() | ||||||
|  |     assert torch.allclose(labels, torch.LongTensor(targets)) | ||||||
|  |  | ||||||
|  |  | ||||||
| # Reasonings initializers | # Reasonings initializers | ||||||
|  | def test_literal_reasonings_init(): | ||||||
|  |     r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2]) | ||||||
|  |     with pytest.warns(UserWarning): | ||||||
|  |         reasonings = r.generate() | ||||||
|  |     assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2])) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_random_reasonings_init(): | def test_random_reasonings_init(): | ||||||
|     r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8) |     r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8) | ||||||
|     reasonings = r.generate(distribution=[0, 1]) |     reasonings = r.generate(distribution=[0, 1]) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user