[TEST] Test literal initializers
This commit is contained in:
parent
fc9edeaa97
commit
d45e71256c
@ -49,18 +49,41 @@ def test_parse_distribution_custom_labels():
|
|||||||
|
|
||||||
|
|
||||||
# Components initializers
|
# Components initializers
|
||||||
|
def test_literal_comp_generate():
|
||||||
|
protos = torch.rand(4, 3, 5, 5)
|
||||||
|
c = pt.initializers.LiteralCompInitializer(protos)
|
||||||
|
components = c.generate()
|
||||||
|
assert torch.allclose(components, protos)
|
||||||
|
|
||||||
|
|
||||||
|
def test_literal_comp_generate_from_list():
|
||||||
|
protos = [[0, 1], [2, 3], [4, 5]]
|
||||||
|
c = pt.initializers.LiteralCompInitializer(protos)
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
components = c.generate()
|
||||||
|
assert torch.allclose(components, torch.Tensor(protos))
|
||||||
|
|
||||||
|
|
||||||
def test_shape_aware_raises_error():
|
def test_shape_aware_raises_error():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
_ = pt.initializers.ShapeAwareCompInitializer(shape=(2, ))
|
_ = pt.initializers.ShapeAwareCompInitializer(shape=(2, ))
|
||||||
|
|
||||||
|
|
||||||
def test_literal_comp_generate():
|
def test_data_aware_comp_generate():
|
||||||
protos = torch.rand(4, 3, 5, 5)
|
protos = torch.rand(4, 3, 5, 5)
|
||||||
c = pt.initializers.LiteralCompInitializer(protos)
|
c = pt.initializers.DataAwareCompInitializer(protos)
|
||||||
components = c.generate(num_components="IgnoreMe!")
|
components = c.generate(num_components="IgnoreMe!")
|
||||||
assert torch.allclose(components, protos)
|
assert torch.allclose(components, protos)
|
||||||
|
|
||||||
|
|
||||||
|
def test_class_aware_comp_generate():
|
||||||
|
protos = torch.rand(4, 2, 3, 5, 5)
|
||||||
|
plabels = torch.tensor([0, 0, 1, 1]).long()
|
||||||
|
c = pt.initializers.ClassAwareCompInitializer([protos, plabels])
|
||||||
|
components = c.generate(distribution=[])
|
||||||
|
assert torch.allclose(components, protos)
|
||||||
|
|
||||||
|
|
||||||
def test_zeros_comp_generate():
|
def test_zeros_comp_generate():
|
||||||
shape = (3, 5, 5)
|
shape = (3, 5, 5)
|
||||||
c = pt.initializers.ZerosCompInitializer(shape)
|
c = pt.initializers.ZerosCompInitializer(shape)
|
||||||
@ -136,6 +159,13 @@ def test_stratified_selection_comp_generate():
|
|||||||
|
|
||||||
|
|
||||||
# Labels initializers
|
# Labels initializers
|
||||||
|
def test_literal_labels_init():
|
||||||
|
l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2])
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
labels = l.generate()
|
||||||
|
assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2]))
|
||||||
|
|
||||||
|
|
||||||
def test_labels_init_from_list():
|
def test_labels_init_from_list():
|
||||||
l = pt.initializers.LabelsInitializer()
|
l = pt.initializers.LabelsInitializer()
|
||||||
components = l.generate(distribution=[1, 1, 1])
|
components = l.generate(distribution=[1, 1, 1])
|
||||||
@ -154,7 +184,22 @@ def test_labels_init_from_tuple_illegal():
|
|||||||
_ = l.generate(distribution=(1, 1, 1))
|
_ = l.generate(distribution=(1, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_aware_labels_init():
|
||||||
|
data, targets = [0, 1, 2, 3], [0, 0, 1, 1]
|
||||||
|
ds = pt.datasets.NumpyDataset(data, targets)
|
||||||
|
l = pt.initializers.DataAwareLabelsInitializer(ds)
|
||||||
|
labels = l.generate()
|
||||||
|
assert torch.allclose(labels, torch.LongTensor(targets))
|
||||||
|
|
||||||
|
|
||||||
# Reasonings initializers
|
# Reasonings initializers
|
||||||
|
def test_literal_reasonings_init():
|
||||||
|
r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2])
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
reasonings = r.generate()
|
||||||
|
assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2]))
|
||||||
|
|
||||||
|
|
||||||
def test_random_reasonings_init():
|
def test_random_reasonings_init():
|
||||||
r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8)
|
r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8)
|
||||||
reasonings = r.generate(distribution=[0, 1])
|
reasonings = r.generate(distribution=[0, 1])
|
||||||
|
Loading…
Reference in New Issue
Block a user