diff --git a/matrix.py b/matrix.py index 23d8cef..7d210d4 100644 --- a/matrix.py +++ b/matrix.py @@ -1,4 +1,5 @@ import numpy as np +from abc import ABC, abstractmethod import svg rgen = np.random.default_rng() @@ -86,7 +87,12 @@ def polylinear_gradient(colours, n): return gradient_dict -class LinearGradientColourMap: +class ColourMap(ABC): + @abstractmethod + def __call__(self, v: float): ... + + +class LinearGradientColourMap(ColourMap): def __init__( self, colours: list[str] | None = ["#ff0000", "#ffffff", "#0000ff"], @@ -104,75 +110,193 @@ class LinearGradientColourMap: return self.colours["hex"][v] -class RandomColourMap: +class RandomColourMap(ColourMap): def __init__(self, random_state: int | list[int] | None = [2, 3, 4, 5, 6]): - self.rgen = np.random.default_rng([2, 3, 4, 5, 6]) + self.rgen = np.random.default_rng(random_state) def __call__(self, v: float): return RGB_to_hex([x * 255 for x in self.rgen.random(3)]) -def colourbar( - cmap, - min_value: float | None = 0, - max_value: float | None = 1, - height=100, - width=10, - resolution=256, -): - items = [ - svg.Rect(fill=cmap(v), x=0, y=i, width=width, height=1, stroke="none") - for i, v in enumerate(np.linspace(min_value, max_value, resolution)) - ] - return svg.G(elements=items, transform=[svg.Scale(1, height / resolution)]) +class MatrixVisualisation: + def __init__( + self, + matrix: np.typing.NDArray, + cmap: ColourMap, + text: bool = False, + labels: int | list[str] | bool = False, + ): + self.m, self.n = matrix.shape + width = 20 + height = 20 + gap = 1 + self.text = text + self.total_width = (gap + width) * n + gap + self.total_height = (gap + height) * m + gap + self.cmap = cmap + + self.elements = [] + self.elements.append( + svg.Style(text=".mono { font: monospace; text-align: center;}") + ) + self.elements.append(svg.Style(text=".small { font-size: 25%; }")) + self.elements.append(svg.Style(text=".normal { font-size: 12px; }")) + + for i, y in enumerate(range(gap, self.total_height, gap + height)): + for j, x in enumerate(range(gap, self.total_width, gap + width)): + self.elements.append( + svg.Rect( + x=x, + y=y, + width=width, + height=height, + stroke="transparent", + fill=cmap(matrix[i, j]), + ) + ) + if text: + self.elements.append( + svg.Text( + x=x + width / 5, + y=y + 3 * height / 4, + textLength=width / 2, + lengthAdjust="spacingAndGlyphs", + class_=["mono"], + text=f"{matrix[i, j]:.02f}", + ) + ) + + def colourbar( + self, + min_value: float = 0, + max_value: float = 1, + height: int | None = None, + width: int = 20, + resolution: int = 256, + border: int | bool = 1, + labels: int | list[str] | bool = False, + ): + if height is None: + height = int(self.total_height * 2 / 3) + + lines = [ + svg.Rect( + fill=self.cmap(v), + x=0, + y=y, + width=width, + height=1.1 * height / resolution, + stroke="none", + ) + for y, v in zip( + np.linspace(0, height, resolution), + np.linspace(min_value, max_value, resolution - 1), + ) + ] + + if labels is None: + label = [] + elif isinstance(labels, int): + label = svg.G( + id="colourbar labels", + elements=[ + svg.Text( + text=f"— {v:.02f}", + class_=["normal"], + x=width, + y=y, + dy=3, + ) + for y, v in zip( + np.linspace(0, height, labels), + np.linspace(min_value, max_value, labels), + ) + ], + ) + elif isinstance(labels, list): + if all(isinstance(n, str) for n in labels): + label = svg.G( + id="colourbar labels", + elements=[ + svg.Text( + text=f"— {v}", + class_=["normal"], + x=width, + y=y, + dy=3, + ) + for y, v in zip(np.linspace(0, height, len(labels)), labels) + ], + ) + if all(isinstance(n, float) or isinstance(n, int) for n in labels): + label = svg.G( + id="colourbar labels", + elements=[ + svg.Text( + text=f"— {v:.02f}", + class_=["normal"], + x=width, + y=(v - min_value) / (max_value - min_value) * height, + dy=3, + ) + for v in labels + ], + ) + + cbar = svg.G( + id="colourbar", + elements=[ + lines, + label, + svg.Rect( + x=0, + y=0, + width=width, + height=height, + fill="none", + stroke_width=border, + stroke="black", + ), + ], + transform=[ + svg.Translate( + x=int(self.total_width + width / 2), + y=int((self.total_height - height) / 2), + ) + ], + ) + self.elements.append(cbar) + self.total_width = self.total_width + 2 * width + 40 * bool(labels) + + @property + def svg(self): + return str( + svg.SVG( + width=self.total_width, height=self.total_height, elements=self.elements + ) + ) + + def __repr__(self): + return f"""Matrix Visualisation: + shape: {matrix.shape} + size: {self.total_width}x{self.total_height} + """ if __name__ == "__main__": - m, n = 20, 20 - width = 20 - height = 20 - gap = 1 - text = False - total_width = (gap + width) * n + gap - total_height = (gap + height) * m + gap - - filename = "matrix.svg" + m, n = 30, 20 matrix = rgen.random(size=(m, n)) + colours = ["#f5d72a", "#ffffff", "#2182af"] # colours = ["#ff0000", "#00ff00", "#0000ff"] cmap = LinearGradientColourMap(colours, matrix.min(), matrix.max()) # cmap = RandomColourMap() - items = [] - items.append(svg.Style(text=".mono { font: monospace; text-align: center;}")) - items.append(svg.Style(text=".small { font-size: 25%; }")) - - for i, y in enumerate(range(gap, total_height, gap + height)): - for j, x in enumerate(range(gap, total_width, gap + width)): - items.append( - svg.Rect( - x=x, - y=y, - width=width, - height=height, - stroke="transparent", - fill=cmap(matrix[i, j]), - ) - ) - if text: - items.append( - svg.Text( - x=x + width / 5, - y=y + 3 * height / 4, - textLength=width / 2, - lengthAdjust="spacingAndGlyphs", - class_=["mono"], - text=f"{matrix[i, j]:.02f}", - ) - ) - - items.append(colourbar(cmap, matrix.min(), matrix.max(), height=100)) - - content = svg.SVG(width=total_width, height=total_height, elements=items) + fig = MatrixVisualisation(matrix, cmap=cmap) + fig.colourbar(labels=["yellow", "white", "blue"]) + fig.colourbar(labels=5) + fig.colourbar(labels=[0.2, 0.5, 0.55, 0.66, 1]) + filename = "matrix.svg" + print(fig) with open(filename, "w") as f: - f.write(str(content)) + f.write(fig.svg)