import numpy as np import svg from colours import ColourMap class MatrixVisualisation: def __init__( self, matrix: np.typing.NDArray, cmap: ColourMap, text: bool = False, labels: int | list[str] | bool = False, ): self.matrix = matrix self.m, self.n = self.matrix.shape width = 20 height = 20 gap = 1 self.text = text self.total_width = (gap + width) * self.n + gap self.total_height = (gap + height) * self.m + gap self.cmap = cmap self.elements = [] self.elements.append( svg.Style(text=".mono { font: monospace; text-align: center;}") ) self.elements.append(svg.Style(text=".small { font-size: 25%; }")) self.elements.append(svg.Style(text=".normal { font-size: 12px; }")) for i, y in enumerate(range(gap, self.total_height, gap + height)): for j, x in enumerate(range(gap, self.total_width, gap + width)): self.elements.append( svg.Rect( x=x, y=y, width=width, height=height, stroke="transparent", fill=cmap(self.matrix[i, j]), ) ) if text: self.elements.append( svg.Text( x=x + width / 5, y=y + 3 * height / 4, textLength=width / 2, lengthAdjust="spacingAndGlyphs", class_=["mono"], text=f"{self.matrix[i, j]:.02f}", ) ) def colourbar( self, min_value: float = 0, max_value: float = 1, height: int | None = None, width: int = 20, resolution: int = 256, border: int | bool = 1, labels: int | list[str] | bool = False, ): if height is None: height = int(self.total_height * 2 / 3) lines = [ svg.Rect( fill=self.cmap(v), x=0, y=y, width=width, height=1.1 * height / resolution, stroke="none", ) for y, v in zip( np.linspace(0, height, resolution), np.linspace(min_value, max_value, resolution - 1), ) ] if labels is None: label = [] elif isinstance(labels, int): label = svg.G( id="colourbar labels", elements=[ svg.Text( text=f"— {v:.02f}", class_=["normal"], x=width, y=y, dy=3, ) for y, v in zip( np.linspace(0, height, labels), np.linspace(min_value, max_value, labels), ) ], ) elif isinstance(labels, list): if all(isinstance(n, str) for n in labels): label = svg.G( id="colourbar labels", elements=[ svg.Text( text=f"— {v}", class_=["normal"], x=width, y=y, dy=3, ) for y, v in zip(np.linspace(0, height, len(labels)), labels) ], ) if all(isinstance(n, float) or isinstance(n, int) for n in labels): label = svg.G( id="colourbar labels", elements=[ svg.Text( text=f"— {v:.02f}", class_=["normal"], x=width, y=(v - min_value) / (max_value - min_value) * height, dy=3, ) for v in labels ], ) cbar = svg.G( id="colourbar", elements=[ lines, label, svg.Rect( x=0, y=0, width=width, height=height, fill="none", stroke_width=border, stroke="black", ), ], transform=[ svg.Translate( x=int(self.total_width + width / 2), y=int((self.total_height - height) / 2), ) ], ) self.elements.append(cbar) self.total_width = self.total_width + 2 * width + 40 * bool(labels) @property def svg(self): return str( svg.SVG( width=self.total_width, height=self.total_height, elements=self.elements ) ) def __repr__(self): return f"""Matrix Visualisation: shape: {self.matrix.shape} size: {self.total_width}x{self.total_height} """