diff --git a/db.py b/db.py index dc61f42..ac0bde1 100644 --- a/db.py +++ b/db.py @@ -69,6 +69,7 @@ class User(SQLModel, table=True): disabled: bool | None = None hashed_password: str | None = None player_id: int | None = Field(default=None, foreign_key="player.id") + scopes: str = "" SQLModel.metadata.create_all(engine) diff --git a/main.py b/main.py index 8ae66d4..57aa01a 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, FastAPI, status +from fastapi import APIRouter, Depends, FastAPI, Security, status from fastapi.staticfiles import StaticFiles from db import Player, Team, Chemistry, MVPRanking, engine from sqlmodel import ( @@ -106,7 +106,8 @@ class SPAStaticFiles(StaticFiles): api_router.include_router(player_router) api_router.include_router(team_router) api_router.include_router( - analysis_router, dependencies=[Depends(get_current_active_user)] + analysis_router, + dependencies=[Security(get_current_active_user, scopes=["analysis"])], ) api_router.add_api_route("/token", endpoint=login_for_access_token, methods=["POST"]) api_router.add_api_route("/users/me/", endpoint=read_users_me, methods=["GET"]) diff --git a/security.py b/security.py index 85d76c5..d3abe56 100644 --- a/security.py +++ b/security.py @@ -1,12 +1,16 @@ from datetime import timedelta, timezone, datetime from typing import Annotated from fastapi import Depends, HTTPException, Response, status -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError import jwt from jwt.exceptions import InvalidTokenError from sqlmodel import Session, select from db import engine, User -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from fastapi.security import ( + OAuth2PasswordBearer, + OAuth2PasswordRequestForm, + SecurityScopes, +) from pydantic_settings import BaseSettings, SettingsConfigDict from passlib.context import CryptContext from sqlalchemy.exc import OperationalError @@ -30,12 +34,18 @@ class Token(BaseModel): class TokenData(BaseModel): username: str | None = None + scopes: list[str] = [] pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token") +oauth2_scheme = OAuth2PasswordBearer( + tokenUrl="api/token", + scopes={ + "analysis": "Access the results.", + }, +) def verify_password(plain_password, hashed_password): @@ -77,23 +87,37 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None): return encoded_jwt -async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): +async def get_current_user( + security_scopes: SecurityScopes, token: Annotated[str, Depends(oauth2_scheme)] +): + if security_scopes.scopes: + authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' + else: + authenticate_value = "Bearer" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, + headers={"WWW-Authenticate": authenticate_value}, ) try: payload = jwt.decode(token, config.secret_key, algorithms=["HS256"]) username: str = payload.get("sub") if username is None: raise credentials_exception - token_data = TokenData(username=username) - except InvalidTokenError: + token_scopes = payload.get("scopes", []) + token_data = TokenData(username=username, scopes=token_scopes) + except (InvalidTokenError, ValidationError): raise credentials_exception user = get_user(username=token_data.username) if user is None: raise credentials_exception + for scope in security_scopes.scopes: + if scope not in token_data.scopes: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + headers={"WWW-Authenticate": authenticate_value}, + ) return user @@ -116,8 +140,11 @@ async def login_for_access_token( headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=config.access_token_expire_minutes) + allowed_scopes = set(user.scopes.split()) + requested_scopes = set(form_data.scopes) access_token = create_access_token( - data={"sub": user.username}, expires_delta=access_token_expires + data={"sub": user.username, "scopes": list(allowed_scopes & requested_scopes)}, + expires_delta=access_token_expires, ) response.set_cookie( "Authorization", value=f"Bearer {access_token}", httponly=True, samesite="none"