diff --git a/prototorch/utils/__init__.py b/prototorch/utils/__init__.py index 26ccedd..ac89f24 100644 --- a/prototorch/utils/__init__.py +++ b/prototorch/utils/__init__.py @@ -1,6 +1,11 @@ -"""ProtoFlow utils module""" +"""ProtoTorch utils module""" -from .colors import hex_to_rgb, rgb_to_hex +from .colors import ( + get_colors, + get_legend_handles, + hex_to_rgb, + rgb_to_hex, +) from .utils import ( mesh2d, parse_data_arg, diff --git a/prototorch/utils/colors.py b/prototorch/utils/colors.py index 4b2d838..169ff65 100644 --- a/prototorch/utils/colors.py +++ b/prototorch/utils/colors.py @@ -1,5 +1,14 @@ """ProtoTorch color utilities""" +import matplotlib.lines as mlines +import torch +from matplotlib import cm +from matplotlib.colors import ( + Normalize, + to_hex, + to_rgb, +) + def hex_to_rgb(hex_values): for v in hex_values: @@ -13,3 +22,39 @@ def rgb_to_hex(rgb_values): for v in rgb_values: c = "%02x%02x%02x" % tuple(v) yield c + + +def get_colors(vmax, vmin=0, cmap="viridis"): + cmap = cm.get_cmap(cmap) + colornorm = Normalize(vmin=vmin, vmax=vmax) + colors = dict() + for c in range(vmin, vmax + 1): + colors[c] = to_hex(cmap(colornorm(c))) + return colors + + +def get_legend_handles(colors, labels, marker="dots", zero_indexed=False): + handles = list() + for color, label in zip(colors.values(), labels): + if marker == "dots": + handle = mlines.Line2D( + xdata=[], + ydata=[], + label=label, + color="white", + markerfacecolor=color, + marker="o", + markersize=10, + markeredgecolor="k", + ) + else: + handle = mlines.Line2D( + xdata=[], + ydata=[], + label=label, + color=color, + marker="", + markersize=15, + ) + handles.append(handle) + return handles