add colourbar with different labeling options
This commit is contained in:
parent
69a8721292
commit
d390069609
238
matrix.py
238
matrix.py
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
import svg
|
||||
|
||||
rgen = np.random.default_rng()
|
||||
@ -86,7 +87,12 @@ def polylinear_gradient(colours, n):
|
||||
return gradient_dict
|
||||
|
||||
|
||||
class LinearGradientColourMap:
|
||||
class ColourMap(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, v: float): ...
|
||||
|
||||
|
||||
class LinearGradientColourMap(ColourMap):
|
||||
def __init__(
|
||||
self,
|
||||
colours: list[str] | None = ["#ff0000", "#ffffff", "#0000ff"],
|
||||
@ -104,75 +110,193 @@ class LinearGradientColourMap:
|
||||
return self.colours["hex"][v]
|
||||
|
||||
|
||||
class RandomColourMap:
|
||||
class RandomColourMap(ColourMap):
|
||||
def __init__(self, random_state: int | list[int] | None = [2, 3, 4, 5, 6]):
|
||||
self.rgen = np.random.default_rng([2, 3, 4, 5, 6])
|
||||
self.rgen = np.random.default_rng(random_state)
|
||||
|
||||
def __call__(self, v: float):
|
||||
return RGB_to_hex([x * 255 for x in self.rgen.random(3)])
|
||||
|
||||
|
||||
def colourbar(
|
||||
cmap,
|
||||
min_value: float | None = 0,
|
||||
max_value: float | None = 1,
|
||||
height=100,
|
||||
width=10,
|
||||
resolution=256,
|
||||
):
|
||||
items = [
|
||||
svg.Rect(fill=cmap(v), x=0, y=i, width=width, height=1, stroke="none")
|
||||
for i, v in enumerate(np.linspace(min_value, max_value, resolution))
|
||||
]
|
||||
return svg.G(elements=items, transform=[svg.Scale(1, height / resolution)])
|
||||
class MatrixVisualisation:
|
||||
def __init__(
|
||||
self,
|
||||
matrix: np.typing.NDArray,
|
||||
cmap: ColourMap,
|
||||
text: bool = False,
|
||||
labels: int | list[str] | bool = False,
|
||||
):
|
||||
self.m, self.n = matrix.shape
|
||||
width = 20
|
||||
height = 20
|
||||
gap = 1
|
||||
self.text = text
|
||||
self.total_width = (gap + width) * n + gap
|
||||
self.total_height = (gap + height) * m + gap
|
||||
self.cmap = cmap
|
||||
|
||||
self.elements = []
|
||||
self.elements.append(
|
||||
svg.Style(text=".mono { font: monospace; text-align: center;}")
|
||||
)
|
||||
self.elements.append(svg.Style(text=".small { font-size: 25%; }"))
|
||||
self.elements.append(svg.Style(text=".normal { font-size: 12px; }"))
|
||||
|
||||
for i, y in enumerate(range(gap, self.total_height, gap + height)):
|
||||
for j, x in enumerate(range(gap, self.total_width, gap + width)):
|
||||
self.elements.append(
|
||||
svg.Rect(
|
||||
x=x,
|
||||
y=y,
|
||||
width=width,
|
||||
height=height,
|
||||
stroke="transparent",
|
||||
fill=cmap(matrix[i, j]),
|
||||
)
|
||||
)
|
||||
if text:
|
||||
self.elements.append(
|
||||
svg.Text(
|
||||
x=x + width / 5,
|
||||
y=y + 3 * height / 4,
|
||||
textLength=width / 2,
|
||||
lengthAdjust="spacingAndGlyphs",
|
||||
class_=["mono"],
|
||||
text=f"{matrix[i, j]:.02f}",
|
||||
)
|
||||
)
|
||||
|
||||
def colourbar(
|
||||
self,
|
||||
min_value: float = 0,
|
||||
max_value: float = 1,
|
||||
height: int | None = None,
|
||||
width: int = 20,
|
||||
resolution: int = 256,
|
||||
border: int | bool = 1,
|
||||
labels: int | list[str] | bool = False,
|
||||
):
|
||||
if height is None:
|
||||
height = int(self.total_height * 2 / 3)
|
||||
|
||||
lines = [
|
||||
svg.Rect(
|
||||
fill=self.cmap(v),
|
||||
x=0,
|
||||
y=y,
|
||||
width=width,
|
||||
height=1.1 * height / resolution,
|
||||
stroke="none",
|
||||
)
|
||||
for y, v in zip(
|
||||
np.linspace(0, height, resolution),
|
||||
np.linspace(min_value, max_value, resolution - 1),
|
||||
)
|
||||
]
|
||||
|
||||
if labels is None:
|
||||
label = []
|
||||
elif isinstance(labels, int):
|
||||
label = svg.G(
|
||||
id="colourbar labels",
|
||||
elements=[
|
||||
svg.Text(
|
||||
text=f"— {v:.02f}",
|
||||
class_=["normal"],
|
||||
x=width,
|
||||
y=y,
|
||||
dy=3,
|
||||
)
|
||||
for y, v in zip(
|
||||
np.linspace(0, height, labels),
|
||||
np.linspace(min_value, max_value, labels),
|
||||
)
|
||||
],
|
||||
)
|
||||
elif isinstance(labels, list):
|
||||
if all(isinstance(n, str) for n in labels):
|
||||
label = svg.G(
|
||||
id="colourbar labels",
|
||||
elements=[
|
||||
svg.Text(
|
||||
text=f"— {v}",
|
||||
class_=["normal"],
|
||||
x=width,
|
||||
y=y,
|
||||
dy=3,
|
||||
)
|
||||
for y, v in zip(np.linspace(0, height, len(labels)), labels)
|
||||
],
|
||||
)
|
||||
if all(isinstance(n, float) or isinstance(n, int) for n in labels):
|
||||
label = svg.G(
|
||||
id="colourbar labels",
|
||||
elements=[
|
||||
svg.Text(
|
||||
text=f"— {v:.02f}",
|
||||
class_=["normal"],
|
||||
x=width,
|
||||
y=(v - min_value) / (max_value - min_value) * height,
|
||||
dy=3,
|
||||
)
|
||||
for v in labels
|
||||
],
|
||||
)
|
||||
|
||||
cbar = svg.G(
|
||||
id="colourbar",
|
||||
elements=[
|
||||
lines,
|
||||
label,
|
||||
svg.Rect(
|
||||
x=0,
|
||||
y=0,
|
||||
width=width,
|
||||
height=height,
|
||||
fill="none",
|
||||
stroke_width=border,
|
||||
stroke="black",
|
||||
),
|
||||
],
|
||||
transform=[
|
||||
svg.Translate(
|
||||
x=int(self.total_width + width / 2),
|
||||
y=int((self.total_height - height) / 2),
|
||||
)
|
||||
],
|
||||
)
|
||||
self.elements.append(cbar)
|
||||
self.total_width = self.total_width + 2 * width + 40 * bool(labels)
|
||||
|
||||
@property
|
||||
def svg(self):
|
||||
return str(
|
||||
svg.SVG(
|
||||
width=self.total_width, height=self.total_height, elements=self.elements
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"""Matrix Visualisation:
|
||||
shape: {matrix.shape}
|
||||
size: {self.total_width}x{self.total_height}
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
m, n = 20, 20
|
||||
width = 20
|
||||
height = 20
|
||||
gap = 1
|
||||
text = False
|
||||
total_width = (gap + width) * n + gap
|
||||
total_height = (gap + height) * m + gap
|
||||
|
||||
filename = "matrix.svg"
|
||||
m, n = 30, 20
|
||||
matrix = rgen.random(size=(m, n))
|
||||
|
||||
colours = ["#f5d72a", "#ffffff", "#2182af"]
|
||||
# colours = ["#ff0000", "#00ff00", "#0000ff"]
|
||||
cmap = LinearGradientColourMap(colours, matrix.min(), matrix.max())
|
||||
# cmap = RandomColourMap()
|
||||
|
||||
items = []
|
||||
items.append(svg.Style(text=".mono { font: monospace; text-align: center;}"))
|
||||
items.append(svg.Style(text=".small { font-size: 25%; }"))
|
||||
|
||||
for i, y in enumerate(range(gap, total_height, gap + height)):
|
||||
for j, x in enumerate(range(gap, total_width, gap + width)):
|
||||
items.append(
|
||||
svg.Rect(
|
||||
x=x,
|
||||
y=y,
|
||||
width=width,
|
||||
height=height,
|
||||
stroke="transparent",
|
||||
fill=cmap(matrix[i, j]),
|
||||
)
|
||||
)
|
||||
if text:
|
||||
items.append(
|
||||
svg.Text(
|
||||
x=x + width / 5,
|
||||
y=y + 3 * height / 4,
|
||||
textLength=width / 2,
|
||||
lengthAdjust="spacingAndGlyphs",
|
||||
class_=["mono"],
|
||||
text=f"{matrix[i, j]:.02f}",
|
||||
)
|
||||
)
|
||||
|
||||
items.append(colourbar(cmap, matrix.min(), matrix.max(), height=100))
|
||||
|
||||
content = svg.SVG(width=total_width, height=total_height, elements=items)
|
||||
fig = MatrixVisualisation(matrix, cmap=cmap)
|
||||
fig.colourbar(labels=["yellow", "white", "blue"])
|
||||
fig.colourbar(labels=5)
|
||||
fig.colourbar(labels=[0.2, 0.5, 0.55, 0.66, 1])
|
||||
filename = "matrix.svg"
|
||||
print(fig)
|
||||
with open(filename, "w") as f:
|
||||
f.write(str(content))
|
||||
f.write(fig.svg)
|
||||
|
Loading…
Reference in New Issue
Block a user