[TEST] Test literal initializers
This commit is contained in:
parent
fc9edeaa97
commit
d45e71256c
@ -49,18 +49,41 @@ def test_parse_distribution_custom_labels():
|
||||
|
||||
|
||||
# Components initializers
|
||||
def test_literal_comp_generate():
|
||||
protos = torch.rand(4, 3, 5, 5)
|
||||
c = pt.initializers.LiteralCompInitializer(protos)
|
||||
components = c.generate()
|
||||
assert torch.allclose(components, protos)
|
||||
|
||||
|
||||
def test_literal_comp_generate_from_list():
|
||||
protos = [[0, 1], [2, 3], [4, 5]]
|
||||
c = pt.initializers.LiteralCompInitializer(protos)
|
||||
with pytest.warns(UserWarning):
|
||||
components = c.generate()
|
||||
assert torch.allclose(components, torch.Tensor(protos))
|
||||
|
||||
|
||||
def test_shape_aware_raises_error():
|
||||
with pytest.raises(TypeError):
|
||||
_ = pt.initializers.ShapeAwareCompInitializer(shape=(2, ))
|
||||
|
||||
|
||||
def test_literal_comp_generate():
|
||||
def test_data_aware_comp_generate():
|
||||
protos = torch.rand(4, 3, 5, 5)
|
||||
c = pt.initializers.LiteralCompInitializer(protos)
|
||||
c = pt.initializers.DataAwareCompInitializer(protos)
|
||||
components = c.generate(num_components="IgnoreMe!")
|
||||
assert torch.allclose(components, protos)
|
||||
|
||||
|
||||
def test_class_aware_comp_generate():
|
||||
protos = torch.rand(4, 2, 3, 5, 5)
|
||||
plabels = torch.tensor([0, 0, 1, 1]).long()
|
||||
c = pt.initializers.ClassAwareCompInitializer([protos, plabels])
|
||||
components = c.generate(distribution=[])
|
||||
assert torch.allclose(components, protos)
|
||||
|
||||
|
||||
def test_zeros_comp_generate():
|
||||
shape = (3, 5, 5)
|
||||
c = pt.initializers.ZerosCompInitializer(shape)
|
||||
@ -136,6 +159,13 @@ def test_stratified_selection_comp_generate():
|
||||
|
||||
|
||||
# Labels initializers
|
||||
def test_literal_labels_init():
|
||||
l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2])
|
||||
with pytest.warns(UserWarning):
|
||||
labels = l.generate()
|
||||
assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2]))
|
||||
|
||||
|
||||
def test_labels_init_from_list():
|
||||
l = pt.initializers.LabelsInitializer()
|
||||
components = l.generate(distribution=[1, 1, 1])
|
||||
@ -154,7 +184,22 @@ def test_labels_init_from_tuple_illegal():
|
||||
_ = l.generate(distribution=(1, 1, 1))
|
||||
|
||||
|
||||
def test_data_aware_labels_init():
|
||||
data, targets = [0, 1, 2, 3], [0, 0, 1, 1]
|
||||
ds = pt.datasets.NumpyDataset(data, targets)
|
||||
l = pt.initializers.DataAwareLabelsInitializer(ds)
|
||||
labels = l.generate()
|
||||
assert torch.allclose(labels, torch.LongTensor(targets))
|
||||
|
||||
|
||||
# Reasonings initializers
|
||||
def test_literal_reasonings_init():
|
||||
r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2])
|
||||
with pytest.warns(UserWarning):
|
||||
reasonings = r.generate()
|
||||
assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2]))
|
||||
|
||||
|
||||
def test_random_reasonings_init():
|
||||
r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8)
|
||||
reasonings = r.generate(distribution=[0, 1])
|
||||
|
Loading…
Reference in New Issue
Block a user