works so far

This commit is contained in:
2024-12-05 14:03:26 +01:00
parent fe428f1e05
commit dfca46b5c3

View File

@@ -10,32 +10,49 @@ class NeuralNet:
self,
matrix: np.typing.NDArray | list[np.typing.NDArray] | None = None,
shape: tuple[int] | None = None,
node_colour: str = "blue",
node_border: str = "black",
edge_colours: list[str] = ["green"],
):
"""Initialise the visualisation object.
Parameters
----------
self :
self
matrix : np.typing.NDArray | None
matrix
matrix or list of matrices
shape : tuple[int] | None
shape
"""
if matrix is None and shape is None:
logging.error("supply a matrix or at least its shape!")
elif matrix is None:
self.matrix = np.ones(shape)
self.matrix = [
np.ones((shape[i], shape[i + 1])) for i in range(len(shape) - 1)
]
self.shape = shape
elif shape is None:
self.shape = matrix.shape
if not isinstance(matrix, list):
self.matrix = [matrix]
else:
self.matrix = matrix
shapes = []
for m in self.matrix:
shapes.extend(m.shape)
self.shape = tuple(shapes[:-1:2] + [shapes[-1]])
else:
logging.error(
"not implemented. best to drop the shapes. they will"
"get inferred from the matrices."
)
if len(edge_colours) != len(self.matrix):
edge_colours = edge_colours * len(self.matrix)
self.elms: list[svg.Element] = []
self.r = 12
self.d_y = 32
self.d_x = 200
self.max_card = max(self.shape)
self.max_height = self.max_card * self.d_y + 2 * self.d_y
self.max_height = (self.max_card - 1) * self.d_y + 2 * self.d_y
node_elms = []
underlying_rectangles: list[svg.Element] = []
@@ -59,8 +76,8 @@ class NeuralNet:
cx=self.d_y + i_layer * self.d_x,
cy=offset + self.d_y + self.d_y * i,
r=self.r,
stroke="black",
fill="blue",
stroke=node_border,
fill=node_colour,
stroke_width=1,
)
)
@@ -94,7 +111,7 @@ class NeuralNet:
for (i, p1), (j, p2) in itertools.product(
enumerate(node_coords[i_layer]), enumerate(node_coords[i_layer + 1])
):
if (self.matrix[i, j] == 0).all():
if (self.matrix[i_layer][i, j] == 0).all():
continue
edges.append(
svg.Path(
@@ -102,7 +119,7 @@ class NeuralNet:
svg.M(p1[0] + self.r, p1[1]),
svg.L(p2[0] - self.r, p2[1]),
],
stroke="green",
stroke=edge_colours[i_layer],
# marker_end="url(#arrowhead)",
)
)
@@ -120,7 +137,7 @@ class NeuralNet:
def svg(self):
return str(
svg.SVG(
width=(len(self.shape) + 1) * self.d_x,
width=(len(self.shape) - 1) * self.d_x + 2 * self.d_y,
height=self.max_height,
elements=self.elms,
)