works so far
This commit is contained in:
@@ -10,32 +10,49 @@ class NeuralNet:
|
|||||||
self,
|
self,
|
||||||
matrix: np.typing.NDArray | list[np.typing.NDArray] | None = None,
|
matrix: np.typing.NDArray | list[np.typing.NDArray] | None = None,
|
||||||
shape: tuple[int] | None = None,
|
shape: tuple[int] | None = None,
|
||||||
|
node_colour: str = "blue",
|
||||||
|
node_border: str = "black",
|
||||||
|
edge_colours: list[str] = ["green"],
|
||||||
):
|
):
|
||||||
"""Initialise the visualisation object.
|
"""Initialise the visualisation object.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
self :
|
|
||||||
self
|
|
||||||
matrix : np.typing.NDArray | None
|
matrix : np.typing.NDArray | None
|
||||||
matrix
|
matrix or list of matrices
|
||||||
shape : tuple[int] | None
|
shape : tuple[int] | None
|
||||||
shape
|
shape
|
||||||
"""
|
"""
|
||||||
if matrix is None and shape is None:
|
if matrix is None and shape is None:
|
||||||
logging.error("supply a matrix or at least its shape!")
|
logging.error("supply a matrix or at least its shape!")
|
||||||
elif matrix is None:
|
elif matrix is None:
|
||||||
self.matrix = np.ones(shape)
|
self.matrix = [
|
||||||
|
np.ones((shape[i], shape[i + 1])) for i in range(len(shape) - 1)
|
||||||
|
]
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
elif shape is None:
|
elif shape is None:
|
||||||
self.shape = matrix.shape
|
if not isinstance(matrix, list):
|
||||||
self.matrix = matrix
|
self.matrix = [matrix]
|
||||||
|
else:
|
||||||
|
self.matrix = matrix
|
||||||
|
shapes = []
|
||||||
|
for m in self.matrix:
|
||||||
|
shapes.extend(m.shape)
|
||||||
|
self.shape = tuple(shapes[:-1:2] + [shapes[-1]])
|
||||||
|
else:
|
||||||
|
logging.error(
|
||||||
|
"not implemented. best to drop the shapes. they will"
|
||||||
|
"get inferred from the matrices."
|
||||||
|
)
|
||||||
|
if len(edge_colours) != len(self.matrix):
|
||||||
|
edge_colours = edge_colours * len(self.matrix)
|
||||||
|
|
||||||
self.elms: list[svg.Element] = []
|
self.elms: list[svg.Element] = []
|
||||||
self.r = 12
|
self.r = 12
|
||||||
self.d_y = 32
|
self.d_y = 32
|
||||||
self.d_x = 200
|
self.d_x = 200
|
||||||
self.max_card = max(self.shape)
|
self.max_card = max(self.shape)
|
||||||
self.max_height = self.max_card * self.d_y + 2 * self.d_y
|
self.max_height = (self.max_card - 1) * self.d_y + 2 * self.d_y
|
||||||
|
|
||||||
node_elms = []
|
node_elms = []
|
||||||
underlying_rectangles: list[svg.Element] = []
|
underlying_rectangles: list[svg.Element] = []
|
||||||
@@ -59,8 +76,8 @@ class NeuralNet:
|
|||||||
cx=self.d_y + i_layer * self.d_x,
|
cx=self.d_y + i_layer * self.d_x,
|
||||||
cy=offset + self.d_y + self.d_y * i,
|
cy=offset + self.d_y + self.d_y * i,
|
||||||
r=self.r,
|
r=self.r,
|
||||||
stroke="black",
|
stroke=node_border,
|
||||||
fill="blue",
|
fill=node_colour,
|
||||||
stroke_width=1,
|
stroke_width=1,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -94,7 +111,7 @@ class NeuralNet:
|
|||||||
for (i, p1), (j, p2) in itertools.product(
|
for (i, p1), (j, p2) in itertools.product(
|
||||||
enumerate(node_coords[i_layer]), enumerate(node_coords[i_layer + 1])
|
enumerate(node_coords[i_layer]), enumerate(node_coords[i_layer + 1])
|
||||||
):
|
):
|
||||||
if (self.matrix[i, j] == 0).all():
|
if (self.matrix[i_layer][i, j] == 0).all():
|
||||||
continue
|
continue
|
||||||
edges.append(
|
edges.append(
|
||||||
svg.Path(
|
svg.Path(
|
||||||
@@ -102,7 +119,7 @@ class NeuralNet:
|
|||||||
svg.M(p1[0] + self.r, p1[1]),
|
svg.M(p1[0] + self.r, p1[1]),
|
||||||
svg.L(p2[0] - self.r, p2[1]),
|
svg.L(p2[0] - self.r, p2[1]),
|
||||||
],
|
],
|
||||||
stroke="green",
|
stroke=edge_colours[i_layer],
|
||||||
# marker_end="url(#arrowhead)",
|
# marker_end="url(#arrowhead)",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -120,7 +137,7 @@ class NeuralNet:
|
|||||||
def svg(self):
|
def svg(self):
|
||||||
return str(
|
return str(
|
||||||
svg.SVG(
|
svg.SVG(
|
||||||
width=(len(self.shape) + 1) * self.d_x,
|
width=(len(self.shape) - 1) * self.d_x + 2 * self.d_y,
|
||||||
height=self.max_height,
|
height=self.max_height,
|
||||||
elements=self.elms,
|
elements=self.elms,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user