[BUGFIX] Remove dangerous mutable default arguments
See https://stackoverflow.com/questions/1132941/least-astonishment-and-the-mutable-default-argument for more information.
This commit is contained in:
parent
1f458ac0cc
commit
0f450ed8a0
@ -179,7 +179,7 @@ class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
|
|||||||
self.num_classes = len(self.clabels)
|
self.num_classes = len(self.clabels)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(self, distribution: Union[dict, list, tuple] = []):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
...
|
...
|
||||||
return self.generate_end_hook(...)
|
return self.generate_end_hook(...)
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
|
|||||||
|
|
||||||
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
||||||
"""'Generate' components from provided data and requested distribution."""
|
"""'Generate' components from provided data and requested distribution."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple] = []):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
"""Ignore `distribution` and simply return transformed `self.data`."""
|
"""Ignore `distribution` and simply return transformed `self.data`."""
|
||||||
components = self.generate_end_hook(self.data)
|
components = self.generate_end_hook(self.data)
|
||||||
return components
|
return components
|
||||||
@ -249,7 +249,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
|||||||
def __init__(self, labels):
|
def __init__(self, labels):
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple] = []):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
"""Ignore `distribution` and simply return `self.labels`.
|
"""Ignore `distribution` and simply return `self.labels`.
|
||||||
|
|
||||||
Convert to long tensor, if necessary.
|
Convert to long tensor, if necessary.
|
||||||
@ -267,7 +267,7 @@ class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
|||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
self.data, self.targets = parse_data_arg(data)
|
self.data, self.targets = parse_data_arg(data)
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple] = []):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
"""Ignore `num_components` and simply return `self.targets`."""
|
"""Ignore `num_components` and simply return `self.targets`."""
|
||||||
return self.targets
|
return self.targets
|
||||||
|
|
||||||
@ -326,7 +326,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.reasonings = reasonings
|
self.reasonings = reasonings
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple] = []):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
"""Ignore `distributuion` and simply return self.reasonings."""
|
"""Ignore `distributuion` and simply return self.reasonings."""
|
||||||
reasonings = self.reasonings
|
reasonings = self.reasonings
|
||||||
if not isinstance(reasonings, torch.Tensor):
|
if not isinstance(reasonings, torch.Tensor):
|
||||||
|
@ -52,7 +52,7 @@ def test_parse_distribution_custom_labels():
|
|||||||
def test_literal_comp_generate():
|
def test_literal_comp_generate():
|
||||||
protos = torch.rand(4, 3, 5, 5)
|
protos = torch.rand(4, 3, 5, 5)
|
||||||
c = pt.initializers.LiteralCompInitializer(protos)
|
c = pt.initializers.LiteralCompInitializer(protos)
|
||||||
components = c.generate()
|
components = c.generate([])
|
||||||
assert torch.allclose(components, protos)
|
assert torch.allclose(components, protos)
|
||||||
|
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ def test_literal_comp_generate_from_list():
|
|||||||
protos = [[0, 1], [2, 3], [4, 5]]
|
protos = [[0, 1], [2, 3], [4, 5]]
|
||||||
c = pt.initializers.LiteralCompInitializer(protos)
|
c = pt.initializers.LiteralCompInitializer(protos)
|
||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
components = c.generate()
|
components = c.generate([])
|
||||||
assert torch.allclose(components, torch.Tensor(protos))
|
assert torch.allclose(components, torch.Tensor(protos))
|
||||||
|
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ def test_stratified_selection_comp_generate():
|
|||||||
def test_literal_labels_init():
|
def test_literal_labels_init():
|
||||||
l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2])
|
l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2])
|
||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
labels = l.generate()
|
labels = l.generate([])
|
||||||
assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2]))
|
assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2]))
|
||||||
|
|
||||||
|
|
||||||
@ -188,7 +188,7 @@ def test_data_aware_labels_init():
|
|||||||
data, targets = [0, 1, 2, 3], [0, 0, 1, 1]
|
data, targets = [0, 1, 2, 3], [0, 0, 1, 1]
|
||||||
ds = pt.datasets.NumpyDataset(data, targets)
|
ds = pt.datasets.NumpyDataset(data, targets)
|
||||||
l = pt.initializers.DataAwareLabelsInitializer(ds)
|
l = pt.initializers.DataAwareLabelsInitializer(ds)
|
||||||
labels = l.generate()
|
labels = l.generate([])
|
||||||
assert torch.allclose(labels, torch.LongTensor(targets))
|
assert torch.allclose(labels, torch.LongTensor(targets))
|
||||||
|
|
||||||
|
|
||||||
@ -196,7 +196,7 @@ def test_data_aware_labels_init():
|
|||||||
def test_literal_reasonings_init():
|
def test_literal_reasonings_init():
|
||||||
r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2])
|
r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2])
|
||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
reasonings = r.generate()
|
reasonings = r.generate([])
|
||||||
assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2]))
|
assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2]))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user