Compare commits

..

No commits in common. "b7c8136b1ece403c7b1412b1f18558b325894519" and "d61bea3c8600f9f2250b65e4a2636410642253d1" have entirely different histories.

4 changed files with 40 additions and 82 deletions

1
db.py
View File

@ -69,7 +69,6 @@ class User(SQLModel, table=True):
disabled: bool | None = None disabled: bool | None = None
hashed_password: str | None = None hashed_password: str | None = None
player_id: int | None = Field(default=None, foreign_key="player.id") player_id: int | None = Field(default=None, foreign_key="player.id")
scopes: str = ""
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, FastAPI, Security, status from fastapi import APIRouter, Depends, FastAPI, status
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from db import Player, Team, Chemistry, MVPRanking, engine from db import Player, Team, Chemistry, MVPRanking, engine
from sqlmodel import ( from sqlmodel import (
@ -106,8 +106,7 @@ class SPAStaticFiles(StaticFiles):
api_router.include_router(player_router) api_router.include_router(player_router)
api_router.include_router(team_router) api_router.include_router(team_router)
api_router.include_router( api_router.include_router(
analysis_router, analysis_router, dependencies=[Depends(get_current_active_user)]
dependencies=[Security(get_current_active_user, scopes=["analysis"])],
) )
api_router.add_api_route("/token", endpoint=login_for_access_token, methods=["POST"]) api_router.add_api_route("/token", endpoint=login_for_access_token, methods=["POST"])
api_router.add_api_route("/users/me/", endpoint=read_users_me, methods=["GET"]) api_router.add_api_route("/users/me/", endpoint=read_users_me, methods=["GET"])

View File

@ -1,16 +1,12 @@
from datetime import timedelta, timezone, datetime from datetime import timedelta, timezone, datetime
from typing import Annotated from typing import Annotated
from fastapi import Depends, HTTPException, Response, status from fastapi import Depends, HTTPException, Response, status
from pydantic import BaseModel, ValidationError from pydantic import BaseModel
import jwt import jwt
from jwt.exceptions import InvalidTokenError from jwt.exceptions import InvalidTokenError
from sqlmodel import Session, select from sqlmodel import Session, select
from db import engine, User from db import engine, User
from fastapi.security import ( from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
OAuth2PasswordBearer,
OAuth2PasswordRequestForm,
SecurityScopes,
)
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from passlib.context import CryptContext from passlib.context import CryptContext
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
@ -34,18 +30,12 @@ class Token(BaseModel):
class TokenData(BaseModel): class TokenData(BaseModel):
username: str | None = None username: str | None = None
scopes: list[str] = []
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
tokenUrl="api/token",
scopes={
"analysis": "Access the results.",
},
)
def verify_password(plain_password, hashed_password): def verify_password(plain_password, hashed_password):
@ -87,37 +77,23 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
return encoded_jwt return encoded_jwt
async def get_current_user( async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
security_scopes: SecurityScopes, token: Annotated[str, Depends(oauth2_scheme)]
):
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = "Bearer"
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials", detail="Could not validate credentials",
headers={"WWW-Authenticate": authenticate_value}, headers={"WWW-Authenticate": "Bearer"},
) )
try: try:
payload = jwt.decode(token, config.secret_key, algorithms=["HS256"]) payload = jwt.decode(token, config.secret_key, algorithms=["HS256"])
username: str = payload.get("sub") username: str = payload.get("sub")
if username is None: if username is None:
raise credentials_exception raise credentials_exception
token_scopes = payload.get("scopes", []) token_data = TokenData(username=username)
token_data = TokenData(username=username, scopes=token_scopes) except InvalidTokenError:
except (InvalidTokenError, ValidationError):
raise credentials_exception raise credentials_exception
user = get_user(username=token_data.username) user = get_user(username=token_data.username)
if user is None: if user is None:
raise credentials_exception raise credentials_exception
for scope in security_scopes.scopes:
if scope not in token_data.scopes:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not enough permissions",
headers={"WWW-Authenticate": authenticate_value},
)
return user return user
@ -140,11 +116,8 @@ async def login_for_access_token(
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
access_token_expires = timedelta(minutes=config.access_token_expire_minutes) access_token_expires = timedelta(minutes=config.access_token_expire_minutes)
allowed_scopes = set(user.scopes.split())
requested_scopes = set(form_data.scopes)
access_token = create_access_token( access_token = create_access_token(
data={"sub": user.username, "scopes": list(allowed_scopes & requested_scopes)}, data={"sub": user.username}, expires_delta=access_token_expires
expires_delta=access_token_expires,
) )
response.set_cookie( response.set_cookie(
"Authorization", value=f"Bearer {access_token}", httponly=True, samesite="none" "Authorization", value=f"Bearer {access_token}", httponly=True, samesite="none"

View File

@ -18,19 +18,18 @@ export default async function api(path: string, data: any): Promise<any> {
return response; return response;
} }
export async function apiAuth( export async function apiAuth(path: string, data: any, method: string = "GET"): Promise<any> {
path: string,
data: any, const req = new Request(`${baseUrl}api/${path}`,
method: string = "GET" {
): Promise<any> {
const req = new Request(`${baseUrl}api/${path}`, {
method: method, method: method,
headers: { headers: {
Authorization: `Bearer ${token()} `, "Authorization": `Bearer ${token()} `,
"Content-Type": "application/json", 'Content-Type': 'application/json'
}, },
...(data && { body: JSON.stringify(data) }), ...(data && { body: JSON.stringify(data) })
}); }
);
let resp: Response; let resp: Response;
try { try {
resp = await fetch(req); resp = await fetch(req);
@ -40,28 +39,25 @@ export async function apiAuth(
if (!resp.ok) { if (!resp.ok) {
if (resp.status === 401) { if (resp.status === 401) {
logout(); logout()
throw new Error("Unauthorized"); throw new Error('Unauthorized');
} }
} }
return resp.json(); return resp.json()
} }
export type User = { export type User = {
username: string; username: string;
full_name: string; fullName: string;
email: string; }
player_id: number;
};
export async function currentUser(): Promise<User> { export async function currentUser(): Promise<User> {
if (!token()) throw new Error("you have no access token"); if (!token()) throw new Error("you have no access token")
const req = new Request(`${baseUrl}api/users/me/`, { const req = new Request(`${baseUrl}api/users/me/`, {
method: "GET", method: "GET", headers: {
headers: { "Authorization": `Bearer ${token()} `,
Authorization: `Bearer ${token()} `, 'Content-Type': 'application/json'
"Content-Type": "application/json", }
},
}); });
let resp: Response; let resp: Response;
try { try {
@ -72,8 +68,8 @@ export async function currentUser(): Promise<User> {
if (!resp.ok) { if (!resp.ok) {
if (resp.status === 401) { if (resp.status === 401) {
logout(); logout()
throw new Error("Unauthorized"); throw new Error('Unauthorized');
} }
} }
return resp.json() as Promise<User>; return resp.json() as Promise<User>;
@ -90,20 +86,11 @@ export type Token = {
export const login = (req: LoginRequest) => { export const login = (req: LoginRequest) => {
fetch(`${baseUrl}api/token`, { fetch(`${baseUrl}api/token`, {
method: "POST", method: "POST", headers: {
headers: { 'Content-Type': 'application/x-www-form-urlencoded',
"Content-Type": "application/x-www-form-urlencoded", }, body: new URLSearchParams(req).toString()
}, }).then(resp => resp.json() as Promise<Token>).then(token => token ? localStorage.setItem("access_token", token.access_token) : console.log("token not acquired")).catch((e) => console.log("catch error " + e + " in login"));
body: new URLSearchParams(req).toString(), return Promise<void>
}) }
.then((resp) => resp.json() as Promise<Token>)
.then((token) =>
token
? localStorage.setItem("access_token", token.access_token)
: console.log("token not acquired")
)
.catch((e) => console.log("catch error " + e + " in login"));
return Promise<void>;
};
export const logout = () => localStorage.removeItem("access_token"); export const logout = () => localStorage.removeItem("access_token");