2025-02-14 20:10:21 +01:00
|
|
|
from datetime import timedelta, timezone, datetime
|
2025-02-10 16:12:31 +01:00
|
|
|
import io
|
2025-02-14 20:10:21 +01:00
|
|
|
from typing import Annotated
|
2025-02-10 16:12:31 +01:00
|
|
|
import base64
|
2025-02-14 20:10:21 +01:00
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
2025-02-10 16:12:31 +01:00
|
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from pydantic import BaseModel, Field
|
2025-02-14 20:10:21 +01:00
|
|
|
import jwt
|
|
|
|
from jwt.exceptions import InvalidTokenError
|
2025-02-10 16:12:31 +01:00
|
|
|
from sqlmodel import Session, func, select
|
|
|
|
from sqlmodel.sql.expression import SelectOfScalar
|
2025-02-14 20:10:21 +01:00
|
|
|
from db import Chemistry, Player, engine, User
|
2025-02-10 16:12:31 +01:00
|
|
|
import networkx as nx
|
2025-02-14 20:10:21 +01:00
|
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
from passlib.context import CryptContext
|
2025-02-10 16:12:31 +01:00
|
|
|
import matplotlib
|
|
|
|
|
|
|
|
matplotlib.use("agg")
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
2025-02-14 20:10:21 +01:00
|
|
|
|
|
|
|
class Config(BaseSettings):
|
|
|
|
secret_key: str = ""
|
|
|
|
access_token_expire_minutes: int = 30
|
|
|
|
model_config = SettingsConfigDict(
|
|
|
|
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
config = Config()
|
|
|
|
|
|
|
|
|
|
|
|
class Token(BaseModel):
|
|
|
|
access_token: str
|
|
|
|
token_type: str
|
|
|
|
|
|
|
|
|
|
|
|
class TokenData(BaseModel):
|
|
|
|
username: str | None = None
|
|
|
|
|
|
|
|
|
|
|
|
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
|
|
|
|
|
|
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
|
|
|
|
analysis_router = APIRouter(prefix="/analysis", dependencies=[Depends(oauth2_scheme)])
|
|
|
|
|
|
|
|
|
|
|
|
def verify_password(plain_password, hashed_password):
|
|
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
|
|
|
|
|
|
def get_password_hash(password):
|
|
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
|
|
|
|
|
|
def get_user(username: str):
|
|
|
|
with Session(engine) as session:
|
|
|
|
return session.exec(select(User).where(User.username == username)).one_or_none()
|
|
|
|
|
|
|
|
|
|
|
|
def authenticate_user(username: str, password: str):
|
|
|
|
user = get_user(username)
|
|
|
|
if not user:
|
|
|
|
return False
|
|
|
|
if not verify_password(password, user.hashed_password):
|
|
|
|
return False
|
|
|
|
return user
|
|
|
|
|
|
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
|
|
|
to_encode = data.copy()
|
|
|
|
if expires_delta:
|
|
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
|
|
else:
|
|
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
|
|
|
|
to_encode.update({"exp": expire})
|
|
|
|
encoded_jwt = jwt.encode(to_encode, config.secret_key, algorithm="HS256")
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
|
|
|
|
|
|
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
|
|
|
|
credentials_exception = HTTPException(
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
detail="Could not validate credentials",
|
|
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
|
|
)
|
|
|
|
try:
|
|
|
|
payload = jwt.decode(token, config.secret_key, algorithms=["HS256"])
|
|
|
|
username: str = payload.get("sub")
|
|
|
|
if username is None:
|
|
|
|
raise credentials_exception
|
|
|
|
token_data = TokenData(username=username)
|
|
|
|
except InvalidTokenError:
|
|
|
|
raise credentials_exception
|
|
|
|
user = get_user(fake_users_db, username=token_data.username)
|
|
|
|
if user is None:
|
|
|
|
raise credentials_exception
|
|
|
|
return user
|
|
|
|
|
|
|
|
|
|
|
|
async def get_current_active_user(
|
|
|
|
current_user: Annotated[User, Depends(get_current_user)],
|
|
|
|
):
|
|
|
|
if current_user.disabled:
|
|
|
|
raise HTTPException(status_code=400, detail="Inactive user")
|
|
|
|
return current_user
|
|
|
|
|
|
|
|
|
|
|
|
@analysis_router.post("/token")
|
|
|
|
async def login_for_access_token(
|
|
|
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
|
|
|
) -> Token:
|
|
|
|
user = authenticate_user(form_data.username, form_data.password)
|
|
|
|
if not user:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
detail="Incorrect username or password",
|
|
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
|
|
)
|
|
|
|
access_token_expires = timedelta(minutes=config.access_token_expire_minutes)
|
|
|
|
access_token = create_access_token(
|
|
|
|
data={"sub": user.username}, expires_delta=access_token_expires
|
|
|
|
)
|
|
|
|
return Token(access_token=access_token, token_type="bearer")
|
2025-02-10 16:12:31 +01:00
|
|
|
|
|
|
|
|
|
|
|
C = Chemistry
|
|
|
|
P = Player
|
|
|
|
|
|
|
|
|
|
|
|
def sociogram_json():
|
|
|
|
nodes = []
|
|
|
|
necessary_nodes = set()
|
|
|
|
links = []
|
|
|
|
with Session(engine) as session:
|
|
|
|
for p in session.exec(select(P)).fetchall():
|
|
|
|
nodes.append({"id": p.name, "appearance": 1})
|
|
|
|
subquery = (
|
|
|
|
select(C.user, func.max(C.time).label("latest"))
|
|
|
|
.where(C.time > datetime(2025, 2, 1, 10))
|
|
|
|
.group_by(C.user)
|
|
|
|
.subquery()
|
|
|
|
)
|
|
|
|
statement2 = select(C).join(
|
|
|
|
subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest)
|
|
|
|
)
|
|
|
|
for c in session.exec(statement2):
|
|
|
|
# G.add_node(c.user)
|
|
|
|
necessary_nodes.add(c.user)
|
|
|
|
for p in c.love:
|
|
|
|
# G.add_edge(c.user, p)
|
|
|
|
# p_id = session.exec(select(P.id).where(P.name == p)).one()
|
|
|
|
necessary_nodes.add(p)
|
|
|
|
links.append({"source": c.user, "target": p})
|
|
|
|
# nodes = [n for n in nodes if n["name"] in necessary_nodes]
|
|
|
|
return JSONResponse({"nodes": nodes, "links": links})
|
|
|
|
|
|
|
|
|
2025-02-12 17:54:07 +01:00
|
|
|
def sociogram_data(show: int | None = 2):
|
2025-02-10 16:12:31 +01:00
|
|
|
G = nx.DiGraph()
|
|
|
|
with Session(engine) as session:
|
|
|
|
for p in session.exec(select(P)).fetchall():
|
|
|
|
G.add_node(p.name)
|
|
|
|
subquery = (
|
|
|
|
select(C.user, func.max(C.time).label("latest"))
|
|
|
|
.where(C.time > datetime(2025, 2, 1, 10))
|
|
|
|
.group_by(C.user)
|
|
|
|
.subquery()
|
|
|
|
)
|
|
|
|
statement2 = (
|
|
|
|
select(C)
|
2025-02-12 17:54:07 +01:00
|
|
|
# .where(C.user.in_(["Kruse", "Franz", "ck"]))
|
2025-02-10 16:12:31 +01:00
|
|
|
.join(subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest))
|
|
|
|
)
|
|
|
|
for c in session.exec(statement2):
|
2025-02-12 17:54:07 +01:00
|
|
|
if show >= 1:
|
|
|
|
for i, p in enumerate(c.love):
|
|
|
|
G.add_edge(c.user, p, group="love", rank=i, popularity=1 - 0.08 * i)
|
|
|
|
if show <= 1:
|
2025-02-12 17:23:18 +01:00
|
|
|
for i, p in enumerate(c.hate):
|
|
|
|
G.add_edge(c.user, p, group="hate", rank=8, popularity=-0.16)
|
2025-02-10 16:12:31 +01:00
|
|
|
return G
|
|
|
|
|
|
|
|
|
|
|
|
class Params(BaseModel):
|
|
|
|
node_size: int | None = Field(default=2400, alias="nodeSize")
|
|
|
|
font_size: int | None = Field(default=10, alias="fontSize")
|
|
|
|
arrow_size: int | None = Field(default=20, alias="arrowSize")
|
|
|
|
edge_width: float | None = Field(default=1, alias="edgeWidth")
|
2025-02-11 14:14:23 +01:00
|
|
|
distance: float | None = 0.2
|
2025-02-12 15:39:52 +01:00
|
|
|
weighting: bool | None = True
|
|
|
|
popularity: bool | None = True
|
2025-02-12 17:54:07 +01:00
|
|
|
show: int | None = 2
|
2025-02-12 17:23:18 +01:00
|
|
|
|
|
|
|
|
|
|
|
ARROWSTYLE = {"love": "-|>", "hate": "-|>"}
|
|
|
|
EDGESTYLE = {"love": "-", "hate": ":"}
|
|
|
|
EDGECOLOR = {"love": "#404040", "hate": "#cc0000"}
|
2025-02-10 16:12:31 +01:00
|
|
|
|
|
|
|
|
2025-02-12 12:23:17 +01:00
|
|
|
async def render_sociogram(params: Params):
|
2025-02-10 16:12:31 +01:00
|
|
|
plt.figure(figsize=(16, 10), facecolor="none")
|
|
|
|
ax = plt.gca()
|
|
|
|
ax.set_facecolor("none") # Set the axis face color to none (transparent)
|
|
|
|
ax.axis("off") # Turn off axis ticks and frames
|
|
|
|
|
2025-02-12 17:54:07 +01:00
|
|
|
G = sociogram_data(show=params.show)
|
2025-02-11 14:14:23 +01:00
|
|
|
pos = nx.spring_layout(G, scale=2, k=params.distance, iterations=50, seed=None)
|
2025-02-12 15:39:52 +01:00
|
|
|
nodes = nx.draw_networkx_nodes(
|
2025-02-10 16:12:31 +01:00
|
|
|
G,
|
|
|
|
pos,
|
2025-02-12 16:08:43 +01:00
|
|
|
node_color=[
|
|
|
|
v for k, v in G.in_degree(weight="popularity" if params.weighting else None)
|
|
|
|
]
|
2025-02-12 15:39:52 +01:00
|
|
|
if params.popularity
|
|
|
|
else "#99ccff",
|
2025-02-10 16:12:31 +01:00
|
|
|
edgecolors="#404040",
|
2025-02-12 15:39:52 +01:00
|
|
|
linewidths=0,
|
|
|
|
# node_shape="8",
|
2025-02-10 16:12:31 +01:00
|
|
|
node_size=params.node_size,
|
2025-02-12 15:39:52 +01:00
|
|
|
cmap="coolwarm",
|
2025-02-10 16:12:31 +01:00
|
|
|
alpha=0.86,
|
|
|
|
)
|
2025-02-12 15:39:52 +01:00
|
|
|
if params.popularity:
|
2025-02-12 17:23:18 +01:00
|
|
|
cbar = plt.colorbar(nodes)
|
|
|
|
cbar.ax.set_xlabel("popularity")
|
2025-02-10 16:12:31 +01:00
|
|
|
nx.draw_networkx_labels(G, pos, font_size=params.font_size)
|
|
|
|
nx.draw_networkx_edges(
|
|
|
|
G,
|
|
|
|
pos,
|
|
|
|
arrows=True,
|
2025-02-12 17:23:18 +01:00
|
|
|
edge_color=[EDGECOLOR[G.edges()[*edge]["group"]] for edge in G.edges()],
|
2025-02-10 16:12:31 +01:00
|
|
|
arrowsize=params.arrow_size,
|
|
|
|
node_size=params.node_size,
|
|
|
|
width=params.edge_width,
|
2025-02-12 17:23:18 +01:00
|
|
|
style=[EDGESTYLE[G.edges()[*edge]["group"]] for edge in G.edges()],
|
|
|
|
arrowstyle=[ARROWSTYLE[G.edges()[*edge]["group"]] for edge in G.edges()],
|
|
|
|
connectionstyle="arc3,rad=0.12",
|
2025-02-12 15:39:52 +01:00
|
|
|
alpha=[1 - 0.08 * G.edges()[*edge]["rank"] for edge in G.edges()]
|
|
|
|
if params.weighting
|
|
|
|
else 1,
|
2025-02-10 16:12:31 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
buf = io.BytesIO()
|
|
|
|
plt.savefig(buf, format="png", bbox_inches="tight", dpi=300, transparent=True)
|
|
|
|
buf.seek(0)
|
|
|
|
encoded_image = base64.b64encode(buf.read()).decode("UTF-8")
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
return {"image": encoded_image}
|
|
|
|
|
|
|
|
|
|
|
|
analysis_router.add_api_route("/json", endpoint=sociogram_json, methods=["GET"])
|
2025-02-12 12:23:17 +01:00
|
|
|
analysis_router.add_api_route("/image", endpoint=render_sociogram, methods=["POST"])
|
2025-02-10 16:12:31 +01:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
with Session(engine) as session:
|
|
|
|
statement: SelectOfScalar[P] = select(func.count(P.id))
|
|
|
|
print("players in DB: ", session.exec(statement).first())
|
|
|
|
G = sociogram_data()
|
|
|
|
pos = nx.spring_layout(G, scale=1, k=2, iterations=50, seed=42)
|