diff --git a/db.py b/db.py
index ac0bde1..e070cda 100644
--- a/db.py
+++ b/db.py
@@ -12,7 +12,9 @@ from sqlmodel import (
with open("db.secrets", "r") as f:
db_secrets = f.readline().strip()
-engine = create_engine(db_secrets, connect_args={"connect_timeout": 8})
+engine = create_engine(
+ db_secrets, pool_timeout=20, pool_size=2, connect_args={"connect_timeout": 8}
+)
del db_secrets
diff --git a/main.py b/main.py
index 57aa01a..959a1da 100644
--- a/main.py
+++ b/main.py
@@ -10,6 +10,7 @@ from analysis import analysis_router
from security import (
get_current_active_user,
login_for_access_token,
+ logout,
read_users_me,
read_own_items,
)
@@ -64,21 +65,11 @@ def list_teams():
player_router = APIRouter(prefix="/player")
player_router.add_api_route("/list", endpoint=list_players, methods=["GET"])
-player_router.add_api_route(
- "/add",
- endpoint=add_player,
- methods=["POST"],
- dependencies=[Depends(get_current_active_user)],
-)
+player_router.add_api_route("/add", endpoint=add_player, methods=["POST"])
team_router = APIRouter(prefix="/team")
team_router.add_api_route("/list", endpoint=list_teams, methods=["GET"])
-team_router.add_api_route(
- "/add",
- endpoint=add_team,
- methods=["POST"],
- dependencies=[Depends(get_current_active_user)],
-)
+team_router.add_api_route("/add", endpoint=add_team, methods=["POST"])
@app.post("/mvps/", status_code=status.HTTP_200_OK)
@@ -103,13 +94,16 @@ class SPAStaticFiles(StaticFiles):
return response
-api_router.include_router(player_router)
-api_router.include_router(team_router)
+api_router.include_router(
+ player_router, dependencies=[Depends(get_current_active_user)]
+)
+api_router.include_router(team_router, dependencies=[Depends(get_current_active_user)])
api_router.include_router(
analysis_router,
dependencies=[Security(get_current_active_user, scopes=["analysis"])],
)
api_router.add_api_route("/token", endpoint=login_for_access_token, methods=["POST"])
+api_router.add_api_route("/logout", endpoint=logout, methods=["POST"])
api_router.add_api_route("/users/me/", endpoint=read_users_me, methods=["GET"])
api_router.add_api_route("/users/me/items/", endpoint=read_own_items, methods=["GET"])
app.include_router(api_router)
diff --git a/security.py b/security.py
index 0c6f5d1..a88b271 100644
--- a/security.py
+++ b/security.py
@@ -1,9 +1,9 @@
from datetime import timedelta, timezone, datetime
from typing import Annotated
-from fastapi import Depends, HTTPException, Response, status
+from fastapi import Depends, HTTPException, Request, Response, status
from pydantic import BaseModel, ValidationError
import jwt
-from jwt.exceptions import InvalidTokenError
+from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from sqlmodel import Session, select
from db import engine, User
from fastapi.security import (
@@ -18,7 +18,7 @@ from sqlalchemy.exc import OperationalError
class Config(BaseSettings):
secret_key: str = ""
- access_token_expire_minutes: int = 30
+ access_token_expire_minutes: int = 15
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
@@ -29,7 +29,6 @@ config = Config()
class Token(BaseModel):
access_token: str
- token_type: str
class TokenData(BaseModel):
@@ -81,15 +80,15 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
- expire = datetime.now(timezone.utc) + timedelta(minutes=15)
+ expire = datetime.now(timezone.utc) + timedelta(
+ seconds=config.access_token_expire_minutes
+ )
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, config.secret_key, algorithm="HS256")
return encoded_jwt
-async def get_current_user(
- security_scopes: SecurityScopes, token: Annotated[str, Depends(oauth2_scheme)]
-):
+async def get_current_user(security_scopes: SecurityScopes, request: Request):
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
@@ -99,13 +98,22 @@ async def get_current_user(
detail="Could not validate credentials",
headers={"WWW-Authenticate": authenticate_value},
)
+ access_token = request.cookies.get("access_token")
+ if not access_token:
+ raise credentials_exception
try:
- payload = jwt.decode(token, config.secret_key, algorithms=["HS256"])
+ payload = jwt.decode(access_token, config.secret_key, algorithms=["HS256"])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(username=username, scopes=token_scopes)
+ except ExpiredSignatureError:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Token expired",
+ headers={"WWW-Authenticate": authenticate_value},
+ )
except (InvalidTokenError, ValidationError):
raise credentials_exception
user = get_user(username=token_data.username)
@@ -139,17 +147,24 @@ async def login_for_access_token(
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
- access_token_expires = timedelta(minutes=config.access_token_expire_minutes)
allowed_scopes = set(user.scopes.split())
requested_scopes = set(form_data.scopes)
access_token = create_access_token(
- data={"sub": user.username, "scopes": list(allowed_scopes)},
- expires_delta=access_token_expires,
+ data={"sub": user.username, "scopes": list(allowed_scopes)}
)
response.set_cookie(
- "Authorization", value=f"Bearer {access_token}", httponly=True, samesite="none"
+ "access_token",
+ value=access_token,
+ httponly=True,
+ samesite="strict",
+ max_age=15,
)
- return Token(access_token=access_token, token_type="bearer")
+ return Token(access_token=access_token)
+
+
+async def logout(response: Response):
+ response.set_cookie("access_token", "", expires=0, httponly=True, samesite="strict")
+ return
async def read_users_me(
diff --git a/src/Rankings.tsx b/src/Rankings.tsx
index 409f907..306c07c 100644
--- a/src/Rankings.tsx
+++ b/src/Rankings.tsx
@@ -1,6 +1,6 @@
import { Dispatch, SetStateAction, useEffect, useMemo, useState } from "react";
import { ReactSortable, ReactSortableProps } from "react-sortablejs";
-import api, { baseUrl } from "./api";
+import { apiAuth } from "./api";
interface Player {
id: number;
@@ -124,7 +124,7 @@ export function Chemistry({ user, players }: PlayerInfoProps) {
let middle = playersMiddle.map(({ name }) => name);
let right = playersRight.map(({ name }) => name);
const data = { user: _user, hate: left, undecided: middle, love: right };
- const response = await api("chemistry", data);
+ const response = await apiAuth("chemistry", data);
response.ok ? setDialog("success!") : setDialog("try sending again");
}
}
@@ -203,7 +203,7 @@ export function MVP({ user, players }: PlayerInfoProps) {
let _user = user.map(({ name }) => name)[0];
let mvps = rankedPlayers.map(({ name }) => name);
const data = { user: _user, mvps: mvps };
- const response = await api("mvps", data);
+ const response = await apiAuth("mvps", data);
response.ok ? setDialog("success!") : setDialog("try sending again");
}
}
@@ -272,10 +272,7 @@ export default function Rankings() {
const [openTab, setOpenTab] = useState("Chemistry");
async function loadPlayers() {
- const response = await fetch(`${baseUrl}api/player/list`, {
- method: "GET",
- });
- const data = await response.json();
+ const data = await apiAuth("player/list", null, "GET");
setPlayers(data as Player[]);
}
@@ -334,8 +331,11 @@ export default function Rankings() {
- assign as many or as few players as you want
- and don't forget to submit (💾) when you're done :)
+
+ assign as many or as few players as you want
+
+ and don't forget to submit (💾) when you're done :)
+