[REFACTOR] Refactor parse_distribution

This commit is contained in:
Jensun Ravichandran 2021-06-14 17:20:22 +02:00
parent 083cc929be
commit 9241475570

View File

@ -23,10 +23,15 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
return mesh, xx, yy return mesh, xx, yy
def parse_distribution( def distribution_from_list(list_dist: list[int], clabels: list[int] = []):
user_distribution: Union[dict[int, int], dict[str, str], list[int], clabels = clabels or list(range(len(list_dist)))
tuple[int]] distribution = dict(zip(clabels, list_dist))
) -> dict[int, int]: return distribution
def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
list[int], tuple[int]],
clabels: list[int] = []) -> dict[int, int]:
"""Parse user-provided distribution. """Parse user-provided distribution.
Return a dictionary with integer keys that represent the class labels and Return a dictionary with integer keys that represent the class labels and
@ -47,25 +52,20 @@ def parse_distribution(
as one might expect. as one might expect.
""" """
def from_list(list_dist):
clabels = list(range(len(list_dist)))
distribution = dict(zip(clabels, list_dist))
return distribution
if isinstance(user_distribution, dict): if isinstance(user_distribution, dict):
if "num_classes" in user_distribution.keys(): if "num_classes" in user_distribution.keys():
num_classes = int(user_distribution["num_classes"]) num_classes = int(user_distribution["num_classes"])
per_class = int(user_distribution["per_class"]) per_class = int(user_distribution["per_class"])
return from_list([per_class] * num_classes) return distribution_from_list([per_class] * num_classes, clabels)
else: else:
return user_distribution return user_distribution
elif isinstance(user_distribution, tuple): elif isinstance(user_distribution, tuple):
assert len(user_distribution) == 2 assert len(user_distribution) == 2
num_classes, per_class = user_distribution num_classes, per_class = user_distribution
num_classes, per_class = int(num_classes), int(per_class) num_classes, per_class = int(num_classes), int(per_class)
return from_list([per_class] * num_classes) return distribution_from_list([per_class] * num_classes, clabels)
elif isinstance(user_distribution, list): elif isinstance(user_distribution, list):
return from_list(user_distribution) return distribution_from_list(user_distribution, clabels)
else: else:
msg = f"`distribution` not understood." \ msg = f"`distribution` not understood." \
f"You have provided: {user_distribution}." f"You have provided: {user_distribution}."