diff --git a/db.py b/db.py index ac0bde1..e070cda 100644 --- a/db.py +++ b/db.py @@ -12,7 +12,9 @@ from sqlmodel import ( with open("db.secrets", "r") as f: db_secrets = f.readline().strip() -engine = create_engine(db_secrets, connect_args={"connect_timeout": 8}) +engine = create_engine( + db_secrets, pool_timeout=20, pool_size=2, connect_args={"connect_timeout": 8} +) del db_secrets diff --git a/main.py b/main.py index 57aa01a..959a1da 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ from analysis import analysis_router from security import ( get_current_active_user, login_for_access_token, + logout, read_users_me, read_own_items, ) @@ -64,21 +65,11 @@ def list_teams(): player_router = APIRouter(prefix="/player") player_router.add_api_route("/list", endpoint=list_players, methods=["GET"]) -player_router.add_api_route( - "/add", - endpoint=add_player, - methods=["POST"], - dependencies=[Depends(get_current_active_user)], -) +player_router.add_api_route("/add", endpoint=add_player, methods=["POST"]) team_router = APIRouter(prefix="/team") team_router.add_api_route("/list", endpoint=list_teams, methods=["GET"]) -team_router.add_api_route( - "/add", - endpoint=add_team, - methods=["POST"], - dependencies=[Depends(get_current_active_user)], -) +team_router.add_api_route("/add", endpoint=add_team, methods=["POST"]) @app.post("/mvps/", status_code=status.HTTP_200_OK) @@ -103,13 +94,16 @@ class SPAStaticFiles(StaticFiles): return response -api_router.include_router(player_router) -api_router.include_router(team_router) +api_router.include_router( + player_router, dependencies=[Depends(get_current_active_user)] +) +api_router.include_router(team_router, dependencies=[Depends(get_current_active_user)]) api_router.include_router( analysis_router, dependencies=[Security(get_current_active_user, scopes=["analysis"])], ) api_router.add_api_route("/token", endpoint=login_for_access_token, methods=["POST"]) +api_router.add_api_route("/logout", endpoint=logout, methods=["POST"]) api_router.add_api_route("/users/me/", endpoint=read_users_me, methods=["GET"]) api_router.add_api_route("/users/me/items/", endpoint=read_own_items, methods=["GET"]) app.include_router(api_router) diff --git a/security.py b/security.py index 0c6f5d1..a88b271 100644 --- a/security.py +++ b/security.py @@ -1,9 +1,9 @@ from datetime import timedelta, timezone, datetime from typing import Annotated -from fastapi import Depends, HTTPException, Response, status +from fastapi import Depends, HTTPException, Request, Response, status from pydantic import BaseModel, ValidationError import jwt -from jwt.exceptions import InvalidTokenError +from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from sqlmodel import Session, select from db import engine, User from fastapi.security import ( @@ -18,7 +18,7 @@ from sqlalchemy.exc import OperationalError class Config(BaseSettings): secret_key: str = "" - access_token_expire_minutes: int = 30 + access_token_expire_minutes: int = 15 model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", extra="ignore" ) @@ -29,7 +29,6 @@ config = Config() class Token(BaseModel): access_token: str - token_type: str class TokenData(BaseModel): @@ -81,15 +80,15 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None): if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.now(timezone.utc) + timedelta(minutes=15) + expire = datetime.now(timezone.utc) + timedelta( + seconds=config.access_token_expire_minutes + ) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, config.secret_key, algorithm="HS256") return encoded_jwt -async def get_current_user( - security_scopes: SecurityScopes, token: Annotated[str, Depends(oauth2_scheme)] -): +async def get_current_user(security_scopes: SecurityScopes, request: Request): if security_scopes.scopes: authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' else: @@ -99,13 +98,22 @@ async def get_current_user( detail="Could not validate credentials", headers={"WWW-Authenticate": authenticate_value}, ) + access_token = request.cookies.get("access_token") + if not access_token: + raise credentials_exception try: - payload = jwt.decode(token, config.secret_key, algorithms=["HS256"]) + payload = jwt.decode(access_token, config.secret_key, algorithms=["HS256"]) username: str = payload.get("sub") if username is None: raise credentials_exception token_scopes = payload.get("scopes", []) token_data = TokenData(username=username, scopes=token_scopes) + except ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token expired", + headers={"WWW-Authenticate": authenticate_value}, + ) except (InvalidTokenError, ValidationError): raise credentials_exception user = get_user(username=token_data.username) @@ -139,17 +147,24 @@ async def login_for_access_token( detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) - access_token_expires = timedelta(minutes=config.access_token_expire_minutes) allowed_scopes = set(user.scopes.split()) requested_scopes = set(form_data.scopes) access_token = create_access_token( - data={"sub": user.username, "scopes": list(allowed_scopes)}, - expires_delta=access_token_expires, + data={"sub": user.username, "scopes": list(allowed_scopes)} ) response.set_cookie( - "Authorization", value=f"Bearer {access_token}", httponly=True, samesite="none" + "access_token", + value=access_token, + httponly=True, + samesite="strict", + max_age=15, ) - return Token(access_token=access_token, token_type="bearer") + return Token(access_token=access_token) + + +async def logout(response: Response): + response.set_cookie("access_token", "", expires=0, httponly=True, samesite="strict") + return async def read_users_me( diff --git a/src/Rankings.tsx b/src/Rankings.tsx index 409f907..306c07c 100644 --- a/src/Rankings.tsx +++ b/src/Rankings.tsx @@ -1,6 +1,6 @@ import { Dispatch, SetStateAction, useEffect, useMemo, useState } from "react"; import { ReactSortable, ReactSortableProps } from "react-sortablejs"; -import api, { baseUrl } from "./api"; +import { apiAuth } from "./api"; interface Player { id: number; @@ -124,7 +124,7 @@ export function Chemistry({ user, players }: PlayerInfoProps) { let middle = playersMiddle.map(({ name }) => name); let right = playersRight.map(({ name }) => name); const data = { user: _user, hate: left, undecided: middle, love: right }; - const response = await api("chemistry", data); + const response = await apiAuth("chemistry", data); response.ok ? setDialog("success!") : setDialog("try sending again"); } } @@ -203,7 +203,7 @@ export function MVP({ user, players }: PlayerInfoProps) { let _user = user.map(({ name }) => name)[0]; let mvps = rankedPlayers.map(({ name }) => name); const data = { user: _user, mvps: mvps }; - const response = await api("mvps", data); + const response = await apiAuth("mvps", data); response.ok ? setDialog("success!") : setDialog("try sending again"); } } @@ -272,10 +272,7 @@ export default function Rankings() { const [openTab, setOpenTab] = useState("Chemistry"); async function loadPlayers() { - const response = await fetch(`${baseUrl}api/player/list`, { - method: "GET", - }); - const data = await response.json(); + const data = await apiAuth("player/list", null, "GET"); setPlayers(data as Player[]); } @@ -334,8 +331,11 @@ export default function Rankings() { - assign as many or as few players as you want
- and don't forget to submit (💾) when you're done :)
+ + assign as many or as few players as you want +
+ and don't forget to submit (💾) when you're done :) +
diff --git a/src/api.ts b/src/api.ts index e2526d9..949ac82 100644 --- a/src/api.ts +++ b/src/api.ts @@ -1,22 +1,4 @@ export const baseUrl = import.meta.env.VITE_BASE_URL as string; -export const token = () => localStorage.getItem("access_token") as string; - -export default async function api(path: string, data: any): Promise { - const request = new Request(`${baseUrl}${path}/`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(data), - }); - let response: Response; - try { - response = await fetch(request); - } catch (e) { - throw new Error(`request failed: ${e}`); - } - return response; -} export async function apiAuth( path: string, @@ -26,9 +8,9 @@ export async function apiAuth( const req = new Request(`${baseUrl}api/${path}`, { method: method, headers: { - Authorization: `Bearer ${token()} `, "Content-Type": "application/json", }, + credentials: "include", ...(data && { body: JSON.stringify(data) }), }); let resp: Response; @@ -55,13 +37,12 @@ export type User = { }; export async function currentUser(): Promise { - if (!token()) throw new Error("you have no access token"); const req = new Request(`${baseUrl}api/users/me/`, { method: "GET", headers: { - Authorization: `Bearer ${token()} `, "Content-Type": "application/json", }, + credentials: "include", }); let resp: Response; try { @@ -83,12 +64,7 @@ export type LoginRequest = { username: string; password: string; }; -export type Token = { - access_token: string; - token_type: string; -}; -// api.js export const login = async (req: LoginRequest): Promise => { try { const response = await fetch(`${baseUrl}api/token`, { @@ -97,20 +73,24 @@ export const login = async (req: LoginRequest): Promise => { "Content-Type": "application/x-www-form-urlencoded", }, body: new URLSearchParams(req).toString(), + credentials: "include", }); if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); } - const token = (await response.json()) as Token; - if (token && token.access_token) { - localStorage.setItem("access_token", token.access_token); - } else { - console.log("Token not acquired"); - } } catch (e) { console.error(e); throw e; // rethrow the error so it can be caught by the caller } }; -export const logout = () => localStorage.removeItem("access_token"); +export const logout = async () => { + try { + await fetch(`${baseUrl}api/logout`, { + method: "POST", + credentials: "include", + }); + } catch (e) { + console.error(e); + } +};