from datetime import datetime import io import base64 from fastapi import APIRouter from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from sqlmodel import Session, func, select from sqlmodel.sql.expression import SelectOfScalar from db import Chemistry, Player, engine import networkx as nx import matplotlib matplotlib.use("agg") import matplotlib.pyplot as plt analysis_router = APIRouter(prefix="/analysis") C = Chemistry P = Player def sociogram_json(): nodes = [] necessary_nodes = set() links = [] with Session(engine) as session: for p in session.exec(select(P)).fetchall(): nodes.append({"id": p.name, "appearance": 1}) subquery = ( select(C.user, func.max(C.time).label("latest")) .where(C.time > datetime(2025, 2, 1, 10)) .group_by(C.user) .subquery() ) statement2 = select(C).join( subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest) ) for c in session.exec(statement2): # G.add_node(c.user) necessary_nodes.add(c.user) for p in c.love: # G.add_edge(c.user, p) # p_id = session.exec(select(P.id).where(P.name == p)).one() necessary_nodes.add(p) links.append({"source": c.user, "target": p}) # nodes = [n for n in nodes if n["name"] in necessary_nodes] return JSONResponse({"nodes": nodes, "links": links}) def sociogram_data(show: int | None = 2): G = nx.DiGraph() with Session(engine) as session: for p in session.exec(select(P)).fetchall(): G.add_node(p.name) subquery = ( select(C.user, func.max(C.time).label("latest")) .where(C.time > datetime(2025, 2, 1, 10)) .group_by(C.user) .subquery() ) statement2 = ( select(C) # .where(C.user.in_(["Kruse", "Franz", "ck"])) .join(subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest)) ) for c in session.exec(statement2): if show >= 1: for i, p in enumerate(c.love): G.add_edge(c.user, p, group="love", rank=i, popularity=1 - 0.08 * i) if show <= 1: for i, p in enumerate(c.hate): G.add_edge(c.user, p, group="hate", rank=8, popularity=-0.16) return G class Params(BaseModel): node_size: int | None = Field(default=2400, alias="nodeSize") font_size: int | None = Field(default=10, alias="fontSize") arrow_size: int | None = Field(default=20, alias="arrowSize") edge_width: float | None = Field(default=1, alias="edgeWidth") distance: float | None = 0.2 weighting: bool | None = True popularity: bool | None = True show: int | None = 2 ARROWSTYLE = {"love": "-|>", "hate": "-|>"} EDGESTYLE = {"love": "-", "hate": ":"} EDGECOLOR = {"love": "#404040", "hate": "#cc0000"} async def render_sociogram(params: Params): plt.figure(figsize=(16, 10), facecolor="none") ax = plt.gca() ax.set_facecolor("none") # Set the axis face color to none (transparent) ax.axis("off") # Turn off axis ticks and frames G = sociogram_data(show=params.show) pos = nx.spring_layout(G, scale=2, k=params.distance, iterations=50, seed=None) nodes = nx.draw_networkx_nodes( G, pos, node_color=[ v for k, v in G.in_degree(weight="popularity" if params.weighting else None) ] if params.popularity else "#99ccff", edgecolors="#404040", linewidths=0, # node_shape="8", node_size=params.node_size, cmap="coolwarm", alpha=0.86, ) if params.popularity: cbar = plt.colorbar(nodes) cbar.ax.set_xlabel("popularity") nx.draw_networkx_labels(G, pos, font_size=params.font_size) nx.draw_networkx_edges( G, pos, arrows=True, edge_color=[EDGECOLOR[G.edges()[*edge]["group"]] for edge in G.edges()], arrowsize=params.arrow_size, node_size=params.node_size, width=params.edge_width, style=[EDGESTYLE[G.edges()[*edge]["group"]] for edge in G.edges()], arrowstyle=[ARROWSTYLE[G.edges()[*edge]["group"]] for edge in G.edges()], connectionstyle="arc3,rad=0.12", alpha=[1 - 0.08 * G.edges()[*edge]["rank"] for edge in G.edges()] if params.weighting else 1, ) buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", dpi=300, transparent=True) buf.seek(0) encoded_image = base64.b64encode(buf.read()).decode("UTF-8") plt.close() return {"image": encoded_image} analysis_router.add_api_route("/json", endpoint=sociogram_json, methods=["GET"]) analysis_router.add_api_route("/image", endpoint=render_sociogram, methods=["POST"]) if __name__ == "__main__": with Session(engine) as session: statement: SelectOfScalar[P] = select(func.count(P.id)) print("players in DB: ", session.exec(statement).first()) G = sociogram_data() pos = nx.spring_layout(G, scale=1, k=2, iterations=50, seed=42)