170 lines
5.3 KiB
Python
170 lines
5.3 KiB
Python
import numpy as np
|
|
import svg
|
|
from colours import ColourMap
|
|
|
|
|
|
class MatrixVisualisation:
|
|
def __init__(
|
|
self,
|
|
matrix: np.typing.NDArray,
|
|
cmap: ColourMap,
|
|
text: bool = False,
|
|
labels: int | list[str] | bool = False,
|
|
):
|
|
self.matrix = matrix
|
|
self.m, self.n = self.matrix.shape
|
|
width = 20
|
|
height = 20
|
|
gap = 1
|
|
self.text = text
|
|
self.total_width = (gap + width) * self.n + gap
|
|
self.total_height = (gap + height) * self.m + gap
|
|
self.cmap = cmap
|
|
|
|
self.elements = []
|
|
self.elements.append(
|
|
svg.Style(text=".mono { font: monospace; text-align: center;}")
|
|
)
|
|
self.elements.append(svg.Style(text=".small { font-size: 25%; }"))
|
|
self.elements.append(svg.Style(text=".normal { font-size: 12px; }"))
|
|
|
|
for i, y in enumerate(range(gap, self.total_height, gap + height)):
|
|
for j, x in enumerate(range(gap, self.total_width, gap + width)):
|
|
self.elements.append(
|
|
svg.Rect(
|
|
x=x,
|
|
y=y,
|
|
width=width,
|
|
height=height,
|
|
stroke="transparent",
|
|
fill=cmap(self.matrix[i, j]),
|
|
)
|
|
)
|
|
if text:
|
|
self.elements.append(
|
|
svg.Text(
|
|
x=x + width / 5,
|
|
y=y + 3 * height / 4,
|
|
textLength=width / 2,
|
|
lengthAdjust="spacingAndGlyphs",
|
|
class_=["mono"],
|
|
text=f"{self.matrix[i, j]:.02f}",
|
|
)
|
|
)
|
|
|
|
def colourbar(
|
|
self,
|
|
min_value: float = 0,
|
|
max_value: float = 1,
|
|
height: int | None = None,
|
|
width: int = 20,
|
|
resolution: int = 256,
|
|
border: int | bool = 1,
|
|
labels: int | list[str] | bool = False,
|
|
):
|
|
if height is None:
|
|
height = int(self.total_height * 2 / 3)
|
|
|
|
lines = [
|
|
svg.Rect(
|
|
fill=self.cmap(v),
|
|
x=0,
|
|
y=y,
|
|
width=width,
|
|
height=1.1 * height / resolution,
|
|
stroke="none",
|
|
)
|
|
for y, v in zip(
|
|
np.linspace(0, height, resolution),
|
|
np.linspace(min_value, max_value, resolution - 1),
|
|
)
|
|
]
|
|
|
|
if labels is None:
|
|
label = []
|
|
elif isinstance(labels, int):
|
|
label = svg.G(
|
|
id="colourbar labels",
|
|
elements=[
|
|
svg.Text(
|
|
text=f"— {v:.02f}",
|
|
class_=["normal"],
|
|
x=width,
|
|
y=y,
|
|
dy=3,
|
|
)
|
|
for y, v in zip(
|
|
np.linspace(0, height, labels),
|
|
np.linspace(min_value, max_value, labels),
|
|
)
|
|
],
|
|
)
|
|
elif isinstance(labels, list):
|
|
if all(isinstance(n, str) for n in labels):
|
|
label = svg.G(
|
|
id="colourbar labels",
|
|
elements=[
|
|
svg.Text(
|
|
text=f"— {v}",
|
|
class_=["normal"],
|
|
x=width,
|
|
y=y,
|
|
dy=3,
|
|
)
|
|
for y, v in zip(np.linspace(0, height, len(labels)), labels)
|
|
],
|
|
)
|
|
if all(isinstance(n, float) or isinstance(n, int) for n in labels):
|
|
label = svg.G(
|
|
id="colourbar labels",
|
|
elements=[
|
|
svg.Text(
|
|
text=f"— {v:.02f}",
|
|
class_=["normal"],
|
|
x=width,
|
|
y=(v - min_value) / (max_value - min_value) * height,
|
|
dy=3,
|
|
)
|
|
for v in labels
|
|
],
|
|
)
|
|
|
|
cbar = svg.G(
|
|
id="colourbar",
|
|
elements=[
|
|
lines,
|
|
label,
|
|
svg.Rect(
|
|
x=0,
|
|
y=0,
|
|
width=width,
|
|
height=height,
|
|
fill="none",
|
|
stroke_width=border,
|
|
stroke="black",
|
|
),
|
|
],
|
|
transform=[
|
|
svg.Translate(
|
|
x=int(self.total_width + width / 2),
|
|
y=int((self.total_height - height) / 2),
|
|
)
|
|
],
|
|
)
|
|
self.elements.append(cbar)
|
|
self.total_width = self.total_width + 2 * width + 40 * bool(labels)
|
|
|
|
@property
|
|
def svg(self):
|
|
return str(
|
|
svg.SVG(
|
|
width=self.total_width, height=self.total_height, elements=self.elements
|
|
)
|
|
)
|
|
|
|
def __repr__(self):
|
|
return f"""Matrix Visualisation:
|
|
shape: {self.matrix.shape}
|
|
size: {self.total_width}x{self.total_height}
|
|
"""
|