vectise/matrix.py
2024-11-10 18:54:23 +01:00

170 lines
5.3 KiB
Python

import numpy as np
import svg
from colours import ColourMap
class MatrixVisualisation:
def __init__(
self,
matrix: np.typing.NDArray,
cmap: ColourMap,
text: bool = False,
labels: int | list[str] | bool = False,
):
self.matrix = matrix
self.m, self.n = self.matrix.shape
width = 20
height = 20
gap = 1
self.text = text
self.total_width = (gap + width) * self.n + gap
self.total_height = (gap + height) * self.m + gap
self.cmap = cmap
self.elements = []
self.elements.append(
svg.Style(text=".mono { font: monospace; text-align: center;}")
)
self.elements.append(svg.Style(text=".small { font-size: 25%; }"))
self.elements.append(svg.Style(text=".normal { font-size: 12px; }"))
for i, y in enumerate(range(gap, self.total_height, gap + height)):
for j, x in enumerate(range(gap, self.total_width, gap + width)):
self.elements.append(
svg.Rect(
x=x,
y=y,
width=width,
height=height,
stroke="transparent",
fill=cmap(self.matrix[i, j]),
)
)
if text:
self.elements.append(
svg.Text(
x=x + width / 5,
y=y + 3 * height / 4,
textLength=width / 2,
lengthAdjust="spacingAndGlyphs",
class_=["mono"],
text=f"{self.matrix[i, j]:.02f}",
)
)
def colourbar(
self,
min_value: float = 0,
max_value: float = 1,
height: int | None = None,
width: int = 20,
resolution: int = 256,
border: int | bool = 1,
labels: int | list[str] | bool = False,
):
if height is None:
height = int(self.total_height * 2 / 3)
lines = [
svg.Rect(
fill=self.cmap(v),
x=0,
y=y,
width=width,
height=1.1 * height / resolution,
stroke="none",
)
for y, v in zip(
np.linspace(0, height, resolution),
np.linspace(min_value, max_value, resolution - 1),
)
]
if labels is None:
label = []
elif isinstance(labels, int):
label = svg.G(
id="colourbar labels",
elements=[
svg.Text(
text=f"{v:.02f}",
class_=["normal"],
x=width,
y=y,
dy=3,
)
for y, v in zip(
np.linspace(0, height, labels),
np.linspace(min_value, max_value, labels),
)
],
)
elif isinstance(labels, list):
if all(isinstance(n, str) for n in labels):
label = svg.G(
id="colourbar labels",
elements=[
svg.Text(
text=f"{v}",
class_=["normal"],
x=width,
y=y,
dy=3,
)
for y, v in zip(np.linspace(0, height, len(labels)), labels)
],
)
if all(isinstance(n, float) or isinstance(n, int) for n in labels):
label = svg.G(
id="colourbar labels",
elements=[
svg.Text(
text=f"{v:.02f}",
class_=["normal"],
x=width,
y=(v - min_value) / (max_value - min_value) * height,
dy=3,
)
for v in labels
],
)
cbar = svg.G(
id="colourbar",
elements=[
lines,
label,
svg.Rect(
x=0,
y=0,
width=width,
height=height,
fill="none",
stroke_width=border,
stroke="black",
),
],
transform=[
svg.Translate(
x=int(self.total_width + width / 2),
y=int((self.total_height - height) / 2),
)
],
)
self.elements.append(cbar)
self.total_width = self.total_width + 2 * width + 40 * bool(labels)
@property
def svg(self):
return str(
svg.SVG(
width=self.total_width, height=self.total_height, elements=self.elements
)
)
def __repr__(self):
return f"""Matrix Visualisation:
shape: {self.matrix.shape}
size: {self.total_width}x{self.total_height}
"""