diff --git a/db.py b/db.py index 49a04af..b6e9f57 100644 --- a/db.py +++ b/db.py @@ -73,4 +73,12 @@ class MVPRanking(SQLModel, table=True): mvps: list[int] = Field(sa_column=Column(ARRAY(Integer))) +class TokenDB(SQLModel, table=True): + token: str = Field(index=True, primary_key=True) + used: bool | None = False + updated_at: datetime | None = Field( + default_factory=utctime, sa_column_kwargs={"onupdate": utctime} + ) + + SQLModel.metadata.create_all(engine) diff --git a/security.py b/security.py index ed1bac4..0dc76e6 100644 --- a/security.py +++ b/security.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ValidationError import jwt from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from sqlmodel import Session, select -from db import engine, Player +from db import TokenDB, engine, Player from fastapi.security import ( OAuth2PasswordBearer, OAuth2PasswordRequestForm, @@ -178,6 +178,7 @@ async def login_for_access_token( value=access_token, httponly=True, samesite="strict", + max_age=config.access_token_expire_minutes * 60, ) return Token(access_token=access_token) @@ -208,26 +209,51 @@ async def set_first_password(req: FirstPassword): status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate token", ) - try: - payload = jwt.decode(req.token, config.secret_key, algorithms=["HS256"]) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - except ExpiredSignatureError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Access token expired", - ) - except (InvalidTokenError, ValidationError): - raise credentials_exception + with Session(engine) as session: + token_in_db = session.exec( + select(TokenDB) + .where(TokenDB.token == req.token) + .where(TokenDB.used == False) + ).one_or_none() + if token_in_db: + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate token", + ) + try: + payload = jwt.decode(req.token, config.secret_key, algorithms=["HS256"]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + except ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Access token expired", + ) + except (InvalidTokenError, ValidationError): + raise credentials_exception - user = get_user(username) - if user: - with Session(engine) as session: - user.hashed_password = get_password_hash(req.password) - session.add(user) - session.commit() - return Response("Password set successfully", status_code=status.HTTP_200_OK) + user = get_user(username) + if user: + user.hashed_password = get_password_hash(req.password) + session.add(user) + token_in_db.used = True + session.add(token_in_db) + session.commit() + return Response( + "Password set successfully", status_code=status.HTTP_200_OK + ) + elif session.exec( + select(TokenDB) + .where(TokenDB.token == req.token) + .where(TokenDB.used == True) + ).one_or_none(): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token already used", + ) + else: + raise credentials_exception async def change_password( diff --git a/src/SetPassword.tsx b/src/SetPassword.tsx index f62540e..276b457 100644 --- a/src/SetPassword.tsx +++ b/src/SetPassword.tsx @@ -63,9 +63,9 @@ export const SetPassword = () => { if (!resp.ok) { if (resp.status === 401) { - resp.statusText - ? setError(resp.statusText) - : setError("unauthorized"); + const { detail } = await resp.json(); + if (detail) setError(detail); + else setError("unauthorized"); throw new Error("Unauthorized"); } }