feat: visualise neural net

This commit is contained in:
2024-12-02 16:48:35 +01:00
parent 4201899545
commit fe428f1e05
6 changed files with 233 additions and 20 deletions

View File

@@ -6,5 +6,6 @@ readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"numpy>=2.1.3",
"pyment>=0.3.3",
"svg-py>=1.5.0",
]

0
src/vectise/__init__.py Normal file
View File

View File

@@ -4,14 +4,28 @@ import numpy as np
# https://bsouthga.dev/posts/colour-gradients-with-python
def hex_to_RGB(hex):
""" "#FFFFFF" -> [255,255,255]"""
def hex_to_RGB(hex_string: str):
""" "#FFFFFF" -> [255,255,255]
:param hex_string:
:type hex_string: str :
:param hex_string: str:
"""
# Pass 16 to the integer function for change of base
return [int(hex[i : i + 2], 16) for i in range(1, 6, 2)]
return [int(hex_string[i : i + 2], 16) for i in range(1, 6, 2)]
def RGB_to_hex(RGB):
"""[255,255,255] -> "#FFFFFF" """
def RGB_to_hex(RGB: list[int]):
"""[255,255,255] -> "#FFFFFF"
:param RGB:
:type RGB: list[int] :
:param RGB: list[int]:
"""
# Components need to be integers for hex to make sense
RGB = [int(x) for x in RGB]
return "#" + "".join(
@@ -19,10 +33,17 @@ def RGB_to_hex(RGB):
)
def colour_dict(gradient):
def colour_dict(gradient: dict):
"""Takes in a list of RGB sub-lists and returns dictionary of
colours in RGB and hex form for use in a graphing function
defined later on."""
defined later on.
:param gradient:
:type gradient: dict :
:param gradient: dict:
"""
return {
"hex": [RGB_to_hex(RGB) for RGB in gradient],
"r": [RGB[0] for RGB in gradient],
@@ -31,11 +52,24 @@ def colour_dict(gradient):
}
def linear_gradient(start_hex, finish_hex="#FFFFFF", n=10):
def linear_gradient(start_hex: str, finish_hex: str = "#FFFFFF", n: int = 10):
"""returns a gradient list of (n) colours between
two hex colours. start_hex and finish_hex
should be the full six-digit colour string,
inlcuding the number sign ("#FFFFFF")"""
inlcuding the number sign ("#FFFFFF")
:param start_hex:
:type start_hex: str :
:param finish_hex: (Default value = "#FFFFFF")
:type finish_hex: str :
:param n: (Default value = 10)
:type n: int :
:param start_hex: str:
:param finish_hex: str: (Default value = "#FFFFFF")
:param n: int: (Default value = 10)
"""
# Starting and ending colours in RGB form
s = hex_to_RGB(start_hex)
f = hex_to_RGB(finish_hex)
@@ -53,10 +87,20 @@ def linear_gradient(start_hex, finish_hex="#FFFFFF", n=10):
return colour_dict(RGB_list)
def polylinear_gradient(colours, n):
def polylinear_gradient(colours: list[str], n: int):
"""returns a list of colours forming linear gradients between
all sequential pairs of colours. "n" specifies the total
number of desired output colours"""
number of desired output colours
:param colours:
:type colours: list[str] :
:param n:
:type n: int :
:param colours: list[str]:
:param n: int:
"""
# The number of colours per individual linear gradient
n_out = int(float(n) / (len(colours) - 1))
# returns dictionary defined by colour_dict()
@@ -73,16 +117,20 @@ def polylinear_gradient(colours, n):
class ColourMap(ABC):
""" """
@abstractmethod
def __call__(self, v: float): ...
class LinearGradientColourMap(ColourMap):
""" """
def __init__(
self,
colours: list[str] | None = ["#ff0000", "#ffffff", "#0000ff"],
min_value: float | None = 0,
max_value: float | None = 1,
min_value: float = 0,
max_value: float = 1,
bins: int = 100,
):
self.colours = polylinear_gradient(colours, bins)
@@ -96,8 +144,29 @@ class LinearGradientColourMap(ColourMap):
class RandomColourMap(ColourMap):
def __init__(self, random_state: int | list[int] | None = [2, 3, 4, 5, 6]):
""" """
def __init__(self, random_state: int | list[int] | None = None):
self.rgen = np.random.default_rng(random_state)
def __call__(self, v: float):
return RGB_to_hex([x * 255 for x in self.rgen.random(3)])
class BinaryColourMap(ColourMap):
def __init__(self, colours: list[str]):
self.colours = colours
def __call__(self, v: float):
return self.colours[np.isclose(v, 0)]
class RandomChoiceColourMap(ColourMap):
""" """
def __init__(self, colours: list[str], random_state: int | list[int] | None = None):
self.colours = colours
self.rgen = np.random.default_rng(random_state)
def __call__(self, v: float):
return self.colours[int(self.rgen.random() * len(self.colours) - 1)]

View File

@@ -1,6 +1,6 @@
import numpy as np
import svg
from colours import ColourMap
from .colours import ColourMap
class MatrixVisualisation:
@@ -43,12 +43,12 @@ class MatrixVisualisation:
if text:
self.elements.append(
svg.Text(
x=x + width / 5,
x=x + width / 3,
y=y + 3 * height / 4,
textLength=width / 2,
lengthAdjust="spacingAndGlyphs",
# textLength=width / 2,
# lengthAdjust="spacingAndGlyphs",
class_=["mono"],
text=f"{self.matrix[i, j]:.02f}",
text=f"{int(self.matrix[i, j])}",
)
)
@@ -60,7 +60,7 @@ class MatrixVisualisation:
width: int = 20,
resolution: int = 256,
border: int | bool = 1,
labels: int | list[str] | bool = False,
labels: int | list[str | int] | bool = False,
):
if height is None:
height = int(self.total_height * 2 / 3)

132
src/vectise/neural_net.py Normal file
View File

@@ -0,0 +1,132 @@
import svg
import logging
import itertools
import numpy as np
class NeuralNet:
def __init__(
self,
matrix: np.typing.NDArray | list[np.typing.NDArray] | None = None,
shape: tuple[int] | None = None,
):
"""Initialise the visualisation object.
Parameters
----------
self :
self
matrix : np.typing.NDArray | None
matrix
shape : tuple[int] | None
shape
"""
if matrix is None and shape is None:
logging.error("supply a matrix or at least its shape!")
elif matrix is None:
self.matrix = np.ones(shape)
self.shape = shape
elif shape is None:
self.shape = matrix.shape
self.matrix = matrix
self.elms: list[svg.Element] = []
self.r = 12
self.d_y = 32
self.d_x = 200
self.max_card = max(self.shape)
self.max_height = self.max_card * self.d_y + 2 * self.d_y
node_elms = []
underlying_rectangles: list[svg.Element] = []
node_coords = []
for i_layer, card in enumerate(self.shape):
node_coord = []
offset = abs(self.max_card - card) / 2 * self.d_y
underlying_rectangles.append(
svg.Rect(
x=self.d_y - 3 / 2 * self.r + i_layer * self.d_x,
y=offset + self.d_y - 3 / 2 * self.r,
width=3 * self.r,
height=(card - 1) * self.d_y + 3 * self.r,
rx=self.r,
fill="lightgrey",
)
)
for i in range(card):
node_elms.append(
svg.Circle(
cx=self.d_y + i_layer * self.d_x,
cy=offset + self.d_y + self.d_y * i,
r=self.r,
stroke="black",
fill="blue",
stroke_width=1,
)
)
node_coord.append(
(self.d_y + i_layer * self.d_x, offset + self.d_y + self.d_y * i)
)
node_coords.append(node_coord)
arrowhead = svg.Defs(
elements=[
svg.Marker(
id="arrowhead",
elements=[
svg.Path(
d=[svg.M(0, 0), svg.L(10, 5), svg.L(0, 10), svg.Z()],
fill="green",
)
],
markerWidth=6,
markerHeight=6,
orient="auto-start-reverse",
refX=8,
refY=5,
viewBox=svg.ViewBoxSpec(0, 0, 10, 10),
)
]
)
edges = []
for i_layer in range(len(node_coords) - 1):
for (i, p1), (j, p2) in itertools.product(
enumerate(node_coords[i_layer]), enumerate(node_coords[i_layer + 1])
):
if (self.matrix[i, j] == 0).all():
continue
edges.append(
svg.Path(
d=[
svg.M(p1[0] + self.r, p1[1]),
svg.L(p2[0] - self.r, p2[1]),
],
stroke="green",
# marker_end="url(#arrowhead)",
)
)
self.elms.append(arrowhead)
self.elms.append(svg.Style(text=".mono {font: monospace; text-align: center;}"))
self.elms.append(svg.Style(text=".small {font-size: 25%;}"))
self.elms.append(svg.Style(text=".normal {font-size: 12px;}"))
self.elms.extend(underlying_rectangles)
self.elms.extend(edges)
self.elms.extend(node_elms)
@property
def svg(self):
return str(
svg.SVG(
width=(len(self.shape) + 1) * self.d_x,
height=self.max_height,
elements=self.elms,
)
)
def __repr__(self):
return f"""Neural Net Visualisation:
shape: {self.shape}
"""

11
uv.lock generated
View File

@@ -39,6 +39,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/86/09/a5ab407bd7f5f5599e6a9261f964ace03a73e7c6928de906981c31c38082/numpy-2.1.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4", size = 12644098 },
]
[[package]]
name = "pyment"
version = "0.3.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/dd/9e/c58a151c7020f6fdd48eea0085a9d1c91a57da19fa4e7bff0daf930c9900/Pyment-0.3.3.tar.gz", hash = "sha256:951a4c52d6791ccec55bc739811169eed69917d3874f5fe722866623a697f39d", size = 21003 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/52/01/810e174c28a7dcf5f91c048faf69c84eafee60c9a844e4ce21671b2e99bb/Pyment-0.3.3-py2.py3-none-any.whl", hash = "sha256:a0c6ec59d06d24aeec3eaecb22115d0dc95d09e14209b2df838381fdf47a78cc", size = 21924 },
]
[[package]]
name = "svg-py"
version = "1.5.0"
@@ -54,11 +63,13 @@ version = "0.1.0"
source = { virtual = "." }
dependencies = [
{ name = "numpy" },
{ name = "pyment" },
{ name = "svg-py" },
]
[package.metadata]
requires-dist = [
{ name = "numpy", specifier = ">=2.1.3" },
{ name = "pyment", specifier = ">=0.3.3" },
{ name = "svg-py", specifier = ">=1.5.0" },
]