add colourbar with different labeling options
This commit is contained in:
		
							
								
								
									
										202
									
								
								matrix.py
									
									
									
									
									
								
							
							
						
						
									
										202
									
								
								matrix.py
									
									
									
									
									
								
							@@ -1,4 +1,5 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
import svg
 | 
			
		||||
 | 
			
		||||
rgen = np.random.default_rng()
 | 
			
		||||
@@ -86,7 +87,12 @@ def polylinear_gradient(colours, n):
 | 
			
		||||
    return gradient_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearGradientColourMap:
 | 
			
		||||
class ColourMap(ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def __call__(self, v: float): ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearGradientColourMap(ColourMap):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        colours: list[str] | None = ["#ff0000", "#ffffff", "#0000ff"],
 | 
			
		||||
@@ -104,52 +110,41 @@ class LinearGradientColourMap:
 | 
			
		||||
        return self.colours["hex"][v]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RandomColourMap:
 | 
			
		||||
class RandomColourMap(ColourMap):
 | 
			
		||||
    def __init__(self, random_state: int | list[int] | None = [2, 3, 4, 5, 6]):
 | 
			
		||||
        self.rgen = np.random.default_rng([2, 3, 4, 5, 6])
 | 
			
		||||
        self.rgen = np.random.default_rng(random_state)
 | 
			
		||||
 | 
			
		||||
    def __call__(self, v: float):
 | 
			
		||||
        return RGB_to_hex([x * 255 for x in self.rgen.random(3)])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def colourbar(
 | 
			
		||||
    cmap,
 | 
			
		||||
    min_value: float | None = 0,
 | 
			
		||||
    max_value: float | None = 1,
 | 
			
		||||
    height=100,
 | 
			
		||||
    width=10,
 | 
			
		||||
    resolution=256,
 | 
			
		||||
class MatrixVisualisation:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        matrix: np.typing.NDArray,
 | 
			
		||||
        cmap: ColourMap,
 | 
			
		||||
        text: bool = False,
 | 
			
		||||
        labels: int | list[str] | bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
    items = [
 | 
			
		||||
        svg.Rect(fill=cmap(v), x=0, y=i, width=width, height=1, stroke="none")
 | 
			
		||||
        for i, v in enumerate(np.linspace(min_value, max_value, resolution))
 | 
			
		||||
    ]
 | 
			
		||||
    return svg.G(elements=items, transform=[svg.Scale(1, height / resolution)])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    m, n = 20, 20
 | 
			
		||||
        self.m, self.n = matrix.shape
 | 
			
		||||
        width = 20
 | 
			
		||||
        height = 20
 | 
			
		||||
        gap = 1
 | 
			
		||||
    text = False
 | 
			
		||||
    total_width = (gap + width) * n + gap
 | 
			
		||||
    total_height = (gap + height) * m + gap
 | 
			
		||||
        self.text = text
 | 
			
		||||
        self.total_width = (gap + width) * n + gap
 | 
			
		||||
        self.total_height = (gap + height) * m + gap
 | 
			
		||||
        self.cmap = cmap
 | 
			
		||||
 | 
			
		||||
    filename = "matrix.svg"
 | 
			
		||||
    matrix = rgen.random(size=(m, n))
 | 
			
		||||
    colours = ["#f5d72a", "#ffffff", "#2182af"]
 | 
			
		||||
    # colours = ["#ff0000", "#00ff00", "#0000ff"]
 | 
			
		||||
    cmap = LinearGradientColourMap(colours, matrix.min(), matrix.max())
 | 
			
		||||
    # cmap = RandomColourMap()
 | 
			
		||||
        self.elements = []
 | 
			
		||||
        self.elements.append(
 | 
			
		||||
            svg.Style(text=".mono { font: monospace; text-align: center;}")
 | 
			
		||||
        )
 | 
			
		||||
        self.elements.append(svg.Style(text=".small { font-size: 25%; }"))
 | 
			
		||||
        self.elements.append(svg.Style(text=".normal { font-size: 12px; }"))
 | 
			
		||||
 | 
			
		||||
    items = []
 | 
			
		||||
    items.append(svg.Style(text=".mono { font: monospace; text-align: center;}"))
 | 
			
		||||
    items.append(svg.Style(text=".small { font-size: 25%; }"))
 | 
			
		||||
 | 
			
		||||
    for i, y in enumerate(range(gap, total_height, gap + height)):
 | 
			
		||||
        for j, x in enumerate(range(gap, total_width, gap + width)):
 | 
			
		||||
            items.append(
 | 
			
		||||
        for i, y in enumerate(range(gap, self.total_height, gap + height)):
 | 
			
		||||
            for j, x in enumerate(range(gap, self.total_width, gap + width)):
 | 
			
		||||
                self.elements.append(
 | 
			
		||||
                    svg.Rect(
 | 
			
		||||
                        x=x,
 | 
			
		||||
                        y=y,
 | 
			
		||||
@@ -160,7 +155,7 @@ if __name__ == "__main__":
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                if text:
 | 
			
		||||
                items.append(
 | 
			
		||||
                    self.elements.append(
 | 
			
		||||
                        svg.Text(
 | 
			
		||||
                            x=x + width / 5,
 | 
			
		||||
                            y=y + 3 * height / 4,
 | 
			
		||||
@@ -171,8 +166,137 @@ if __name__ == "__main__":
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    items.append(colourbar(cmap, matrix.min(), matrix.max(), height=100))
 | 
			
		||||
    def colourbar(
 | 
			
		||||
        self,
 | 
			
		||||
        min_value: float = 0,
 | 
			
		||||
        max_value: float = 1,
 | 
			
		||||
        height: int | None = None,
 | 
			
		||||
        width: int = 20,
 | 
			
		||||
        resolution: int = 256,
 | 
			
		||||
        border: int | bool = 1,
 | 
			
		||||
        labels: int | list[str] | bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        if height is None:
 | 
			
		||||
            height = int(self.total_height * 2 / 3)
 | 
			
		||||
 | 
			
		||||
    content = svg.SVG(width=total_width, height=total_height, elements=items)
 | 
			
		||||
        lines = [
 | 
			
		||||
            svg.Rect(
 | 
			
		||||
                fill=self.cmap(v),
 | 
			
		||||
                x=0,
 | 
			
		||||
                y=y,
 | 
			
		||||
                width=width,
 | 
			
		||||
                height=1.1 * height / resolution,
 | 
			
		||||
                stroke="none",
 | 
			
		||||
            )
 | 
			
		||||
            for y, v in zip(
 | 
			
		||||
                np.linspace(0, height, resolution),
 | 
			
		||||
                np.linspace(min_value, max_value, resolution - 1),
 | 
			
		||||
            )
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        if labels is None:
 | 
			
		||||
            label = []
 | 
			
		||||
        elif isinstance(labels, int):
 | 
			
		||||
            label = svg.G(
 | 
			
		||||
                id="colourbar labels",
 | 
			
		||||
                elements=[
 | 
			
		||||
                    svg.Text(
 | 
			
		||||
                        text=f"— {v:.02f}",
 | 
			
		||||
                        class_=["normal"],
 | 
			
		||||
                        x=width,
 | 
			
		||||
                        y=y,
 | 
			
		||||
                        dy=3,
 | 
			
		||||
                    )
 | 
			
		||||
                    for y, v in zip(
 | 
			
		||||
                        np.linspace(0, height, labels),
 | 
			
		||||
                        np.linspace(min_value, max_value, labels),
 | 
			
		||||
                    )
 | 
			
		||||
                ],
 | 
			
		||||
            )
 | 
			
		||||
        elif isinstance(labels, list):
 | 
			
		||||
            if all(isinstance(n, str) for n in labels):
 | 
			
		||||
                label = svg.G(
 | 
			
		||||
                    id="colourbar labels",
 | 
			
		||||
                    elements=[
 | 
			
		||||
                        svg.Text(
 | 
			
		||||
                            text=f"— {v}",
 | 
			
		||||
                            class_=["normal"],
 | 
			
		||||
                            x=width,
 | 
			
		||||
                            y=y,
 | 
			
		||||
                            dy=3,
 | 
			
		||||
                        )
 | 
			
		||||
                        for y, v in zip(np.linspace(0, height, len(labels)), labels)
 | 
			
		||||
                    ],
 | 
			
		||||
                )
 | 
			
		||||
            if all(isinstance(n, float) or isinstance(n, int) for n in labels):
 | 
			
		||||
                label = svg.G(
 | 
			
		||||
                    id="colourbar labels",
 | 
			
		||||
                    elements=[
 | 
			
		||||
                        svg.Text(
 | 
			
		||||
                            text=f"— {v:.02f}",
 | 
			
		||||
                            class_=["normal"],
 | 
			
		||||
                            x=width,
 | 
			
		||||
                            y=(v - min_value) / (max_value - min_value) * height,
 | 
			
		||||
                            dy=3,
 | 
			
		||||
                        )
 | 
			
		||||
                        for v in labels
 | 
			
		||||
                    ],
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        cbar = svg.G(
 | 
			
		||||
            id="colourbar",
 | 
			
		||||
            elements=[
 | 
			
		||||
                lines,
 | 
			
		||||
                label,
 | 
			
		||||
                svg.Rect(
 | 
			
		||||
                    x=0,
 | 
			
		||||
                    y=0,
 | 
			
		||||
                    width=width,
 | 
			
		||||
                    height=height,
 | 
			
		||||
                    fill="none",
 | 
			
		||||
                    stroke_width=border,
 | 
			
		||||
                    stroke="black",
 | 
			
		||||
                ),
 | 
			
		||||
            ],
 | 
			
		||||
            transform=[
 | 
			
		||||
                svg.Translate(
 | 
			
		||||
                    x=int(self.total_width + width / 2),
 | 
			
		||||
                    y=int((self.total_height - height) / 2),
 | 
			
		||||
                )
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
        self.elements.append(cbar)
 | 
			
		||||
        self.total_width = self.total_width + 2 * width + 40 * bool(labels)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def svg(self):
 | 
			
		||||
        return str(
 | 
			
		||||
            svg.SVG(
 | 
			
		||||
                width=self.total_width, height=self.total_height, elements=self.elements
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"""Matrix Visualisation: 
 | 
			
		||||
    shape: {matrix.shape}
 | 
			
		||||
    size:  {self.total_width}x{self.total_height}
 | 
			
		||||
            """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    m, n = 30, 20
 | 
			
		||||
    matrix = rgen.random(size=(m, n))
 | 
			
		||||
 | 
			
		||||
    colours = ["#f5d72a", "#ffffff", "#2182af"]
 | 
			
		||||
    # colours = ["#ff0000", "#00ff00", "#0000ff"]
 | 
			
		||||
    cmap = LinearGradientColourMap(colours, matrix.min(), matrix.max())
 | 
			
		||||
    # cmap = RandomColourMap()
 | 
			
		||||
 | 
			
		||||
    fig = MatrixVisualisation(matrix, cmap=cmap)
 | 
			
		||||
    fig.colourbar(labels=["yellow", "white", "blue"])
 | 
			
		||||
    fig.colourbar(labels=5)
 | 
			
		||||
    fig.colourbar(labels=[0.2, 0.5, 0.55, 0.66, 1])
 | 
			
		||||
    filename = "matrix.svg"
 | 
			
		||||
    print(fig)
 | 
			
		||||
    with open(filename, "w") as f:
 | 
			
		||||
        f.write(str(content))
 | 
			
		||||
        f.write(fig.svg)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user