Add more utils
This commit is contained in:
parent
dfefd128c4
commit
b8969347b1
@ -1,4 +1,8 @@
|
|||||||
"""ProtoFlow utils module"""
|
"""ProtoFlow utils module"""
|
||||||
|
|
||||||
from .colors import hex_to_rgb, rgb_to_hex
|
from .colors import hex_to_rgb, rgb_to_hex
|
||||||
from .utils import mesh2d
|
from .utils import (
|
||||||
|
mesh2d,
|
||||||
|
parse_data_arg,
|
||||||
|
parse_distribution,
|
||||||
|
)
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
"""ProtoFlow utilities"""
|
"""ProtoFlow utilities"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
||||||
@ -16,3 +21,73 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
|||||||
np.linspace(y_min, y_max, resolution))
|
np.linspace(y_min, y_max, resolution))
|
||||||
mesh = np.c_[xx.ravel(), yy.ravel()]
|
mesh = np.c_[xx.ravel(), yy.ravel()]
|
||||||
return mesh, xx, yy
|
return mesh, xx, yy
|
||||||
|
|
||||||
|
|
||||||
|
def parse_distribution(user_distribution: Union[dict, list, tuple]):
|
||||||
|
"""Parse user-provided distribution.
|
||||||
|
|
||||||
|
Return a dictionary with integer keys that represent the class labels and
|
||||||
|
values that denote the number of components/prototypes with that class
|
||||||
|
label.
|
||||||
|
|
||||||
|
The argument `user_distribution` could be any one of a number of allowed
|
||||||
|
formats. If it is a Python list, it is assumed that there are as many
|
||||||
|
entries in this list as there are classes, and the value at each index of
|
||||||
|
this list describes the number of prototypes for that particular class. So,
|
||||||
|
[1, 1, 1] implies that we have three classes with one prototype per class.
|
||||||
|
If it is a Python tuple, a shorthand of (num_classes, prototypes_per_class)
|
||||||
|
is assumed. If it is a Python dictionary, the key-value pairs describe the
|
||||||
|
class label and the number of prototypes for that class respectively. So,
|
||||||
|
{0: 2, 1: 2, 2: 2} implies that we have three classes with labels {1, 2,
|
||||||
|
3}, each equipped with two prototypes. If however, the dictionary contains
|
||||||
|
the keys "num_classes" and "per_class", they are parsed to use their values
|
||||||
|
as one might expect.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def from_list(list_dist):
|
||||||
|
clabels = list(range(len(list_dist)))
|
||||||
|
distribution = dict(zip(clabels, list_dist))
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
if isinstance(user_distribution, dict):
|
||||||
|
if "num_classes" in user_distribution.keys():
|
||||||
|
num_classes = user_distribution["num_classes"]
|
||||||
|
per_class = user_distribution["per_class"]
|
||||||
|
return from_list([per_class] * num_classes)
|
||||||
|
else:
|
||||||
|
return user_distribution
|
||||||
|
elif isinstance(user_distribution, tuple):
|
||||||
|
assert len(user_distribution) == 2
|
||||||
|
num_classes, per_class = user_distribution
|
||||||
|
return from_list([per_class] * num_classes)
|
||||||
|
elif isinstance(user_distribution, list):
|
||||||
|
return from_list(user_distribution)
|
||||||
|
else:
|
||||||
|
msg = f"`distribution` not understood." \
|
||||||
|
f"You have provided: {user_distribution}."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
||||||
|
if isinstance(data_arg, Dataset):
|
||||||
|
ds_size = len(data_arg)
|
||||||
|
data_arg = DataLoader(data_arg, batch_size=ds_size)
|
||||||
|
|
||||||
|
if isinstance(data_arg, DataLoader):
|
||||||
|
data = torch.tensor([])
|
||||||
|
targets = torch.tensor([])
|
||||||
|
for x, y in data_arg:
|
||||||
|
data = torch.cat([data, x])
|
||||||
|
targets = torch.cat([targets, y])
|
||||||
|
else:
|
||||||
|
assert len(data_arg) == 2
|
||||||
|
data, targets = data_arg
|
||||||
|
if not isinstance(data, torch.Tensor):
|
||||||
|
wmsg = f"Converting data to {torch.Tensor}."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
data = torch.Tensor(data)
|
||||||
|
if not isinstance(targets, torch.LongTensor):
|
||||||
|
wmsg = f"Converting targets to {torch.LongTensor}."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
targets = torch.LongTensor(targets)
|
||||||
|
return data, targets
|
||||||
|
Loading…
Reference in New Issue
Block a user