From 6ad665f8c248b4e7f9ece5f64d4c13e0f225d4b0 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sun, 13 Jun 2021 23:04:07 +0000 Subject: [PATCH] [REFACTOR] Simplify initializer validation --- prototorch/core/components.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index d2b0f40..a243318 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -29,18 +29,6 @@ def validate_initializer(initializer, instanceof): return True -def validate_components_initializer(initializer): - return validate_initializer(initializer, AbstractComponentsInitializer) - - -def validate_labels_initializer(initializer): - return validate_initializer(initializer, AbstractLabelsInitializer) - - -def validate_reasonings_initializer(initializer): - return validate_initializer(initializer, AbstractReasoningsInitializer) - - def gencat(ins, attr, init, *iargs, **ikwargs): """Generate new items and concatenate with existing items.""" new_items = init.generate(*iargs, **ikwargs) @@ -91,7 +79,7 @@ class Components(AbstractComponents): def add_components(self, num_components: int, initializer: AbstractComponentsInitializer): """Generate and add new components.""" - assert validate_components_initializer(initializer) + assert validate_initializer(initializer, AbstractComponentsInitializer) _components, new_components = gencat(self, "_components", initializer, num_components) self._register_components(_components) @@ -162,7 +150,7 @@ class Labels(AbstractLabels): distribution: Union[dict, tuple, list], initializer: AbstractLabelsInitializer = LabelsInitializer()): """Generate and add new labels.""" - assert validate_labels_initializer(initializer) + assert validate_initializer(initializer, AbstractLabelsInitializer) _labels, new_labels = gencat(self, "_labels", initializer, distribution) self._register_labels(_labels) @@ -202,8 +190,10 @@ class LabeledComponents(AbstractComponents): components_initializer, labels_initializer: AbstractLabelsInitializer = LabelsInitializer()): """Generate and add new components and labels.""" - assert validate_components_initializer(components_initializer) - assert validate_labels_initializer(labels_initializer) + assert validate_initializer(components_initializer, + AbstractComponentsInitializer) + assert validate_initializer(labels_initializer, + AbstractLabelsInitializer) if isinstance(components_initializer, ClassAwareCompInitializer): cikwargs = dict(distribution=distribution) else: @@ -270,8 +260,10 @@ class ReasoningComponents(AbstractComponents): def add_components(self, distribution, components_initializer, reasonings_initializer: AbstractReasoningsInitializer): # Checks - assert validate_components_initializer(components_initializer) - assert validate_reasonings_initializer(reasonings_initializer) + assert validate_initializer(components_initializer, + AbstractComponentsInitializer) + assert validate_initializer(reasonings_initializer, + AbstractReasoningsInitializer) distribution = parse_distribution(distribution)