[BUGFIX] Remove dangerous mutable default arguments
See https://stackoverflow.com/questions/1132941/least-astonishment-and-the-mutable-default-argument for more information.
This commit is contained in:
@@ -52,7 +52,7 @@ def test_parse_distribution_custom_labels():
|
||||
def test_literal_comp_generate():
|
||||
protos = torch.rand(4, 3, 5, 5)
|
||||
c = pt.initializers.LiteralCompInitializer(protos)
|
||||
components = c.generate()
|
||||
components = c.generate([])
|
||||
assert torch.allclose(components, protos)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_literal_comp_generate_from_list():
|
||||
protos = [[0, 1], [2, 3], [4, 5]]
|
||||
c = pt.initializers.LiteralCompInitializer(protos)
|
||||
with pytest.warns(UserWarning):
|
||||
components = c.generate()
|
||||
components = c.generate([])
|
||||
assert torch.allclose(components, torch.Tensor(protos))
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ def test_stratified_selection_comp_generate():
|
||||
def test_literal_labels_init():
|
||||
l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2])
|
||||
with pytest.warns(UserWarning):
|
||||
labels = l.generate()
|
||||
labels = l.generate([])
|
||||
assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2]))
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ def test_data_aware_labels_init():
|
||||
data, targets = [0, 1, 2, 3], [0, 0, 1, 1]
|
||||
ds = pt.datasets.NumpyDataset(data, targets)
|
||||
l = pt.initializers.DataAwareLabelsInitializer(ds)
|
||||
labels = l.generate()
|
||||
labels = l.generate([])
|
||||
assert torch.allclose(labels, torch.LongTensor(targets))
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ def test_data_aware_labels_init():
|
||||
def test_literal_reasonings_init():
|
||||
r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2])
|
||||
with pytest.warns(UserWarning):
|
||||
reasonings = r.generate()
|
||||
reasonings = r.generate([])
|
||||
assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2]))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user