feat: visualise neural net

This commit is contained in:
2024-12-02 16:48:35 +01:00
parent 4201899545
commit fe428f1e05
6 changed files with 233 additions and 20 deletions

View File

@@ -6,5 +6,6 @@ readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"numpy>=2.1.3", "numpy>=2.1.3",
"pyment>=0.3.3",
"svg-py>=1.5.0", "svg-py>=1.5.0",
] ]

0
src/vectise/__init__.py Normal file
View File

View File

@@ -4,14 +4,28 @@ import numpy as np
# https://bsouthga.dev/posts/colour-gradients-with-python # https://bsouthga.dev/posts/colour-gradients-with-python
def hex_to_RGB(hex): def hex_to_RGB(hex_string: str):
""" "#FFFFFF" -> [255,255,255]""" """ "#FFFFFF" -> [255,255,255]
:param hex_string:
:type hex_string: str :
:param hex_string: str:
"""
# Pass 16 to the integer function for change of base # Pass 16 to the integer function for change of base
return [int(hex[i : i + 2], 16) for i in range(1, 6, 2)] return [int(hex_string[i : i + 2], 16) for i in range(1, 6, 2)]
def RGB_to_hex(RGB): def RGB_to_hex(RGB: list[int]):
"""[255,255,255] -> "#FFFFFF" """ """[255,255,255] -> "#FFFFFF"
:param RGB:
:type RGB: list[int] :
:param RGB: list[int]:
"""
# Components need to be integers for hex to make sense # Components need to be integers for hex to make sense
RGB = [int(x) for x in RGB] RGB = [int(x) for x in RGB]
return "#" + "".join( return "#" + "".join(
@@ -19,10 +33,17 @@ def RGB_to_hex(RGB):
) )
def colour_dict(gradient): def colour_dict(gradient: dict):
"""Takes in a list of RGB sub-lists and returns dictionary of """Takes in a list of RGB sub-lists and returns dictionary of
colours in RGB and hex form for use in a graphing function colours in RGB and hex form for use in a graphing function
defined later on.""" defined later on.
:param gradient:
:type gradient: dict :
:param gradient: dict:
"""
return { return {
"hex": [RGB_to_hex(RGB) for RGB in gradient], "hex": [RGB_to_hex(RGB) for RGB in gradient],
"r": [RGB[0] for RGB in gradient], "r": [RGB[0] for RGB in gradient],
@@ -31,11 +52,24 @@ def colour_dict(gradient):
} }
def linear_gradient(start_hex, finish_hex="#FFFFFF", n=10): def linear_gradient(start_hex: str, finish_hex: str = "#FFFFFF", n: int = 10):
"""returns a gradient list of (n) colours between """returns a gradient list of (n) colours between
two hex colours. start_hex and finish_hex two hex colours. start_hex and finish_hex
should be the full six-digit colour string, should be the full six-digit colour string,
inlcuding the number sign ("#FFFFFF")""" inlcuding the number sign ("#FFFFFF")
:param start_hex:
:type start_hex: str :
:param finish_hex: (Default value = "#FFFFFF")
:type finish_hex: str :
:param n: (Default value = 10)
:type n: int :
:param start_hex: str:
:param finish_hex: str: (Default value = "#FFFFFF")
:param n: int: (Default value = 10)
"""
# Starting and ending colours in RGB form # Starting and ending colours in RGB form
s = hex_to_RGB(start_hex) s = hex_to_RGB(start_hex)
f = hex_to_RGB(finish_hex) f = hex_to_RGB(finish_hex)
@@ -53,10 +87,20 @@ def linear_gradient(start_hex, finish_hex="#FFFFFF", n=10):
return colour_dict(RGB_list) return colour_dict(RGB_list)
def polylinear_gradient(colours, n): def polylinear_gradient(colours: list[str], n: int):
"""returns a list of colours forming linear gradients between """returns a list of colours forming linear gradients between
all sequential pairs of colours. "n" specifies the total all sequential pairs of colours. "n" specifies the total
number of desired output colours""" number of desired output colours
:param colours:
:type colours: list[str] :
:param n:
:type n: int :
:param colours: list[str]:
:param n: int:
"""
# The number of colours per individual linear gradient # The number of colours per individual linear gradient
n_out = int(float(n) / (len(colours) - 1)) n_out = int(float(n) / (len(colours) - 1))
# returns dictionary defined by colour_dict() # returns dictionary defined by colour_dict()
@@ -73,16 +117,20 @@ def polylinear_gradient(colours, n):
class ColourMap(ABC): class ColourMap(ABC):
""" """
@abstractmethod @abstractmethod
def __call__(self, v: float): ... def __call__(self, v: float): ...
class LinearGradientColourMap(ColourMap): class LinearGradientColourMap(ColourMap):
""" """
def __init__( def __init__(
self, self,
colours: list[str] | None = ["#ff0000", "#ffffff", "#0000ff"], colours: list[str] | None = ["#ff0000", "#ffffff", "#0000ff"],
min_value: float | None = 0, min_value: float = 0,
max_value: float | None = 1, max_value: float = 1,
bins: int = 100, bins: int = 100,
): ):
self.colours = polylinear_gradient(colours, bins) self.colours = polylinear_gradient(colours, bins)
@@ -96,8 +144,29 @@ class LinearGradientColourMap(ColourMap):
class RandomColourMap(ColourMap): class RandomColourMap(ColourMap):
def __init__(self, random_state: int | list[int] | None = [2, 3, 4, 5, 6]): """ """
def __init__(self, random_state: int | list[int] | None = None):
self.rgen = np.random.default_rng(random_state) self.rgen = np.random.default_rng(random_state)
def __call__(self, v: float): def __call__(self, v: float):
return RGB_to_hex([x * 255 for x in self.rgen.random(3)]) return RGB_to_hex([x * 255 for x in self.rgen.random(3)])
class BinaryColourMap(ColourMap):
def __init__(self, colours: list[str]):
self.colours = colours
def __call__(self, v: float):
return self.colours[np.isclose(v, 0)]
class RandomChoiceColourMap(ColourMap):
""" """
def __init__(self, colours: list[str], random_state: int | list[int] | None = None):
self.colours = colours
self.rgen = np.random.default_rng(random_state)
def __call__(self, v: float):
return self.colours[int(self.rgen.random() * len(self.colours) - 1)]

View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
import svg import svg
from colours import ColourMap from .colours import ColourMap
class MatrixVisualisation: class MatrixVisualisation:
@@ -43,12 +43,12 @@ class MatrixVisualisation:
if text: if text:
self.elements.append( self.elements.append(
svg.Text( svg.Text(
x=x + width / 5, x=x + width / 3,
y=y + 3 * height / 4, y=y + 3 * height / 4,
textLength=width / 2, # textLength=width / 2,
lengthAdjust="spacingAndGlyphs", # lengthAdjust="spacingAndGlyphs",
class_=["mono"], class_=["mono"],
text=f"{self.matrix[i, j]:.02f}", text=f"{int(self.matrix[i, j])}",
) )
) )
@@ -60,7 +60,7 @@ class MatrixVisualisation:
width: int = 20, width: int = 20,
resolution: int = 256, resolution: int = 256,
border: int | bool = 1, border: int | bool = 1,
labels: int | list[str] | bool = False, labels: int | list[str | int] | bool = False,
): ):
if height is None: if height is None:
height = int(self.total_height * 2 / 3) height = int(self.total_height * 2 / 3)

132
src/vectise/neural_net.py Normal file
View File

@@ -0,0 +1,132 @@
import svg
import logging
import itertools
import numpy as np
class NeuralNet:
def __init__(
self,
matrix: np.typing.NDArray | list[np.typing.NDArray] | None = None,
shape: tuple[int] | None = None,
):
"""Initialise the visualisation object.
Parameters
----------
self :
self
matrix : np.typing.NDArray | None
matrix
shape : tuple[int] | None
shape
"""
if matrix is None and shape is None:
logging.error("supply a matrix or at least its shape!")
elif matrix is None:
self.matrix = np.ones(shape)
self.shape = shape
elif shape is None:
self.shape = matrix.shape
self.matrix = matrix
self.elms: list[svg.Element] = []
self.r = 12
self.d_y = 32
self.d_x = 200
self.max_card = max(self.shape)
self.max_height = self.max_card * self.d_y + 2 * self.d_y
node_elms = []
underlying_rectangles: list[svg.Element] = []
node_coords = []
for i_layer, card in enumerate(self.shape):
node_coord = []
offset = abs(self.max_card - card) / 2 * self.d_y
underlying_rectangles.append(
svg.Rect(
x=self.d_y - 3 / 2 * self.r + i_layer * self.d_x,
y=offset + self.d_y - 3 / 2 * self.r,
width=3 * self.r,
height=(card - 1) * self.d_y + 3 * self.r,
rx=self.r,
fill="lightgrey",
)
)
for i in range(card):
node_elms.append(
svg.Circle(
cx=self.d_y + i_layer * self.d_x,
cy=offset + self.d_y + self.d_y * i,
r=self.r,
stroke="black",
fill="blue",
stroke_width=1,
)
)
node_coord.append(
(self.d_y + i_layer * self.d_x, offset + self.d_y + self.d_y * i)
)
node_coords.append(node_coord)
arrowhead = svg.Defs(
elements=[
svg.Marker(
id="arrowhead",
elements=[
svg.Path(
d=[svg.M(0, 0), svg.L(10, 5), svg.L(0, 10), svg.Z()],
fill="green",
)
],
markerWidth=6,
markerHeight=6,
orient="auto-start-reverse",
refX=8,
refY=5,
viewBox=svg.ViewBoxSpec(0, 0, 10, 10),
)
]
)
edges = []
for i_layer in range(len(node_coords) - 1):
for (i, p1), (j, p2) in itertools.product(
enumerate(node_coords[i_layer]), enumerate(node_coords[i_layer + 1])
):
if (self.matrix[i, j] == 0).all():
continue
edges.append(
svg.Path(
d=[
svg.M(p1[0] + self.r, p1[1]),
svg.L(p2[0] - self.r, p2[1]),
],
stroke="green",
# marker_end="url(#arrowhead)",
)
)
self.elms.append(arrowhead)
self.elms.append(svg.Style(text=".mono {font: monospace; text-align: center;}"))
self.elms.append(svg.Style(text=".small {font-size: 25%;}"))
self.elms.append(svg.Style(text=".normal {font-size: 12px;}"))
self.elms.extend(underlying_rectangles)
self.elms.extend(edges)
self.elms.extend(node_elms)
@property
def svg(self):
return str(
svg.SVG(
width=(len(self.shape) + 1) * self.d_x,
height=self.max_height,
elements=self.elms,
)
)
def __repr__(self):
return f"""Neural Net Visualisation:
shape: {self.shape}
"""

11
uv.lock generated
View File

@@ -39,6 +39,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/86/09/a5ab407bd7f5f5599e6a9261f964ace03a73e7c6928de906981c31c38082/numpy-2.1.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4", size = 12644098 }, { url = "https://files.pythonhosted.org/packages/86/09/a5ab407bd7f5f5599e6a9261f964ace03a73e7c6928de906981c31c38082/numpy-2.1.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4", size = 12644098 },
] ]
[[package]]
name = "pyment"
version = "0.3.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/dd/9e/c58a151c7020f6fdd48eea0085a9d1c91a57da19fa4e7bff0daf930c9900/Pyment-0.3.3.tar.gz", hash = "sha256:951a4c52d6791ccec55bc739811169eed69917d3874f5fe722866623a697f39d", size = 21003 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/52/01/810e174c28a7dcf5f91c048faf69c84eafee60c9a844e4ce21671b2e99bb/Pyment-0.3.3-py2.py3-none-any.whl", hash = "sha256:a0c6ec59d06d24aeec3eaecb22115d0dc95d09e14209b2df838381fdf47a78cc", size = 21924 },
]
[[package]] [[package]]
name = "svg-py" name = "svg-py"
version = "1.5.0" version = "1.5.0"
@@ -54,11 +63,13 @@ version = "0.1.0"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "numpy" }, { name = "numpy" },
{ name = "pyment" },
{ name = "svg-py" }, { name = "svg-py" },
] ]
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "numpy", specifier = ">=2.1.3" }, { name = "numpy", specifier = ">=2.1.3" },
{ name = "pyment", specifier = ">=0.3.3" },
{ name = "svg-py", specifier = ">=1.5.0" }, { name = "svg-py", specifier = ">=1.5.0" },
] ]