cutt/analysis.py

155 lines
5.1 KiB
Python
Raw Normal View History

2025-02-16 16:38:55 +01:00
from datetime import datetime
import io
import base64
2025-02-16 16:38:55 +01:00
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from sqlmodel import Session, func, select
from sqlmodel.sql.expression import SelectOfScalar
2025-02-16 16:38:55 +01:00
from db import Chemistry, Player, engine
import networkx as nx
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
2025-02-14 20:10:21 +01:00
2025-02-16 16:38:55 +01:00
analysis_router = APIRouter(prefix="/analysis")
C = Chemistry
P = Player
def sociogram_json():
nodes = []
necessary_nodes = set()
links = []
with Session(engine) as session:
for p in session.exec(select(P)).fetchall():
nodes.append({"id": p.name, "appearance": 1})
subquery = (
select(C.user, func.max(C.time).label("latest"))
.where(C.time > datetime(2025, 2, 1, 10))
.group_by(C.user)
.subquery()
)
statement2 = select(C).join(
subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest)
)
for c in session.exec(statement2):
# G.add_node(c.user)
necessary_nodes.add(c.user)
for p in c.love:
# G.add_edge(c.user, p)
# p_id = session.exec(select(P.id).where(P.name == p)).one()
necessary_nodes.add(p)
links.append({"source": c.user, "target": p})
# nodes = [n for n in nodes if n["name"] in necessary_nodes]
return JSONResponse({"nodes": nodes, "links": links})
def sociogram_data(show: int | None = 2):
G = nx.DiGraph()
with Session(engine) as session:
for p in session.exec(select(P)).fetchall():
G.add_node(p.name)
subquery = (
select(C.user, func.max(C.time).label("latest"))
.where(C.time > datetime(2025, 2, 1, 10))
.group_by(C.user)
.subquery()
)
statement2 = (
select(C)
# .where(C.user.in_(["Kruse", "Franz", "ck"]))
.join(subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest))
)
for c in session.exec(statement2):
if show >= 1:
for i, p in enumerate(c.love):
G.add_edge(c.user, p, group="love", rank=i, popularity=1 - 0.08 * i)
if show <= 1:
2025-02-12 17:23:18 +01:00
for i, p in enumerate(c.hate):
G.add_edge(c.user, p, group="hate", rank=8, popularity=-0.16)
return G
class Params(BaseModel):
node_size: int | None = Field(default=2400, alias="nodeSize")
font_size: int | None = Field(default=10, alias="fontSize")
arrow_size: int | None = Field(default=20, alias="arrowSize")
edge_width: float | None = Field(default=1, alias="edgeWidth")
2025-02-11 14:14:23 +01:00
distance: float | None = 0.2
weighting: bool | None = True
popularity: bool | None = True
show: int | None = 2
2025-02-12 17:23:18 +01:00
ARROWSTYLE = {"love": "-|>", "hate": "-|>"}
EDGESTYLE = {"love": "-", "hate": ":"}
EDGECOLOR = {"love": "#404040", "hate": "#cc0000"}
2025-02-12 12:23:17 +01:00
async def render_sociogram(params: Params):
plt.figure(figsize=(16, 10), facecolor="none")
ax = plt.gca()
ax.set_facecolor("none") # Set the axis face color to none (transparent)
ax.axis("off") # Turn off axis ticks and frames
G = sociogram_data(show=params.show)
2025-02-11 14:14:23 +01:00
pos = nx.spring_layout(G, scale=2, k=params.distance, iterations=50, seed=None)
nodes = nx.draw_networkx_nodes(
G,
pos,
node_color=[
v for k, v in G.in_degree(weight="popularity" if params.weighting else None)
]
if params.popularity
else "#99ccff",
edgecolors="#404040",
linewidths=0,
# node_shape="8",
node_size=params.node_size,
cmap="coolwarm",
alpha=0.86,
)
if params.popularity:
2025-02-12 17:23:18 +01:00
cbar = plt.colorbar(nodes)
cbar.ax.set_xlabel("popularity")
nx.draw_networkx_labels(G, pos, font_size=params.font_size)
nx.draw_networkx_edges(
G,
pos,
arrows=True,
2025-02-12 17:23:18 +01:00
edge_color=[EDGECOLOR[G.edges()[*edge]["group"]] for edge in G.edges()],
arrowsize=params.arrow_size,
node_size=params.node_size,
width=params.edge_width,
2025-02-12 17:23:18 +01:00
style=[EDGESTYLE[G.edges()[*edge]["group"]] for edge in G.edges()],
arrowstyle=[ARROWSTYLE[G.edges()[*edge]["group"]] for edge in G.edges()],
connectionstyle="arc3,rad=0.12",
alpha=[1 - 0.08 * G.edges()[*edge]["rank"] for edge in G.edges()]
if params.weighting
else 1,
)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=300, transparent=True)
buf.seek(0)
encoded_image = base64.b64encode(buf.read()).decode("UTF-8")
plt.close()
return {"image": encoded_image}
analysis_router.add_api_route("/json", endpoint=sociogram_json, methods=["GET"])
2025-02-12 12:23:17 +01:00
analysis_router.add_api_route("/image", endpoint=render_sociogram, methods=["POST"])
if __name__ == "__main__":
with Session(engine) as session:
statement: SelectOfScalar[P] = select(func.count(P.id))
print("players in DB: ", session.exec(statement).first())
G = sociogram_data()
pos = nx.spring_layout(G, scale=1, k=2, iterations=50, seed=42)