diff --git a/security.py b/security.py index 9e116e8..0cdebce 100644 --- a/security.py +++ b/security.py @@ -39,7 +39,23 @@ class TokenData(BaseModel): pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") -oauth2_scheme = OAuth2PasswordBearer( +class CookieOAuth2(OAuth2PasswordBearer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def __call__(self, request: Request): + cookie_token = request.cookies.get("access_token") + if cookie_token: + return cookie_token + else: + header_token = await super().__call__(request) + if header_token: + return header_token + else: + raise HTTPException(status_code=401) + + +oauth2_scheme = CookieOAuth2( tokenUrl="api/token", scopes={ "analysis": "Access the results.", @@ -88,7 +104,10 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None): return encoded_jwt -async def get_current_user(security_scopes: SecurityScopes, request: Request): +async def get_current_user( + token: Annotated[str, Depends(oauth2_scheme)], + security_scopes: SecurityScopes, +): if security_scopes.scopes: authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' else: @@ -98,7 +117,8 @@ async def get_current_user(security_scopes: SecurityScopes, request: Request): detail="Could not validate credentials", headers={"WWW-Authenticate": authenticate_value}, ) - access_token = request.cookies.get("access_token") + # access_token = request.cookies.get("access_token") + access_token = token if not access_token: raise credentials_exception try: