add colourbar with different labeling options
This commit is contained in:
		
							
								
								
									
										202
									
								
								matrix.py
									
									
									
									
									
								
							
							
						
						
									
										202
									
								
								matrix.py
									
									
									
									
									
								
							@@ -1,4 +1,5 @@
 | 
				
			|||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from abc import ABC, abstractmethod
 | 
				
			||||||
import svg
 | 
					import svg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
rgen = np.random.default_rng()
 | 
					rgen = np.random.default_rng()
 | 
				
			||||||
@@ -86,7 +87,12 @@ def polylinear_gradient(colours, n):
 | 
				
			|||||||
    return gradient_dict
 | 
					    return gradient_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LinearGradientColourMap:
 | 
					class ColourMap(ABC):
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def __call__(self, v: float): ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LinearGradientColourMap(ColourMap):
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        colours: list[str] | None = ["#ff0000", "#ffffff", "#0000ff"],
 | 
					        colours: list[str] | None = ["#ff0000", "#ffffff", "#0000ff"],
 | 
				
			||||||
@@ -104,52 +110,41 @@ class LinearGradientColourMap:
 | 
				
			|||||||
        return self.colours["hex"][v]
 | 
					        return self.colours["hex"][v]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RandomColourMap:
 | 
					class RandomColourMap(ColourMap):
 | 
				
			||||||
    def __init__(self, random_state: int | list[int] | None = [2, 3, 4, 5, 6]):
 | 
					    def __init__(self, random_state: int | list[int] | None = [2, 3, 4, 5, 6]):
 | 
				
			||||||
        self.rgen = np.random.default_rng([2, 3, 4, 5, 6])
 | 
					        self.rgen = np.random.default_rng(random_state)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, v: float):
 | 
					    def __call__(self, v: float):
 | 
				
			||||||
        return RGB_to_hex([x * 255 for x in self.rgen.random(3)])
 | 
					        return RGB_to_hex([x * 255 for x in self.rgen.random(3)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def colourbar(
 | 
					class MatrixVisualisation:
 | 
				
			||||||
    cmap,
 | 
					    def __init__(
 | 
				
			||||||
    min_value: float | None = 0,
 | 
					        self,
 | 
				
			||||||
    max_value: float | None = 1,
 | 
					        matrix: np.typing.NDArray,
 | 
				
			||||||
    height=100,
 | 
					        cmap: ColourMap,
 | 
				
			||||||
    width=10,
 | 
					        text: bool = False,
 | 
				
			||||||
    resolution=256,
 | 
					        labels: int | list[str] | bool = False,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
    items = [
 | 
					        self.m, self.n = matrix.shape
 | 
				
			||||||
        svg.Rect(fill=cmap(v), x=0, y=i, width=width, height=1, stroke="none")
 | 
					 | 
				
			||||||
        for i, v in enumerate(np.linspace(min_value, max_value, resolution))
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
    return svg.G(elements=items, transform=[svg.Scale(1, height / resolution)])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    m, n = 20, 20
 | 
					 | 
				
			||||||
        width = 20
 | 
					        width = 20
 | 
				
			||||||
        height = 20
 | 
					        height = 20
 | 
				
			||||||
        gap = 1
 | 
					        gap = 1
 | 
				
			||||||
    text = False
 | 
					        self.text = text
 | 
				
			||||||
    total_width = (gap + width) * n + gap
 | 
					        self.total_width = (gap + width) * n + gap
 | 
				
			||||||
    total_height = (gap + height) * m + gap
 | 
					        self.total_height = (gap + height) * m + gap
 | 
				
			||||||
 | 
					        self.cmap = cmap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    filename = "matrix.svg"
 | 
					        self.elements = []
 | 
				
			||||||
    matrix = rgen.random(size=(m, n))
 | 
					        self.elements.append(
 | 
				
			||||||
    colours = ["#f5d72a", "#ffffff", "#2182af"]
 | 
					            svg.Style(text=".mono { font: monospace; text-align: center;}")
 | 
				
			||||||
    # colours = ["#ff0000", "#00ff00", "#0000ff"]
 | 
					        )
 | 
				
			||||||
    cmap = LinearGradientColourMap(colours, matrix.min(), matrix.max())
 | 
					        self.elements.append(svg.Style(text=".small { font-size: 25%; }"))
 | 
				
			||||||
    # cmap = RandomColourMap()
 | 
					        self.elements.append(svg.Style(text=".normal { font-size: 12px; }"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    items = []
 | 
					        for i, y in enumerate(range(gap, self.total_height, gap + height)):
 | 
				
			||||||
    items.append(svg.Style(text=".mono { font: monospace; text-align: center;}"))
 | 
					            for j, x in enumerate(range(gap, self.total_width, gap + width)):
 | 
				
			||||||
    items.append(svg.Style(text=".small { font-size: 25%; }"))
 | 
					                self.elements.append(
 | 
				
			||||||
 | 
					 | 
				
			||||||
    for i, y in enumerate(range(gap, total_height, gap + height)):
 | 
					 | 
				
			||||||
        for j, x in enumerate(range(gap, total_width, gap + width)):
 | 
					 | 
				
			||||||
            items.append(
 | 
					 | 
				
			||||||
                    svg.Rect(
 | 
					                    svg.Rect(
 | 
				
			||||||
                        x=x,
 | 
					                        x=x,
 | 
				
			||||||
                        y=y,
 | 
					                        y=y,
 | 
				
			||||||
@@ -160,7 +155,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
                    )
 | 
					                    )
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                if text:
 | 
					                if text:
 | 
				
			||||||
                items.append(
 | 
					                    self.elements.append(
 | 
				
			||||||
                        svg.Text(
 | 
					                        svg.Text(
 | 
				
			||||||
                            x=x + width / 5,
 | 
					                            x=x + width / 5,
 | 
				
			||||||
                            y=y + 3 * height / 4,
 | 
					                            y=y + 3 * height / 4,
 | 
				
			||||||
@@ -171,8 +166,137 @@ if __name__ == "__main__":
 | 
				
			|||||||
                        )
 | 
					                        )
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    items.append(colourbar(cmap, matrix.min(), matrix.max(), height=100))
 | 
					    def colourbar(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        min_value: float = 0,
 | 
				
			||||||
 | 
					        max_value: float = 1,
 | 
				
			||||||
 | 
					        height: int | None = None,
 | 
				
			||||||
 | 
					        width: int = 20,
 | 
				
			||||||
 | 
					        resolution: int = 256,
 | 
				
			||||||
 | 
					        border: int | bool = 1,
 | 
				
			||||||
 | 
					        labels: int | list[str] | bool = False,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if height is None:
 | 
				
			||||||
 | 
					            height = int(self.total_height * 2 / 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    content = svg.SVG(width=total_width, height=total_height, elements=items)
 | 
					        lines = [
 | 
				
			||||||
 | 
					            svg.Rect(
 | 
				
			||||||
 | 
					                fill=self.cmap(v),
 | 
				
			||||||
 | 
					                x=0,
 | 
				
			||||||
 | 
					                y=y,
 | 
				
			||||||
 | 
					                width=width,
 | 
				
			||||||
 | 
					                height=1.1 * height / resolution,
 | 
				
			||||||
 | 
					                stroke="none",
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            for y, v in zip(
 | 
				
			||||||
 | 
					                np.linspace(0, height, resolution),
 | 
				
			||||||
 | 
					                np.linspace(min_value, max_value, resolution - 1),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if labels is None:
 | 
				
			||||||
 | 
					            label = []
 | 
				
			||||||
 | 
					        elif isinstance(labels, int):
 | 
				
			||||||
 | 
					            label = svg.G(
 | 
				
			||||||
 | 
					                id="colourbar labels",
 | 
				
			||||||
 | 
					                elements=[
 | 
				
			||||||
 | 
					                    svg.Text(
 | 
				
			||||||
 | 
					                        text=f"— {v:.02f}",
 | 
				
			||||||
 | 
					                        class_=["normal"],
 | 
				
			||||||
 | 
					                        x=width,
 | 
				
			||||||
 | 
					                        y=y,
 | 
				
			||||||
 | 
					                        dy=3,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    for y, v in zip(
 | 
				
			||||||
 | 
					                        np.linspace(0, height, labels),
 | 
				
			||||||
 | 
					                        np.linspace(min_value, max_value, labels),
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        elif isinstance(labels, list):
 | 
				
			||||||
 | 
					            if all(isinstance(n, str) for n in labels):
 | 
				
			||||||
 | 
					                label = svg.G(
 | 
				
			||||||
 | 
					                    id="colourbar labels",
 | 
				
			||||||
 | 
					                    elements=[
 | 
				
			||||||
 | 
					                        svg.Text(
 | 
				
			||||||
 | 
					                            text=f"— {v}",
 | 
				
			||||||
 | 
					                            class_=["normal"],
 | 
				
			||||||
 | 
					                            x=width,
 | 
				
			||||||
 | 
					                            y=y,
 | 
				
			||||||
 | 
					                            dy=3,
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                        for y, v in zip(np.linspace(0, height, len(labels)), labels)
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            if all(isinstance(n, float) or isinstance(n, int) for n in labels):
 | 
				
			||||||
 | 
					                label = svg.G(
 | 
				
			||||||
 | 
					                    id="colourbar labels",
 | 
				
			||||||
 | 
					                    elements=[
 | 
				
			||||||
 | 
					                        svg.Text(
 | 
				
			||||||
 | 
					                            text=f"— {v:.02f}",
 | 
				
			||||||
 | 
					                            class_=["normal"],
 | 
				
			||||||
 | 
					                            x=width,
 | 
				
			||||||
 | 
					                            y=(v - min_value) / (max_value - min_value) * height,
 | 
				
			||||||
 | 
					                            dy=3,
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                        for v in labels
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cbar = svg.G(
 | 
				
			||||||
 | 
					            id="colourbar",
 | 
				
			||||||
 | 
					            elements=[
 | 
				
			||||||
 | 
					                lines,
 | 
				
			||||||
 | 
					                label,
 | 
				
			||||||
 | 
					                svg.Rect(
 | 
				
			||||||
 | 
					                    x=0,
 | 
				
			||||||
 | 
					                    y=0,
 | 
				
			||||||
 | 
					                    width=width,
 | 
				
			||||||
 | 
					                    height=height,
 | 
				
			||||||
 | 
					                    fill="none",
 | 
				
			||||||
 | 
					                    stroke_width=border,
 | 
				
			||||||
 | 
					                    stroke="black",
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					            transform=[
 | 
				
			||||||
 | 
					                svg.Translate(
 | 
				
			||||||
 | 
					                    x=int(self.total_width + width / 2),
 | 
				
			||||||
 | 
					                    y=int((self.total_height - height) / 2),
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.elements.append(cbar)
 | 
				
			||||||
 | 
					        self.total_width = self.total_width + 2 * width + 40 * bool(labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def svg(self):
 | 
				
			||||||
 | 
					        return str(
 | 
				
			||||||
 | 
					            svg.SVG(
 | 
				
			||||||
 | 
					                width=self.total_width, height=self.total_height, elements=self.elements
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __repr__(self):
 | 
				
			||||||
 | 
					        return f"""Matrix Visualisation: 
 | 
				
			||||||
 | 
					    shape: {matrix.shape}
 | 
				
			||||||
 | 
					    size:  {self.total_width}x{self.total_height}
 | 
				
			||||||
 | 
					            """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    m, n = 30, 20
 | 
				
			||||||
 | 
					    matrix = rgen.random(size=(m, n))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    colours = ["#f5d72a", "#ffffff", "#2182af"]
 | 
				
			||||||
 | 
					    # colours = ["#ff0000", "#00ff00", "#0000ff"]
 | 
				
			||||||
 | 
					    cmap = LinearGradientColourMap(colours, matrix.min(), matrix.max())
 | 
				
			||||||
 | 
					    # cmap = RandomColourMap()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fig = MatrixVisualisation(matrix, cmap=cmap)
 | 
				
			||||||
 | 
					    fig.colourbar(labels=["yellow", "white", "blue"])
 | 
				
			||||||
 | 
					    fig.colourbar(labels=5)
 | 
				
			||||||
 | 
					    fig.colourbar(labels=[0.2, 0.5, 0.55, 0.66, 1])
 | 
				
			||||||
 | 
					    filename = "matrix.svg"
 | 
				
			||||||
 | 
					    print(fig)
 | 
				
			||||||
    with open(filename, "w") as f:
 | 
					    with open(filename, "w") as f:
 | 
				
			||||||
        f.write(str(content))
 | 
					        f.write(fig.svg)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user