diff --git a/analysis.py b/analysis.py index 9f0d88e..ea1ed7e 100644 --- a/analysis.py +++ b/analysis.py @@ -1,130 +1,20 @@ -from datetime import timedelta, timezone, datetime +from datetime import datetime import io -from typing import Annotated import base64 -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter from fastapi.responses import JSONResponse from pydantic import BaseModel, Field -import jwt -from jwt.exceptions import InvalidTokenError from sqlmodel import Session, func, select from sqlmodel.sql.expression import SelectOfScalar -from db import Chemistry, Player, engine, User +from db import Chemistry, Player, engine import networkx as nx -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from pydantic_settings import BaseSettings, SettingsConfigDict -from passlib.context import CryptContext import matplotlib matplotlib.use("agg") import matplotlib.pyplot as plt -class Config(BaseSettings): - secret_key: str = "" - access_token_expire_minutes: int = 30 - model_config = SettingsConfigDict( - env_file=".env", env_file_encoding="utf-8", extra="ignore" - ) - - -config = Config() - - -class Token(BaseModel): - access_token: str - token_type: str - - -class TokenData(BaseModel): - username: str | None = None - - -pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") - - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") - -analysis_router = APIRouter(prefix="/analysis", dependencies=[Depends(oauth2_scheme)]) - - -def verify_password(plain_password, hashed_password): - return pwd_context.verify(plain_password, hashed_password) - - -def get_password_hash(password): - return pwd_context.hash(password) - - -def get_user(username: str): - with Session(engine) as session: - return session.exec(select(User).where(User.username == username)).one_or_none() - - -def authenticate_user(username: str, password: str): - user = get_user(username) - if not user: - return False - if not verify_password(password, user.hashed_password): - return False - return user - - -def create_access_token(data: dict, expires_delta: timedelta | None = None): - to_encode = data.copy() - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - else: - expire = datetime.now(timezone.utc) + timedelta(minutes=15) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, config.secret_key, algorithm="HS256") - return encoded_jwt - - -async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, config.secret_key, algorithms=["HS256"]) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - token_data = TokenData(username=username) - except InvalidTokenError: - raise credentials_exception - user = get_user(fake_users_db, username=token_data.username) - if user is None: - raise credentials_exception - return user - - -async def get_current_active_user( - current_user: Annotated[User, Depends(get_current_user)], -): - if current_user.disabled: - raise HTTPException(status_code=400, detail="Inactive user") - return current_user - - -@analysis_router.post("/token") -async def login_for_access_token( - form_data: Annotated[OAuth2PasswordRequestForm, Depends()], -) -> Token: - user = authenticate_user(form_data.username, form_data.password) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", - headers={"WWW-Authenticate": "Bearer"}, - ) - access_token_expires = timedelta(minutes=config.access_token_expire_minutes) - access_token = create_access_token( - data={"sub": user.username}, expires_delta=access_token_expires - ) - return Token(access_token=access_token, token_type="bearer") +analysis_router = APIRouter(prefix="/analysis") C = Chemistry diff --git a/main.py b/main.py index 301239f..3ab6c20 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, FastAPI, status +from fastapi import APIRouter, Depends, FastAPI, status from fastapi.staticfiles import StaticFiles from db import Player, Team, Chemistry, MVPRanking, engine from sqlmodel import ( @@ -7,7 +7,12 @@ from sqlmodel import ( ) from fastapi.middleware.cors import CORSMiddleware from analysis import analysis_router -from security import login_for_access_token, read_users_me, read_own_items +from security import ( + get_current_active_user, + login_for_access_token, + read_users_me, + read_own_items, +) app = FastAPI(title="cutt") @@ -92,7 +97,9 @@ class SPAStaticFiles(StaticFiles): api_router.include_router(player_router) api_router.include_router(team_router) -api_router.include_router(analysis_router) +api_router.include_router( + analysis_router, dependencies=[Depends(get_current_active_user)] +) api_router.add_api_route("/token", endpoint=login_for_access_token, methods=["POST"]) api_router.add_api_route("/users/me/", endpoint=read_users_me, methods=["GET"]) api_router.add_api_route("/users/me/items/", endpoint=read_own_items, methods=["GET"]) diff --git a/security.py b/security.py index df250a1..dca1e37 100644 --- a/security.py +++ b/security.py @@ -34,7 +34,7 @@ class TokenData(BaseModel): pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/token") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token") def verify_password(plain_password, hashed_password):