works so far
This commit is contained in:
@@ -10,32 +10,49 @@ class NeuralNet:
|
||||
self,
|
||||
matrix: np.typing.NDArray | list[np.typing.NDArray] | None = None,
|
||||
shape: tuple[int] | None = None,
|
||||
node_colour: str = "blue",
|
||||
node_border: str = "black",
|
||||
edge_colours: list[str] = ["green"],
|
||||
):
|
||||
"""Initialise the visualisation object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self :
|
||||
self
|
||||
matrix : np.typing.NDArray | None
|
||||
matrix
|
||||
matrix or list of matrices
|
||||
shape : tuple[int] | None
|
||||
shape
|
||||
"""
|
||||
if matrix is None and shape is None:
|
||||
logging.error("supply a matrix or at least its shape!")
|
||||
elif matrix is None:
|
||||
self.matrix = np.ones(shape)
|
||||
self.matrix = [
|
||||
np.ones((shape[i], shape[i + 1])) for i in range(len(shape) - 1)
|
||||
]
|
||||
self.shape = shape
|
||||
elif shape is None:
|
||||
self.shape = matrix.shape
|
||||
self.matrix = matrix
|
||||
if not isinstance(matrix, list):
|
||||
self.matrix = [matrix]
|
||||
else:
|
||||
self.matrix = matrix
|
||||
shapes = []
|
||||
for m in self.matrix:
|
||||
shapes.extend(m.shape)
|
||||
self.shape = tuple(shapes[:-1:2] + [shapes[-1]])
|
||||
else:
|
||||
logging.error(
|
||||
"not implemented. best to drop the shapes. they will"
|
||||
"get inferred from the matrices."
|
||||
)
|
||||
if len(edge_colours) != len(self.matrix):
|
||||
edge_colours = edge_colours * len(self.matrix)
|
||||
|
||||
self.elms: list[svg.Element] = []
|
||||
self.r = 12
|
||||
self.d_y = 32
|
||||
self.d_x = 200
|
||||
self.max_card = max(self.shape)
|
||||
self.max_height = self.max_card * self.d_y + 2 * self.d_y
|
||||
self.max_height = (self.max_card - 1) * self.d_y + 2 * self.d_y
|
||||
|
||||
node_elms = []
|
||||
underlying_rectangles: list[svg.Element] = []
|
||||
@@ -59,8 +76,8 @@ class NeuralNet:
|
||||
cx=self.d_y + i_layer * self.d_x,
|
||||
cy=offset + self.d_y + self.d_y * i,
|
||||
r=self.r,
|
||||
stroke="black",
|
||||
fill="blue",
|
||||
stroke=node_border,
|
||||
fill=node_colour,
|
||||
stroke_width=1,
|
||||
)
|
||||
)
|
||||
@@ -94,7 +111,7 @@ class NeuralNet:
|
||||
for (i, p1), (j, p2) in itertools.product(
|
||||
enumerate(node_coords[i_layer]), enumerate(node_coords[i_layer + 1])
|
||||
):
|
||||
if (self.matrix[i, j] == 0).all():
|
||||
if (self.matrix[i_layer][i, j] == 0).all():
|
||||
continue
|
||||
edges.append(
|
||||
svg.Path(
|
||||
@@ -102,7 +119,7 @@ class NeuralNet:
|
||||
svg.M(p1[0] + self.r, p1[1]),
|
||||
svg.L(p2[0] - self.r, p2[1]),
|
||||
],
|
||||
stroke="green",
|
||||
stroke=edge_colours[i_layer],
|
||||
# marker_end="url(#arrowhead)",
|
||||
)
|
||||
)
|
||||
@@ -120,7 +137,7 @@ class NeuralNet:
|
||||
def svg(self):
|
||||
return str(
|
||||
svg.SVG(
|
||||
width=(len(self.shape) + 1) * self.d_x,
|
||||
width=(len(self.shape) - 1) * self.d_x + 2 * self.d_y,
|
||||
height=self.max_height,
|
||||
elements=self.elms,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user