[REFACTOR] Refactor parse_distribution
This commit is contained in:
parent
083cc929be
commit
9241475570
@ -23,10 +23,15 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
|||||||
return mesh, xx, yy
|
return mesh, xx, yy
|
||||||
|
|
||||||
|
|
||||||
def parse_distribution(
|
def distribution_from_list(list_dist: list[int], clabels: list[int] = []):
|
||||||
user_distribution: Union[dict[int, int], dict[str, str], list[int],
|
clabels = clabels or list(range(len(list_dist)))
|
||||||
tuple[int]]
|
distribution = dict(zip(clabels, list_dist))
|
||||||
) -> dict[int, int]:
|
return distribution
|
||||||
|
|
||||||
|
|
||||||
|
def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
|
||||||
|
list[int], tuple[int]],
|
||||||
|
clabels: list[int] = []) -> dict[int, int]:
|
||||||
"""Parse user-provided distribution.
|
"""Parse user-provided distribution.
|
||||||
|
|
||||||
Return a dictionary with integer keys that represent the class labels and
|
Return a dictionary with integer keys that represent the class labels and
|
||||||
@ -47,25 +52,20 @@ def parse_distribution(
|
|||||||
as one might expect.
|
as one might expect.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def from_list(list_dist):
|
|
||||||
clabels = list(range(len(list_dist)))
|
|
||||||
distribution = dict(zip(clabels, list_dist))
|
|
||||||
return distribution
|
|
||||||
|
|
||||||
if isinstance(user_distribution, dict):
|
if isinstance(user_distribution, dict):
|
||||||
if "num_classes" in user_distribution.keys():
|
if "num_classes" in user_distribution.keys():
|
||||||
num_classes = int(user_distribution["num_classes"])
|
num_classes = int(user_distribution["num_classes"])
|
||||||
per_class = int(user_distribution["per_class"])
|
per_class = int(user_distribution["per_class"])
|
||||||
return from_list([per_class] * num_classes)
|
return distribution_from_list([per_class] * num_classes, clabels)
|
||||||
else:
|
else:
|
||||||
return user_distribution
|
return user_distribution
|
||||||
elif isinstance(user_distribution, tuple):
|
elif isinstance(user_distribution, tuple):
|
||||||
assert len(user_distribution) == 2
|
assert len(user_distribution) == 2
|
||||||
num_classes, per_class = user_distribution
|
num_classes, per_class = user_distribution
|
||||||
num_classes, per_class = int(num_classes), int(per_class)
|
num_classes, per_class = int(num_classes), int(per_class)
|
||||||
return from_list([per_class] * num_classes)
|
return distribution_from_list([per_class] * num_classes, clabels)
|
||||||
elif isinstance(user_distribution, list):
|
elif isinstance(user_distribution, list):
|
||||||
return from_list(user_distribution)
|
return distribution_from_list(user_distribution, clabels)
|
||||||
else:
|
else:
|
||||||
msg = f"`distribution` not understood." \
|
msg = f"`distribution` not understood." \
|
||||||
f"You have provided: {user_distribution}."
|
f"You have provided: {user_distribution}."
|
||||||
|
Loading…
Reference in New Issue
Block a user