works so far

This commit is contained in:
2024-12-05 14:03:26 +01:00
parent fe428f1e05
commit dfca46b5c3

View File

@@ -10,32 +10,49 @@ class NeuralNet:
self, self,
matrix: np.typing.NDArray | list[np.typing.NDArray] | None = None, matrix: np.typing.NDArray | list[np.typing.NDArray] | None = None,
shape: tuple[int] | None = None, shape: tuple[int] | None = None,
node_colour: str = "blue",
node_border: str = "black",
edge_colours: list[str] = ["green"],
): ):
"""Initialise the visualisation object. """Initialise the visualisation object.
Parameters Parameters
---------- ----------
self :
self
matrix : np.typing.NDArray | None matrix : np.typing.NDArray | None
matrix matrix or list of matrices
shape : tuple[int] | None shape : tuple[int] | None
shape shape
""" """
if matrix is None and shape is None: if matrix is None and shape is None:
logging.error("supply a matrix or at least its shape!") logging.error("supply a matrix or at least its shape!")
elif matrix is None: elif matrix is None:
self.matrix = np.ones(shape) self.matrix = [
np.ones((shape[i], shape[i + 1])) for i in range(len(shape) - 1)
]
self.shape = shape self.shape = shape
elif shape is None: elif shape is None:
self.shape = matrix.shape if not isinstance(matrix, list):
self.matrix = [matrix]
else:
self.matrix = matrix self.matrix = matrix
shapes = []
for m in self.matrix:
shapes.extend(m.shape)
self.shape = tuple(shapes[:-1:2] + [shapes[-1]])
else:
logging.error(
"not implemented. best to drop the shapes. they will"
"get inferred from the matrices."
)
if len(edge_colours) != len(self.matrix):
edge_colours = edge_colours * len(self.matrix)
self.elms: list[svg.Element] = [] self.elms: list[svg.Element] = []
self.r = 12 self.r = 12
self.d_y = 32 self.d_y = 32
self.d_x = 200 self.d_x = 200
self.max_card = max(self.shape) self.max_card = max(self.shape)
self.max_height = self.max_card * self.d_y + 2 * self.d_y self.max_height = (self.max_card - 1) * self.d_y + 2 * self.d_y
node_elms = [] node_elms = []
underlying_rectangles: list[svg.Element] = [] underlying_rectangles: list[svg.Element] = []
@@ -59,8 +76,8 @@ class NeuralNet:
cx=self.d_y + i_layer * self.d_x, cx=self.d_y + i_layer * self.d_x,
cy=offset + self.d_y + self.d_y * i, cy=offset + self.d_y + self.d_y * i,
r=self.r, r=self.r,
stroke="black", stroke=node_border,
fill="blue", fill=node_colour,
stroke_width=1, stroke_width=1,
) )
) )
@@ -94,7 +111,7 @@ class NeuralNet:
for (i, p1), (j, p2) in itertools.product( for (i, p1), (j, p2) in itertools.product(
enumerate(node_coords[i_layer]), enumerate(node_coords[i_layer + 1]) enumerate(node_coords[i_layer]), enumerate(node_coords[i_layer + 1])
): ):
if (self.matrix[i, j] == 0).all(): if (self.matrix[i_layer][i, j] == 0).all():
continue continue
edges.append( edges.append(
svg.Path( svg.Path(
@@ -102,7 +119,7 @@ class NeuralNet:
svg.M(p1[0] + self.r, p1[1]), svg.M(p1[0] + self.r, p1[1]),
svg.L(p2[0] - self.r, p2[1]), svg.L(p2[0] - self.r, p2[1]),
], ],
stroke="green", stroke=edge_colours[i_layer],
# marker_end="url(#arrowhead)", # marker_end="url(#arrowhead)",
) )
) )
@@ -120,7 +137,7 @@ class NeuralNet:
def svg(self): def svg(self):
return str( return str(
svg.SVG( svg.SVG(
width=(len(self.shape) + 1) * self.d_x, width=(len(self.shape) - 1) * self.d_x + 2 * self.d_y,
height=self.max_height, height=self.max_height,
elements=self.elms, elements=self.elms,
) )