cutt/analysis.py

238 lines
8.2 KiB
Python

import io
import base64
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from sqlmodel import Session, func, select
from sqlmodel.sql.expression import SelectOfScalar
from db import Chemistry, MVPRanking, Player, engine
import networkx as nx
import numpy as np
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
analysis_router = APIRouter(prefix="/analysis")
C = Chemistry
R = MVPRanking
P = Player
def sociogram_json():
nodes = []
necessary_nodes = set()
edges = []
players = {}
with Session(engine) as session:
for p in session.exec(select(P)).fetchall():
nodes.append({"id": p.display_name, "label": p.display_name})
players[p.id] = p.display_name
subquery = (
select(C.user, func.max(C.time).label("latest")).group_by(C.user).subquery()
)
statement2 = select(C).join(
subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest)
)
for c in session.exec(statement2):
# G.add_node(c.user)
necessary_nodes.add(c.user)
for p in [players[p_id] for p_id in c.love]:
# G.add_edge(c.user, p)
# p_id = session.exec(select(P.id).where(P.name == p)).one()
necessary_nodes.add(p)
edges.append({"from": players[c.user], "to": p, "relation": "likes"})
for p in [players[p_id] for p_id in c.hate]:
edges.append({"from": players[c.user], "to": p, "relation": "dislikes"})
# nodes = [n for n in nodes if n["name"] in necessary_nodes]
return JSONResponse({"nodes": nodes, "edges": edges})
def graph_json():
nodes = []
edges = []
players = {}
with Session(engine) as session:
for p in session.exec(select(P)).fetchall():
players[p.id] = p.display_name
nodes.append({"id": p.display_name, "label": p.display_name})
subquery = (
select(C.user, func.max(C.time).label("latest")).group_by(C.user).subquery()
)
statement2 = select(C).join(
subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest)
)
for c in session.exec(statement2):
user = players[c.user]
for i, p_id in enumerate(c.love):
p = players[p_id]
edges.append(
{
"id": f"{user}->{p}",
"source": user,
"target": p,
"size": max(1.0 - 0.1 * i, 0.3),
"data": {
"relation": 2,
"origSize": max(1.0 - 0.1 * i, 0.3),
"origFill": "#bed4ff",
},
}
)
for p_id in c.hate:
p = players[p_id]
edges.append(
{
"id": f"{user}-x>{p}",
"source": user,
"target": p,
"size": 0.3,
"data": {"relation": 0, "origSize": 0.3, "origFill": "#ff7c7c"},
"fill": "#ff7c7c",
}
)
G = nx.DiGraph()
G.add_weighted_edges_from([(e["source"], e["target"], e["size"]) for e in edges])
in_degrees = G.in_degree(weight="weight")
nodes = [
dict(node, **{"data": {"inDegree": in_degrees[node["id"]]}}) for node in nodes
]
return JSONResponse({"nodes": nodes, "edges": edges})
def sociogram_data(show: int | None = 2):
G = nx.DiGraph()
with Session(engine) as session:
players = {}
for p in session.exec(select(P)).fetchall():
G.add_node(p.display_name)
players[p.id] = p.display_name
subquery = (
select(C.user, func.max(C.time).label("latest")).group_by(C.user).subquery()
)
statement2 = (
select(C)
# .where(C.user.in_(["Kruse", "Franz", "ck"]))
.join(subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest))
)
for c in session.exec(statement2):
if show >= 1:
for i, p_id in enumerate(c.love):
p = players[p_id]
G.add_edge(c.user, p, group="love", rank=i, popularity=1 - 0.08 * i)
if show <= 1:
for i, p_id in enumerate(c.hate):
p = players[p_id]
G.add_edge(c.user, p, group="hate", rank=8, popularity=-0.16)
return G
class Params(BaseModel):
node_size: int | None = Field(default=2400, alias="nodeSize")
font_size: int | None = Field(default=10, alias="fontSize")
arrow_size: int | None = Field(default=20, alias="arrowSize")
edge_width: float | None = Field(default=1, alias="edgeWidth")
distance: float | None = 0.2
weighting: bool | None = True
popularity: bool | None = True
show: int | None = 2
ARROWSTYLE = {"love": "-|>", "hate": "-|>"}
EDGESTYLE = {"love": "-", "hate": ":"}
EDGECOLOR = {"love": "#404040", "hate": "#cc0000"}
async def render_sociogram(params: Params):
plt.figure(figsize=(16, 10), facecolor="none")
ax = plt.gca()
ax.set_facecolor("none") # Set the axis face color to none (transparent)
ax.axis("off") # Turn off axis ticks and frames
G = sociogram_data(show=params.show)
pos = nx.spring_layout(G, scale=2, k=params.distance, iterations=50, seed=None)
nodes = nx.draw_networkx_nodes(
G,
pos,
node_color=[
v for k, v in G.in_degree(weight="popularity" if params.weighting else None)
]
if params.popularity
else "#99ccff",
edgecolors="#404040",
linewidths=0,
# node_shape="8",
node_size=params.node_size,
cmap="coolwarm",
alpha=0.86,
)
if params.popularity:
cbar = plt.colorbar(nodes)
cbar.ax.set_xlabel("popularity")
nx.draw_networkx_labels(G, pos, font_size=params.font_size)
nx.draw_networkx_edges(
G,
pos,
arrows=True,
edge_color=[EDGECOLOR[G.edges()[*edge]["group"]] for edge in G.edges()],
arrowsize=params.arrow_size,
node_size=params.node_size,
width=params.edge_width,
style=[EDGESTYLE[G.edges()[*edge]["group"]] for edge in G.edges()],
arrowstyle=[ARROWSTYLE[G.edges()[*edge]["group"]] for edge in G.edges()],
connectionstyle="arc3,rad=0.12",
alpha=[1 - 0.08 * G.edges()[*edge]["rank"] for edge in G.edges()]
if params.weighting
else 1,
)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=300, transparent=True)
buf.seek(0)
encoded_image = base64.b64encode(buf.read()).decode("UTF-8")
plt.close()
return {"image": encoded_image}
def mvp():
ranks = dict()
with Session(engine) as session:
players = {p.id: p.display_name for p in session.exec(select(P)).fetchall()}
subquery = (
select(R.user, func.max(R.time).label("latest")).group_by(R.user).subquery()
)
statement2 = select(R).join(
subquery, (R.user == subquery.c.user) & (R.time == subquery.c.latest)
)
for r in session.exec(statement2):
for i, p_id in enumerate(r.mvps):
p = players[p_id]
ranks[p] = ranks.get(p, []) + [i + 1]
return [
{
"name": p,
"rank": f"{np.mean(v):.02f}",
"std": f"{np.std(v):.02f}",
"n": len(v),
}
for p, v in ranks.items()
]
analysis_router.add_api_route("/json", endpoint=sociogram_json, methods=["GET"])
analysis_router.add_api_route("/graph_json", endpoint=graph_json, methods=["GET"])
analysis_router.add_api_route("/image", endpoint=render_sociogram, methods=["POST"])
analysis_router.add_api_route("/mvp", endpoint=mvp, methods=["GET"])
if __name__ == "__main__":
with Session(engine) as session:
statement: SelectOfScalar[P] = select(func.count(P.id))
print("players in DB: ", session.exec(statement).first())
G = sociogram_data()
pos = nx.spring_layout(G, scale=1, k=2, iterations=50, seed=42)