From d9ad903798c22e8440d0005ba411e5511e63d2f8 Mon Sep 17 00:00:00 2001 From: julius Date: Fri, 21 Mar 2025 14:48:55 +0100 Subject: [PATCH] feat: add back files in new location --- cutt/__init__.py | 0 cutt/analysis.py | 275 ++++++++++++++++++++++++++++++++++++++++ cutt/db.py | 86 +++++++++++++ cutt/main.py | 174 ++++++++++++++++++++++++++ cutt/player.py | 47 +++++++ cutt/security.py | 318 +++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 900 insertions(+) create mode 100644 cutt/__init__.py create mode 100644 cutt/analysis.py create mode 100644 cutt/db.py create mode 100644 cutt/main.py create mode 100644 cutt/player.py create mode 100644 cutt/security.py diff --git a/cutt/__init__.py b/cutt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cutt/analysis.py b/cutt/analysis.py new file mode 100644 index 0000000..795482f --- /dev/null +++ b/cutt/analysis.py @@ -0,0 +1,275 @@ +import io +import base64 +from typing import Annotated +from fastapi import APIRouter, HTTPException, Security, status +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field +from sqlmodel import Session, func, select +from sqlmodel.sql.expression import SelectOfScalar +from cutt.db import Chemistry, MVPRanking, Player, Team, engine +import networkx as nx +import numpy as np +import matplotlib + +from cutt.security import TeamScopedRequest, verify_team_scope + +matplotlib.use("agg") +import matplotlib.pyplot as plt + + +analysis_router = APIRouter(prefix="/analysis") + + +C = Chemistry +R = MVPRanking +P = Player + + +def sociogram_json(): + nodes = [] + necessary_nodes = set() + edges = [] + players = {} + with Session(engine) as session: + for p in session.exec(select(P)).fetchall(): + nodes.append({"id": p.display_name, "label": p.display_name}) + players[p.id] = p.display_name + subquery = ( + select(C.user, func.max(C.time).label("latest")).group_by(C.user).subquery() + ) + statement2 = select(C).join( + subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest) + ) + for c in session.exec(statement2): + # G.add_node(c.user) + necessary_nodes.add(c.user) + for p in [players[p_id] for p_id in c.love]: + # G.add_edge(c.user, p) + # p_id = session.exec(select(P.id).where(P.name == p)).one() + necessary_nodes.add(p) + edges.append({"from": players[c.user], "to": p, "relation": "likes"}) + for p in [players[p_id] for p_id in c.hate]: + edges.append({"from": players[c.user], "to": p, "relation": "dislikes"}) + # nodes = [n for n in nodes if n["name"] in necessary_nodes] + return JSONResponse({"nodes": nodes, "edges": edges}) + + +def graph_json( + request: Annotated[ + TeamScopedRequest, Security(verify_team_scope, scopes=["analysis"]) + ], +): + nodes = [] + edges = [] + player_map = {} + with Session(engine) as session: + statement = select(Team).where(Team.id == request.team_id) + players = [t.players for t in session.exec(statement)][0] + if not players: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + for p in players: + player_map[p.id] = p.display_name + nodes.append({"id": p.display_name, "label": p.display_name}) + + subquery = ( + select(C.user, func.max(C.time).label("latest")) + .where(C.team == request.team_id) + .group_by(C.user) + .subquery() + ) + statement2 = select(C).join( + subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest) + ) + + for c in session.exec(statement2): + user = player_map[c.user] + for i, p_id in enumerate(c.love): + p = player_map[p_id] + edges.append( + { + "id": f"{user}->{p}", + "source": user, + "target": p, + "size": max(1.0 - 0.1 * i, 0.3), + "data": { + "relation": 2, + "origSize": max(1.0 - 0.1 * i, 0.3), + "origFill": "#bed4ff", + }, + } + ) + for p_id in c.hate: + p = player_map[p_id] + edges.append( + { + "id": f"{user}-x>{p}", + "source": user, + "target": p, + "size": 0.3, + "data": {"relation": 0, "origSize": 0.3, "origFill": "#ff7c7c"}, + "fill": "#ff7c7c", + } + ) + + if not edges: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="no entries found" + ) + G = nx.DiGraph() + G.add_weighted_edges_from([(e["source"], e["target"], e["size"]) for e in edges]) + in_degrees = G.in_degree(weight="weight") + nodes = [ + dict(node, **{"data": {"inDegree": in_degrees[node["id"]]}}) for node in nodes + ] + return JSONResponse({"nodes": nodes, "edges": edges}) + + +def sociogram_data(show: int | None = 2): + G = nx.DiGraph() + with Session(engine) as session: + players = {} + for p in session.exec(select(P)).fetchall(): + G.add_node(p.display_name) + players[p.id] = p.display_name + subquery = ( + select(C.user, func.max(C.time).label("latest")).group_by(C.user).subquery() + ) + statement2 = ( + select(C) + # .where(C.user.in_(["Kruse", "Franz", "ck"])) + .join(subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest)) + ) + for c in session.exec(statement2): + if show >= 1: + for i, p_id in enumerate(c.love): + p = players[p_id] + G.add_edge(c.user, p, group="love", rank=i, popularity=1 - 0.08 * i) + if show <= 1: + for i, p_id in enumerate(c.hate): + p = players[p_id] + G.add_edge(c.user, p, group="hate", rank=8, popularity=-0.16) + return G + + +class Params(BaseModel): + node_size: int | None = Field(default=2400, alias="nodeSize") + font_size: int | None = Field(default=10, alias="fontSize") + arrow_size: int | None = Field(default=20, alias="arrowSize") + edge_width: float | None = Field(default=1, alias="edgeWidth") + distance: float | None = 0.2 + weighting: bool | None = True + popularity: bool | None = True + show: int | None = 2 + + +ARROWSTYLE = {"love": "-|>", "hate": "-|>"} +EDGESTYLE = {"love": "-", "hate": ":"} +EDGECOLOR = {"love": "#404040", "hate": "#cc0000"} + + +async def render_sociogram(params: Params): + plt.figure(figsize=(16, 10), facecolor="none") + ax = plt.gca() + ax.set_facecolor("none") # Set the axis face color to none (transparent) + ax.axis("off") # Turn off axis ticks and frames + + G = sociogram_data(show=params.show) + pos = nx.spring_layout(G, scale=2, k=params.distance, iterations=50, seed=None) + nodes = nx.draw_networkx_nodes( + G, + pos, + node_color=[ + v for k, v in G.in_degree(weight="popularity" if params.weighting else None) + ] + if params.popularity + else "#99ccff", + edgecolors="#404040", + linewidths=0, + # node_shape="8", + node_size=params.node_size, + cmap="coolwarm", + alpha=0.86, + ) + if params.popularity: + cbar = plt.colorbar(nodes) + cbar.ax.set_xlabel("popularity") + nx.draw_networkx_labels(G, pos, font_size=params.font_size) + nx.draw_networkx_edges( + G, + pos, + arrows=True, + edge_color=[EDGECOLOR[G.edges()[*edge]["group"]] for edge in G.edges()], + arrowsize=params.arrow_size, + node_size=params.node_size, + width=params.edge_width, + style=[EDGESTYLE[G.edges()[*edge]["group"]] for edge in G.edges()], + arrowstyle=[ARROWSTYLE[G.edges()[*edge]["group"]] for edge in G.edges()], + connectionstyle="arc3,rad=0.12", + alpha=[1 - 0.08 * G.edges()[*edge]["rank"] for edge in G.edges()] + if params.weighting + else 1, + ) + + buf = io.BytesIO() + plt.savefig(buf, format="png", bbox_inches="tight", dpi=300, transparent=True) + buf.seek(0) + encoded_image = base64.b64encode(buf.read()).decode("UTF-8") + plt.close() + + return {"image": encoded_image} + + +def mvp( + request: Annotated[ + TeamScopedRequest, Security(verify_team_scope, scopes=["analysis"]) + ], +): + ranks = dict() + with Session(engine) as session: + statement = select(Team).where(Team.id == request.team_id) + players = [t.players for t in session.exec(statement)][0] + if not players: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + player_map = {p.id: p.display_name for p in players} + subquery = ( + select(R.user, func.max(R.time).label("latest")) + .where(R.team == request.team_id) + .group_by(R.user) + .subquery() + ) + statement2 = select(R).join( + subquery, (R.user == subquery.c.user) & (R.time == subquery.c.latest) + ) + for r in session.exec(statement2): + for i, p_id in enumerate(r.mvps): + p = player_map[p_id] + ranks[p] = ranks.get(p, []) + [i + 1] + + if not ranks: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="no entries found" + ) + return [ + { + "name": p, + "rank": f"{np.mean(v):.02f}", + "std": f"{np.std(v):.02f}", + "n": len(v), + } + for p, v in ranks.items() + ] + + +# analysis_router.add_api_route("/json", endpoint=sociogram_json, methods=["GET"]) +analysis_router.add_api_route( + "/graph_json/{team_id}", endpoint=graph_json, methods=["GET"] +) +analysis_router.add_api_route("/image", endpoint=render_sociogram, methods=["POST"]) +analysis_router.add_api_route("/mvp/{team_id}", endpoint=mvp, methods=["GET"]) + +if __name__ == "__main__": + with Session(engine) as session: + statement: SelectOfScalar[P] = select(func.count(P.id)) + print("players in DB: ", session.exec(statement).first()) + G = sociogram_data() + pos = nx.spring_layout(G, scale=1, k=2, iterations=50, seed=42) diff --git a/cutt/db.py b/cutt/db.py new file mode 100644 index 0000000..8a02343 --- /dev/null +++ b/cutt/db.py @@ -0,0 +1,86 @@ +from datetime import datetime, timezone +from sqlmodel import ( + ARRAY, + Column, + Integer, + Relationship, + SQLModel, + Field, + create_engine, +) + +with open("db.secrets", "r") as f: + db_secrets = f.readline().strip() + +engine = create_engine( + db_secrets, + pool_timeout=20, + pool_size=2, + connect_args={"connect_timeout": 8}, +) +del db_secrets + + +def utctime(): + return datetime.now(tz=timezone.utc) + + +class PlayerTeamLink(SQLModel, table=True): + team_id: int | None = Field(default=None, foreign_key="team.id", primary_key=True) + player_id: int | None = Field( + default=None, foreign_key="player.id", primary_key=True + ) + + +class Team(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + location: str | None + country: str | None + players: list["Player"] | None = Relationship( + back_populates="teams", link_model=PlayerTeamLink + ) + + +class Player(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + username: str = Field(default=None, unique=True) + display_name: str + email: str | None = None + full_name: str | None = None + disabled: bool | None = None + hashed_password: str | None = None + number: str | None = None + teams: list[Team] = Relationship( + back_populates="players", link_model=PlayerTeamLink + ) + scopes: str = "" + + +class Chemistry(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + time: datetime | None = Field(default_factory=utctime) + user: int = Field(default=None, foreign_key="player.id") + hate: list[int] = Field(sa_column=Column(ARRAY(Integer))) + undecided: list[int] = Field(sa_column=Column(ARRAY(Integer))) + love: list[int] = Field(sa_column=Column(ARRAY(Integer))) + team: int = Field(default=None, foreign_key="team.id") + + +class MVPRanking(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + time: datetime | None = Field(default_factory=utctime) + user: int = Field(default=None, foreign_key="player.id") + mvps: list[int] = Field(sa_column=Column(ARRAY(Integer))) + team: int = Field(default=None, foreign_key="team.id") + + +class TokenDB(SQLModel, table=True): + token: str = Field(index=True, primary_key=True) + used: bool | None = False + updated_at: datetime | None = Field( + default_factory=utctime, sa_column_kwargs={"onupdate": utctime} + ) + + +SQLModel.metadata.create_all(engine) diff --git a/cutt/main.py b/cutt/main.py new file mode 100644 index 0000000..2fb90f5 --- /dev/null +++ b/cutt/main.py @@ -0,0 +1,174 @@ +from typing import Annotated +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Security, status +from fastapi.responses import JSONResponse +from fastapi.staticfiles import StaticFiles +from cutt.db import Player, Team, Chemistry, MVPRanking, engine +from sqlmodel import ( + Session, + func, + select, +) +from fastapi.middleware.cors import CORSMiddleware +from cutt.analysis import analysis_router +from cutt.security import ( + get_current_active_user, + login_for_access_token, + logout, + set_first_password, +) +from cutt.player import player_router + +C = Chemistry +R = MVPRanking +P = Player + +app = FastAPI(title="cutt") +api_router = APIRouter(prefix="/api") +origins = [ + "https://cutt.0124816.xyz", + "http://localhost:5173", +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +def add_team(team: Team): + with Session(engine) as session: + session.add(team) + session.commit() + + +def list_teams(): + with Session(engine) as session: + statement = select(Team) + return session.exec(statement).fetchall() + + +team_router = APIRouter( + prefix="/team", + dependencies=[Security(get_current_active_user, scopes=["admin"])], +) +team_router.add_api_route("/list", endpoint=list_teams, methods=["GET"]) +team_router.add_api_route("/add", endpoint=add_team, methods=["POST"]) + + +wrong_user_id_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="you're not who you think you are...", +) +somethings_fishy = HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="something up..." +) + + +@api_router.put("/mvps") +def submit_mvps( + mvps: MVPRanking, + user: Annotated[Player, Depends(get_current_active_user)], +): + if user.id == mvps.user: + with Session(engine) as session: + statement = select(Team).where(Team.id == mvps.team) + players = [t.players for t in session.exec(statement)][0] + if players: + player_ids = {p.id for p in players} + if player_ids >= set(mvps.mvps): + session.add(mvps) + session.commit() + return JSONResponse("success!") + raise somethings_fishy + else: + raise wrong_user_id_exception + + +@api_router.get("/mvps") +def get_mvps( + user: Annotated[Player, Depends(get_current_active_user)], +): + with Session(engine) as session: + subquery = ( + select(R.user, func.max(R.time).label("latest")) + .where(R.user == user.id) + .group_by(R.user) + .subquery() + ) + statement2 = select(R).join( + subquery, (R.user == subquery.c.user) & (R.time == subquery.c.latest) + ) + mvps = session.exec(statement2).one_or_none() + if mvps: + return mvps + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="no previous state was found", + ) + + +@api_router.put("/chemistry") +def submit_chemistry( + chemistry: Chemistry, user: Annotated[Player, Depends(get_current_active_user)] +): + if user.id == chemistry.user: + with Session(engine) as session: + statement = select(Team).where(Team.id == chemistry.team) + players = [t.players for t in session.exec(statement)][0] + if players: + player_ids = {p.id for p in players} + if player_ids >= ( + set(chemistry.love) | set(chemistry.hate) | set(chemistry.undecided) + ): + session.add(chemistry) + session.commit() + return JSONResponse("success!") + raise somethings_fishy + else: + raise wrong_user_id_exception + + +@api_router.get("/chemistry") +def get_chemistry(user: Annotated[Player, Depends(get_current_active_user)]): + with Session(engine) as session: + subquery = ( + select(C.user, func.max(C.time).label("latest")) + .where(C.user == user.id) + .group_by(C.user) + .subquery() + ) + statement2 = select(C).join( + subquery, (C.user == subquery.c.user) & (C.time == subquery.c.latest) + ) + chemistry = session.exec(statement2).one_or_none() + if chemistry: + return chemistry + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="no previous state was found", + ) + + +class SPAStaticFiles(StaticFiles): + async def get_response(self, path: str, scope): + response = await super().get_response(path, scope) + if response.status_code == 404: + response = await super().get_response(".", scope) + return response + + +api_router.include_router( + player_router, dependencies=[Depends(get_current_active_user)] +) +api_router.include_router(team_router, dependencies=[Depends(get_current_active_user)]) +api_router.include_router(analysis_router) +api_router.add_api_route("/token", endpoint=login_for_access_token, methods=["POST"]) +api_router.add_api_route("/set_password", endpoint=set_first_password, methods=["POST"]) +api_router.add_api_route("/logout", endpoint=logout, methods=["POST"]) +app.include_router(api_router) +app.mount("/", SPAStaticFiles(directory="dist", html=True), name="site") diff --git a/cutt/player.py b/cutt/player.py new file mode 100644 index 0000000..b50762e --- /dev/null +++ b/cutt/player.py @@ -0,0 +1,47 @@ +from typing import Annotated +from fastapi import APIRouter, Depends +from sqlmodel import Session, select + +from cutt.db import Player, Team, engine +from cutt.security import change_password, get_current_active_user, read_player_me + +P = Player + + +def add_player(player: P): + with Session(engine) as session: + session.add(player) + session.commit() + + +def add_players(players: list[P]): + with Session(engine) as session: + for player in players: + session.add(player) + session.commit() + + +async def list_players(team_id: int): + with Session(engine) as session: + statement = select(Team).where(Team.id == team_id) + players = [t.players for t in session.exec(statement)][0] + if players: + return [ + player.model_dump(include={"id", "display_name", "number"}) + for player in players + ] + + +async def read_teams_me(user: Annotated[P, Depends(get_current_active_user)]): + with Session(engine) as session: + return [p.teams for p in session.exec(select(P).where(P.id == user.id))][0] + + +player_router = APIRouter(prefix="/player") +player_router.add_api_route("/list/{team_id}", endpoint=list_players, methods=["GET"]) +player_router.add_api_route("/add", endpoint=add_player, methods=["POST"]) +player_router.add_api_route("/me", endpoint=read_player_me, methods=["GET"]) +player_router.add_api_route("/me/teams", endpoint=read_teams_me, methods=["GET"]) +player_router.add_api_route( + "/change_password", endpoint=change_password, methods=["POST"] +) diff --git a/cutt/security.py b/cutt/security.py new file mode 100644 index 0000000..703ec86 --- /dev/null +++ b/cutt/security.py @@ -0,0 +1,318 @@ +from datetime import timedelta, timezone, datetime +from typing import Annotated +from fastapi import Depends, HTTPException, Request, Response, status +from fastapi.responses import PlainTextResponse +from pydantic import BaseModel, ValidationError +import jwt +from jwt.exceptions import ExpiredSignatureError, InvalidTokenError +from sqlmodel import Session, select +from cutt.db import TokenDB, engine, Player +from fastapi.security import ( + OAuth2PasswordBearer, + OAuth2PasswordRequestForm, + SecurityScopes, +) +from pydantic_settings import BaseSettings, SettingsConfigDict +from passlib.context import CryptContext +from sqlalchemy.exc import OperationalError + + +class Config(BaseSettings): + secret_key: str = "" + access_token_expire_minutes: int = 15 + model_config = SettingsConfigDict( + env_file=".env", env_file_encoding="utf-8", extra="ignore" + ) + + +config = Config() + + +class Token(BaseModel): + access_token: str + + +class TokenData(BaseModel): + username: str | None = None + scopes: list[str] = [] + + +pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") + + +class CookieOAuth2(OAuth2PasswordBearer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def __call__(self, request: Request): + cookie_token = request.cookies.get("access_token") + if cookie_token: + return cookie_token + else: + header_token = await super().__call__(request) + if header_token: + return header_token + else: + raise HTTPException(status_code=401) + + +oauth2_scheme = CookieOAuth2( + tokenUrl="api/token", + scopes={ + "analysis": "Access the results.", + "admin": "Maintain DB etc.", + }, +) + + +def verify_password(plain_password, hashed_password): + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password): + return pwd_context.hash(password) + + +def get_user(username: str | None): + if username: + try: + with Session(engine) as session: + return session.exec( + select(Player).where(Player.username == username) + ).one_or_none() + except OperationalError: + return + + +def authenticate_user(username: str, password: str): + user = get_user(username) + if not user: + return False + if not verify_password(password, user.hashed_password): + return False + return user + + +def create_access_token(data: dict, expires_delta: timedelta | None = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta( + minutes=config.access_token_expire_minutes + ) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, config.secret_key, algorithm="HS256") + return encoded_jwt + + +async def get_current_user( + token: Annotated[str, Depends(oauth2_scheme)], + security_scopes: SecurityScopes, +): + if security_scopes.scopes: + authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' + else: + authenticate_value = "Bearer" + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": authenticate_value}, + ) + # access_token = request.cookies.get("access_token") + access_token = token + if not access_token: + raise credentials_exception + try: + payload = jwt.decode(access_token, config.secret_key, algorithms=["HS256"]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_scopes = payload.get("scopes", []) + token_data = TokenData(username=username, scopes=token_scopes) + except ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Access token expired", + headers={"WWW-Authenticate": authenticate_value}, + ) + except (InvalidTokenError, ValidationError): + raise credentials_exception + user = get_user(username=token_data.username) + if user is None: + raise credentials_exception + allowed_scopes = set(user.scopes.split()) + for scope in security_scopes.scopes: + if scope not in allowed_scopes or scope not in token_data.scopes: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + headers={"WWW-Authenticate": authenticate_value}, + ) + return user + + +async def get_current_active_user( + current_user: Annotated[Player, Depends(get_current_user)], +): + if current_user.disabled: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + +class TeamScopedRequest(BaseModel): + user: Player + team_id: int + + +async def verify_team_scope( + team_id: int, user: Annotated[Player, Depends(get_current_active_user)] +): + allowed_scopes = set(user.scopes.split()) + if f"team:{team_id}" not in allowed_scopes: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + ) + else: + return TeamScopedRequest(user=user, team_id=team_id) + + +async def login_for_access_token( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], response: Response +) -> Token: + user = authenticate_user(form_data.username, form_data.password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + allowed_scopes = set(user.scopes.split()) + requested_scopes = set(form_data.scopes) + access_token = create_access_token( + data={"sub": user.username, "scopes": list(allowed_scopes)} + ) + response.set_cookie( + "access_token", + value=access_token, + httponly=True, + samesite="strict", + max_age=config.access_token_expire_minutes * 60, + ) + return Token(access_token=access_token) + + +async def logout(response: Response): + response.set_cookie("access_token", "", expires=0, httponly=True, samesite="strict") + return {"message": "Successfully logged out"} + + +def generate_one_time_token(username): + user = get_user(username) + if user: + expire = timedelta(days=7) + token = create_access_token( + data={"sub": username, "name": user.display_name}, + expires_delta=expire, + ) + return token + + +class FirstPassword(BaseModel): + token: str + password: str + + +async def set_first_password(req: FirstPassword): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate token", + ) + with Session(engine) as session: + token_in_db = session.exec( + select(TokenDB) + .where(TokenDB.token == req.token) + .where(TokenDB.used == False) + ).one_or_none() + if token_in_db: + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate token", + ) + try: + payload = jwt.decode(req.token, config.secret_key, algorithms=["HS256"]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + except ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Access token expired", + ) + except (InvalidTokenError, ValidationError): + raise credentials_exception + + user = get_user(username) + if user: + user.hashed_password = get_password_hash(req.password) + session.add(user) + token_in_db.used = True + session.add(token_in_db) + session.commit() + return Response( + "Password set successfully", status_code=status.HTTP_200_OK + ) + elif session.exec( + select(TokenDB) + .where(TokenDB.token == req.token) + .where(TokenDB.used == True) + ).one_or_none(): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token already used", + ) + else: + raise credentials_exception + + +class ChangedPassword(BaseModel): + current_password: str + new_password: str + + +async def change_password( + request: ChangedPassword, + user: Annotated[Player, Depends(get_current_active_user)], +): + if ( + request.new_password + and user.hashed_password + and verify_password(request.current_password, user.hashed_password) + ): + with Session(engine) as session: + user.hashed_password = get_password_hash(request.new_password) + session.add(user) + session.commit() + return PlainTextResponse( + "Password changed successfully", + status_code=status.HTTP_200_OK, + media_type="text/plain", + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Wrong password", + ) + + +async def read_player_me( + current_user: Annotated[Player, Depends(get_current_active_user)], +): + return current_user.model_dump(exclude={"hashed_password", "disabled"}) + + +async def read_own_items( + current_user: Annotated[Player, Depends(get_current_active_user)], +): + return [{"item_id": "Foo", "owner": current_user.username}]