Add more utils
This commit is contained in:
		@@ -1,4 +1,8 @@
 | 
				
			|||||||
"""ProtoFlow utils module"""
 | 
					"""ProtoFlow utils module"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .colors import hex_to_rgb, rgb_to_hex
 | 
					from .colors import hex_to_rgb, rgb_to_hex
 | 
				
			||||||
from .utils import mesh2d
 | 
					from .utils import (
 | 
				
			||||||
 | 
					    mesh2d,
 | 
				
			||||||
 | 
					    parse_data_arg,
 | 
				
			||||||
 | 
					    parse_distribution,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,11 @@
 | 
				
			|||||||
"""ProtoFlow utilities"""
 | 
					"""ProtoFlow utilities"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader, Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
 | 
					def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
 | 
				
			||||||
@@ -16,3 +21,73 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
 | 
				
			|||||||
                         np.linspace(y_min, y_max, resolution))
 | 
					                         np.linspace(y_min, y_max, resolution))
 | 
				
			||||||
    mesh = np.c_[xx.ravel(), yy.ravel()]
 | 
					    mesh = np.c_[xx.ravel(), yy.ravel()]
 | 
				
			||||||
    return mesh, xx, yy
 | 
					    return mesh, xx, yy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_distribution(user_distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					    """Parse user-provided distribution.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Return a dictionary with integer keys that represent the class labels and
 | 
				
			||||||
 | 
					    values that denote the number of components/prototypes with that class
 | 
				
			||||||
 | 
					    label.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The argument `user_distribution` could be any one of a number of allowed
 | 
				
			||||||
 | 
					    formats. If it is a Python list, it is assumed that there are as many
 | 
				
			||||||
 | 
					    entries in this list as there are classes, and the value at each index of
 | 
				
			||||||
 | 
					    this list describes the number of prototypes for that particular class. So,
 | 
				
			||||||
 | 
					    [1, 1, 1] implies that we have three classes with one prototype per class.
 | 
				
			||||||
 | 
					    If it is a Python tuple, a shorthand of (num_classes, prototypes_per_class)
 | 
				
			||||||
 | 
					    is assumed. If it is a Python dictionary, the key-value pairs describe the
 | 
				
			||||||
 | 
					    class label and the number of prototypes for that class respectively. So,
 | 
				
			||||||
 | 
					    {0: 2, 1: 2, 2: 2} implies that we have three classes with labels {1, 2,
 | 
				
			||||||
 | 
					    3}, each equipped with two prototypes. If however, the dictionary contains
 | 
				
			||||||
 | 
					    the keys "num_classes" and "per_class", they are parsed to use their values
 | 
				
			||||||
 | 
					    as one might expect.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def from_list(list_dist):
 | 
				
			||||||
 | 
					        clabels = list(range(len(list_dist)))
 | 
				
			||||||
 | 
					        distribution = dict(zip(clabels, list_dist))
 | 
				
			||||||
 | 
					        return distribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if isinstance(user_distribution, dict):
 | 
				
			||||||
 | 
					        if "num_classes" in user_distribution.keys():
 | 
				
			||||||
 | 
					            num_classes = user_distribution["num_classes"]
 | 
				
			||||||
 | 
					            per_class = user_distribution["per_class"]
 | 
				
			||||||
 | 
					            return from_list([per_class] * num_classes)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return user_distribution
 | 
				
			||||||
 | 
					    elif isinstance(user_distribution, tuple):
 | 
				
			||||||
 | 
					        assert len(user_distribution) == 2
 | 
				
			||||||
 | 
					        num_classes, per_class = user_distribution
 | 
				
			||||||
 | 
					        return from_list([per_class] * num_classes)
 | 
				
			||||||
 | 
					    elif isinstance(user_distribution, list):
 | 
				
			||||||
 | 
					        return from_list(user_distribution)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        msg = f"`distribution` not understood." \
 | 
				
			||||||
 | 
					            f"You have provided: {user_distribution}."
 | 
				
			||||||
 | 
					        raise ValueError(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
 | 
				
			||||||
 | 
					    if isinstance(data_arg, Dataset):
 | 
				
			||||||
 | 
					        ds_size = len(data_arg)
 | 
				
			||||||
 | 
					        data_arg = DataLoader(data_arg, batch_size=ds_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if isinstance(data_arg, DataLoader):
 | 
				
			||||||
 | 
					        data = torch.tensor([])
 | 
				
			||||||
 | 
					        targets = torch.tensor([])
 | 
				
			||||||
 | 
					        for x, y in data_arg:
 | 
				
			||||||
 | 
					            data = torch.cat([data, x])
 | 
				
			||||||
 | 
					            targets = torch.cat([targets, y])
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        assert len(data_arg) == 2
 | 
				
			||||||
 | 
					        data, targets = data_arg
 | 
				
			||||||
 | 
					        if not isinstance(data, torch.Tensor):
 | 
				
			||||||
 | 
					            wmsg = f"Converting data to {torch.Tensor}."
 | 
				
			||||||
 | 
					            warnings.warn(wmsg)
 | 
				
			||||||
 | 
					            data = torch.Tensor(data)
 | 
				
			||||||
 | 
					        if not isinstance(targets, torch.LongTensor):
 | 
				
			||||||
 | 
					            wmsg = f"Converting targets to {torch.LongTensor}."
 | 
				
			||||||
 | 
					            warnings.warn(wmsg)
 | 
				
			||||||
 | 
					            targets = torch.LongTensor(targets)
 | 
				
			||||||
 | 
					    return data, targets
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user