refactor(api)!: merge the new api changes into dev

BREAKING CHANGE: remove the following
`prototorch/functions/*`
`prototorch/components/*`
`prototorch/modules/*`
BREAKING CHANGE: move `initializers` into the `prototorch.initializers`
namespace from the `prototorch.components` namespace
BREAKING CHANGE: `functions` and `modules` and moved into `core` and `nn`
This commit is contained in:
Jensun Ravichandran
2021-06-18 18:20:30 +02:00
49 changed files with 2465 additions and 2201 deletions

View File

@@ -0,0 +1,8 @@
"""ProtoFlow utils module"""
from .colors import hex_to_rgb, rgb_to_hex
from .utils import (
mesh2d,
parse_data_arg,
parse_distribution,
)

View File

@@ -1,46 +0,0 @@
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
from collections import defaultdict
from typing import Dict, List
from matplotlib.animation import ArtistAnimation
from matplotlib.artist import Artist
from matplotlib.figure import Figure
__version__ = "0.2.0"
class Camera:
"""Make animations easier."""
def __init__(self, figure: Figure) -> None:
"""Create camera from matplotlib figure."""
self._figure = figure
# need to keep track off artists for each axis
self._offsets: Dict[str, Dict[int, int]] = {
k: defaultdict(int)
for k in
["collections", "patches", "lines", "texts", "artists", "images"]
}
self._photos: List[List[Artist]] = []
def snap(self) -> List[Artist]:
"""Capture current state of the figure."""
frame_artists: List[Artist] = []
for i, axis in enumerate(self._figure.axes):
if axis.legend_ is not None:
axis.add_artist(axis.legend_)
for name in self._offsets:
new_artists = getattr(axis, name)[self._offsets[name][i]:]
frame_artists += new_artists
self._offsets[name][i] += len(new_artists)
self._photos.append(frame_artists)
return frame_artists
def animate(self, *args, **kwargs) -> ArtistAnimation:
"""Animate the snapshots taken.
Uses matplotlib.animation.ArtistAnimation
Returns
-------
ArtistAnimation
"""
return ArtistAnimation(self._figure, self._photos, *args, **kwargs)

View File

@@ -1,78 +1,15 @@
"""ProtoFlow color utilities."""
import matplotlib.lines as mlines
from matplotlib import cm
from matplotlib.colors import Normalize, to_hex, to_rgb
"""ProtoFlow color utilities"""
def color_scheme(n,
cmap="viridis",
form="hex",
tikz=False,
zero_indexed=False):
"""Return *n* colors from the color scheme.
Arguments:
n (int): number of colors to return
Keyword Arguments:
cmap (str): Name of a matplotlib `colormap\
<https://matplotlib.org/3.1.1/gallery/color/colormap_reference.html>`_.
form (str): Colorformat (supports "hex" and "rgb").
tikz (bool): Output as `TikZ <https://github.com/pgf-tikz/pgf>`_
command.
zero_indexed (bool): Use zero indexing for output array.
Returns:
(list): List of colors
"""
cmap = cm.get_cmap(cmap)
colornorm = Normalize(vmin=1, vmax=n)
hex_map = dict()
rgb_map = dict()
for cl in range(1, n + 1):
if zero_indexed:
hex_map[cl - 1] = to_hex(cmap(colornorm(cl)))
rgb_map[cl - 1] = to_rgb(cmap(colornorm(cl)))
else:
hex_map[cl] = to_hex(cmap(colornorm(cl)))
rgb_map[cl] = to_rgb(cmap(colornorm(cl)))
if tikz:
for k, v in rgb_map.items():
print(f"\\definecolor{{color-{k}}}{{rgb}}{{{v[0]},{v[1]},{v[2]}}}")
if form == "hex":
return hex_map
elif form == "rgb":
return rgb_map
else:
return hex_map
def hex_to_rgb(hex_values):
for v in hex_values:
v = v.lstrip('#')
lv = len(v)
c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)]
yield c
def get_legend_handles(labels, marker="dots", zero_indexed=False):
"""Return matplotlib legend handles and colors."""
handles = list()
n = len(labels)
colors = color_scheme(n,
cmap="viridis",
form="hex",
zero_indexed=zero_indexed)
for label, color in zip(labels, colors.values()):
if marker == "dots":
handle = mlines.Line2D(
[],
[],
color="white",
markerfacecolor=color,
marker="o",
markersize=10,
markeredgecolor="k",
label=label,
)
else:
handle = mlines.Line2D([], [],
color=color,
marker="",
markersize=15,
label=label)
handles.append(handle)
return handles, colors
def rgb_to_hex(rgb_values):
for v in rgb_values:
c = "%02x%02x%02x" % tuple(v)
yield c

104
prototorch/utils/utils.py Normal file
View File

@@ -0,0 +1,104 @@
"""ProtoFlow utilities"""
import warnings
from collections.abc import Iterable
from typing import Union
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
if x is not None:
x_shift = border * np.ptp(x[:, 0])
y_shift = border * np.ptp(x[:, 1])
x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
else:
x_min, x_max = -border, border
y_min, y_max = -border, border
xx, yy = np.meshgrid(np.linspace(x_min, x_max, resolution),
np.linspace(y_min, y_max, resolution))
mesh = np.c_[xx.ravel(), yy.ravel()]
return mesh, xx, yy
def distribution_from_list(list_dist: list[int],
clabels: Iterable[int] = None):
clabels = clabels or list(range(len(list_dist)))
distribution = dict(zip(clabels, list_dist))
return distribution
def parse_distribution(user_distribution,
clabels: Iterable[int] = None) -> dict[int, int]:
"""Parse user-provided distribution.
Return a dictionary with integer keys that represent the class labels and
values that denote the number of components/prototypes with that class
label.
The argument `user_distribution` could be any one of a number of allowed
formats. If it is a Python list, it is assumed that there are as many
entries in this list as there are classes, and the value at each index of
this list describes the number of prototypes for that particular class. So,
[1, 1, 1] implies that we have three classes with one prototype per class.
If it is a Python tuple, a shorthand of (num_classes, prototypes_per_class)
is assumed. If it is a Python dictionary, the key-value pairs describe the
class label and the number of prototypes for that class respectively. So,
{0: 2, 1: 2, 2: 2} implies that we have three classes with labels {1, 2,
3}, each equipped with two prototypes. If however, the dictionary contains
the keys "num_classes" and "per_class", they are parsed to use their values
as one might expect.
"""
if isinstance(user_distribution, dict):
if "num_classes" in user_distribution.keys():
num_classes = int(user_distribution["num_classes"])
per_class = int(user_distribution["per_class"])
return distribution_from_list([per_class] * num_classes, clabels)
else:
return user_distribution
elif isinstance(user_distribution, tuple):
assert len(user_distribution) == 2
num_classes, per_class = user_distribution
num_classes, per_class = int(num_classes), int(per_class)
return distribution_from_list([per_class] * num_classes, clabels)
elif isinstance(user_distribution, list):
return distribution_from_list(user_distribution, clabels)
else:
msg = f"`distribution` was not understood." \
f"You have provided: {user_distribution}."
raise ValueError(msg)
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
"""Return data and target as torch tensors."""
if isinstance(data_arg, Dataset):
if hasattr(data_arg, "__len__"):
ds_size = len(data_arg) # type: ignore
loader = DataLoader(data_arg, batch_size=ds_size)
data, targets = next(iter(loader))
else:
emsg = f"Dataset {data_arg} is not sized (`__len__` unimplemented)."
raise TypeError(emsg)
elif isinstance(data_arg, DataLoader):
data = torch.tensor([])
targets = torch.tensor([])
for x, y in data_arg:
data = torch.cat([data, x])
targets = torch.cat([targets, y])
else:
assert len(data_arg) == 2
data, targets = data_arg
if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}..."
warnings.warn(wmsg)
data = torch.Tensor(data)
if not isinstance(targets, torch.LongTensor):
wmsg = f"Converting targets to {torch.LongTensor}..."
warnings.warn(wmsg)
targets = torch.LongTensor(targets)
return data, targets