[REFACTOR] Clean and move components and initializers into core
This commit is contained in:
@@ -23,7 +23,10 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
||||
return mesh, xx, yy
|
||||
|
||||
|
||||
def parse_distribution(user_distribution: Union[dict, list, tuple]):
|
||||
def parse_distribution(
|
||||
user_distribution: Union[dict[int, int], dict[str, str], list[int],
|
||||
tuple[int]]
|
||||
) -> dict[int, int]:
|
||||
"""Parse user-provided distribution.
|
||||
|
||||
Return a dictionary with integer keys that represent the class labels and
|
||||
@@ -51,14 +54,15 @@ def parse_distribution(user_distribution: Union[dict, list, tuple]):
|
||||
|
||||
if isinstance(user_distribution, dict):
|
||||
if "num_classes" in user_distribution.keys():
|
||||
num_classes = user_distribution["num_classes"]
|
||||
per_class = user_distribution["per_class"]
|
||||
num_classes = int(user_distribution["num_classes"])
|
||||
per_class = int(user_distribution["per_class"])
|
||||
return from_list([per_class] * num_classes)
|
||||
else:
|
||||
return user_distribution
|
||||
elif isinstance(user_distribution, tuple):
|
||||
assert len(user_distribution) == 2
|
||||
num_classes, per_class = user_distribution
|
||||
num_classes, per_class = int(num_classes), int(per_class)
|
||||
return from_list([per_class] * num_classes)
|
||||
elif isinstance(user_distribution, list):
|
||||
return from_list(user_distribution)
|
||||
|
Reference in New Issue
Block a user