diff --git a/src/vectise/neural_net.py b/src/vectise/neural_net.py index b2b7d06..6cd3527 100644 --- a/src/vectise/neural_net.py +++ b/src/vectise/neural_net.py @@ -10,32 +10,49 @@ class NeuralNet: self, matrix: np.typing.NDArray | list[np.typing.NDArray] | None = None, shape: tuple[int] | None = None, + node_colour: str = "blue", + node_border: str = "black", + edge_colours: list[str] = ["green"], ): """Initialise the visualisation object. Parameters ---------- - self : - self matrix : np.typing.NDArray | None - matrix + matrix or list of matrices shape : tuple[int] | None shape """ if matrix is None and shape is None: logging.error("supply a matrix or at least its shape!") elif matrix is None: - self.matrix = np.ones(shape) + self.matrix = [ + np.ones((shape[i], shape[i + 1])) for i in range(len(shape) - 1) + ] self.shape = shape elif shape is None: - self.shape = matrix.shape - self.matrix = matrix + if not isinstance(matrix, list): + self.matrix = [matrix] + else: + self.matrix = matrix + shapes = [] + for m in self.matrix: + shapes.extend(m.shape) + self.shape = tuple(shapes[:-1:2] + [shapes[-1]]) + else: + logging.error( + "not implemented. best to drop the shapes. they will" + "get inferred from the matrices." + ) + if len(edge_colours) != len(self.matrix): + edge_colours = edge_colours * len(self.matrix) + self.elms: list[svg.Element] = [] self.r = 12 self.d_y = 32 self.d_x = 200 self.max_card = max(self.shape) - self.max_height = self.max_card * self.d_y + 2 * self.d_y + self.max_height = (self.max_card - 1) * self.d_y + 2 * self.d_y node_elms = [] underlying_rectangles: list[svg.Element] = [] @@ -59,8 +76,8 @@ class NeuralNet: cx=self.d_y + i_layer * self.d_x, cy=offset + self.d_y + self.d_y * i, r=self.r, - stroke="black", - fill="blue", + stroke=node_border, + fill=node_colour, stroke_width=1, ) ) @@ -94,7 +111,7 @@ class NeuralNet: for (i, p1), (j, p2) in itertools.product( enumerate(node_coords[i_layer]), enumerate(node_coords[i_layer + 1]) ): - if (self.matrix[i, j] == 0).all(): + if (self.matrix[i_layer][i, j] == 0).all(): continue edges.append( svg.Path( @@ -102,7 +119,7 @@ class NeuralNet: svg.M(p1[0] + self.r, p1[1]), svg.L(p2[0] - self.r, p2[1]), ], - stroke="green", + stroke=edge_colours[i_layer], # marker_end="url(#arrowhead)", ) ) @@ -120,7 +137,7 @@ class NeuralNet: def svg(self): return str( svg.SVG( - width=(len(self.shape) + 1) * self.d_x, + width=(len(self.shape) - 1) * self.d_x + 2 * self.d_y, height=self.max_height, elements=self.elms, )