[REFACTOR] Simplify initializer validation

This commit is contained in:
Jensun Ravichandran 2021-06-13 23:04:07 +00:00
parent 2af1da7f23
commit 6ad665f8c2

View File

@ -29,18 +29,6 @@ def validate_initializer(initializer, instanceof):
return True
def validate_components_initializer(initializer):
return validate_initializer(initializer, AbstractComponentsInitializer)
def validate_labels_initializer(initializer):
return validate_initializer(initializer, AbstractLabelsInitializer)
def validate_reasonings_initializer(initializer):
return validate_initializer(initializer, AbstractReasoningsInitializer)
def gencat(ins, attr, init, *iargs, **ikwargs):
"""Generate new items and concatenate with existing items."""
new_items = init.generate(*iargs, **ikwargs)
@ -91,7 +79,7 @@ class Components(AbstractComponents):
def add_components(self, num_components: int,
initializer: AbstractComponentsInitializer):
"""Generate and add new components."""
assert validate_components_initializer(initializer)
assert validate_initializer(initializer, AbstractComponentsInitializer)
_components, new_components = gencat(self, "_components", initializer,
num_components)
self._register_components(_components)
@ -162,7 +150,7 @@ class Labels(AbstractLabels):
distribution: Union[dict, tuple, list],
initializer: AbstractLabelsInitializer = LabelsInitializer()):
"""Generate and add new labels."""
assert validate_labels_initializer(initializer)
assert validate_initializer(initializer, AbstractLabelsInitializer)
_labels, new_labels = gencat(self, "_labels", initializer,
distribution)
self._register_labels(_labels)
@ -202,8 +190,10 @@ class LabeledComponents(AbstractComponents):
components_initializer,
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
"""Generate and add new components and labels."""
assert validate_components_initializer(components_initializer)
assert validate_labels_initializer(labels_initializer)
assert validate_initializer(components_initializer,
AbstractComponentsInitializer)
assert validate_initializer(labels_initializer,
AbstractLabelsInitializer)
if isinstance(components_initializer, ClassAwareCompInitializer):
cikwargs = dict(distribution=distribution)
else:
@ -270,8 +260,10 @@ class ReasoningComponents(AbstractComponents):
def add_components(self, distribution, components_initializer,
reasonings_initializer: AbstractReasoningsInitializer):
# Checks
assert validate_components_initializer(components_initializer)
assert validate_reasonings_initializer(reasonings_initializer)
assert validate_initializer(components_initializer,
AbstractComponentsInitializer)
assert validate_initializer(reasonings_initializer,
AbstractReasoningsInitializer)
distribution = parse_distribution(distribution)