[BUGFIX] Remove dangerous mutable default arguments

See
https://stackoverflow.com/questions/1132941/least-astonishment-and-the-mutable-default-argument
for more information.
This commit is contained in:
Jensun Ravichandran 2021-06-15 00:14:34 +02:00
parent 1f458ac0cc
commit 0f450ed8a0
2 changed files with 10 additions and 10 deletions

View File

@ -179,7 +179,7 @@ class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
self.num_classes = len(self.clabels) self.num_classes = len(self.clabels)
@abstractmethod @abstractmethod
def generate(self, distribution: Union[dict, list, tuple] = []): def generate(self, distribution: Union[dict, list, tuple]):
... ...
return self.generate_end_hook(...) return self.generate_end_hook(...)
@ -190,7 +190,7 @@ class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer): class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
"""'Generate' components from provided data and requested distribution.""" """'Generate' components from provided data and requested distribution."""
def generate(self, distribution: Union[dict, list, tuple] = []): def generate(self, distribution: Union[dict, list, tuple]):
"""Ignore `distribution` and simply return transformed `self.data`.""" """Ignore `distribution` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data) components = self.generate_end_hook(self.data)
return components return components
@ -249,7 +249,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
def __init__(self, labels): def __init__(self, labels):
self.labels = labels self.labels = labels
def generate(self, distribution: Union[dict, list, tuple] = []): def generate(self, distribution: Union[dict, list, tuple]):
"""Ignore `distribution` and simply return `self.labels`. """Ignore `distribution` and simply return `self.labels`.
Convert to long tensor, if necessary. Convert to long tensor, if necessary.
@ -267,7 +267,7 @@ class DataAwareLabelsInitializer(AbstractLabelsInitializer):
def __init__(self, data): def __init__(self, data):
self.data, self.targets = parse_data_arg(data) self.data, self.targets = parse_data_arg(data)
def generate(self, distribution: Union[dict, list, tuple] = []): def generate(self, distribution: Union[dict, list, tuple]):
"""Ignore `num_components` and simply return `self.targets`.""" """Ignore `num_components` and simply return `self.targets`."""
return self.targets return self.targets
@ -326,7 +326,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.reasonings = reasonings self.reasonings = reasonings
def generate(self, distribution: Union[dict, list, tuple] = []): def generate(self, distribution: Union[dict, list, tuple]):
"""Ignore `distributuion` and simply return self.reasonings.""" """Ignore `distributuion` and simply return self.reasonings."""
reasonings = self.reasonings reasonings = self.reasonings
if not isinstance(reasonings, torch.Tensor): if not isinstance(reasonings, torch.Tensor):

View File

@ -52,7 +52,7 @@ def test_parse_distribution_custom_labels():
def test_literal_comp_generate(): def test_literal_comp_generate():
protos = torch.rand(4, 3, 5, 5) protos = torch.rand(4, 3, 5, 5)
c = pt.initializers.LiteralCompInitializer(protos) c = pt.initializers.LiteralCompInitializer(protos)
components = c.generate() components = c.generate([])
assert torch.allclose(components, protos) assert torch.allclose(components, protos)
@ -60,7 +60,7 @@ def test_literal_comp_generate_from_list():
protos = [[0, 1], [2, 3], [4, 5]] protos = [[0, 1], [2, 3], [4, 5]]
c = pt.initializers.LiteralCompInitializer(protos) c = pt.initializers.LiteralCompInitializer(protos)
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
components = c.generate() components = c.generate([])
assert torch.allclose(components, torch.Tensor(protos)) assert torch.allclose(components, torch.Tensor(protos))
@ -162,7 +162,7 @@ def test_stratified_selection_comp_generate():
def test_literal_labels_init(): def test_literal_labels_init():
l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2]) l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2])
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
labels = l.generate() labels = l.generate([])
assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2])) assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2]))
@ -188,7 +188,7 @@ def test_data_aware_labels_init():
data, targets = [0, 1, 2, 3], [0, 0, 1, 1] data, targets = [0, 1, 2, 3], [0, 0, 1, 1]
ds = pt.datasets.NumpyDataset(data, targets) ds = pt.datasets.NumpyDataset(data, targets)
l = pt.initializers.DataAwareLabelsInitializer(ds) l = pt.initializers.DataAwareLabelsInitializer(ds)
labels = l.generate() labels = l.generate([])
assert torch.allclose(labels, torch.LongTensor(targets)) assert torch.allclose(labels, torch.LongTensor(targets))
@ -196,7 +196,7 @@ def test_data_aware_labels_init():
def test_literal_reasonings_init(): def test_literal_reasonings_init():
r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2]) r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2])
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
reasonings = r.generate() reasonings = r.generate([])
assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2])) assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2]))