cutt/security.py
2025-03-17 19:26:09 +01:00

300 lines
9.2 KiB
Python

from datetime import timedelta, timezone, datetime
from typing import Annotated
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, ValidationError
import jwt
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from sqlmodel import Session, select
from db import TokenDB, engine, Player
from fastapi.security import (
OAuth2PasswordBearer,
OAuth2PasswordRequestForm,
SecurityScopes,
)
from pydantic_settings import BaseSettings, SettingsConfigDict
from passlib.context import CryptContext
from sqlalchemy.exc import OperationalError
class Config(BaseSettings):
secret_key: str = ""
access_token_expire_minutes: int = 15
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
config = Config()
class Token(BaseModel):
access_token: str
class TokenData(BaseModel):
username: str | None = None
scopes: list[str] = []
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
class CookieOAuth2(OAuth2PasswordBearer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def __call__(self, request: Request):
cookie_token = request.cookies.get("access_token")
if cookie_token:
return cookie_token
else:
header_token = await super().__call__(request)
if header_token:
return header_token
else:
raise HTTPException(status_code=401)
oauth2_scheme = CookieOAuth2(
tokenUrl="api/token",
scopes={
"analysis": "Access the results.",
"admin": "Maintain DB etc.",
},
)
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def get_user(username: str | None):
if username:
try:
with Session(engine) as session:
return session.exec(
select(Player).where(Player.username == username)
).one_or_none()
except OperationalError:
return
def authenticate_user(username: str, password: str):
user = get_user(username)
if not user:
return False
if not verify_password(password, user.hashed_password):
return False
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(
minutes=config.access_token_expire_minutes
)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, config.secret_key, algorithm="HS256")
return encoded_jwt
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)],
security_scopes: SecurityScopes,
):
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = "Bearer"
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": authenticate_value},
)
# access_token = request.cookies.get("access_token")
access_token = token
if not access_token:
raise credentials_exception
try:
payload = jwt.decode(access_token, config.secret_key, algorithms=["HS256"])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(username=username, scopes=token_scopes)
except ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Access token expired",
headers={"WWW-Authenticate": authenticate_value},
)
except (InvalidTokenError, ValidationError):
raise credentials_exception
user = get_user(username=token_data.username)
if user is None:
raise credentials_exception
for scope in security_scopes.scopes:
if scope not in token_data.scopes:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not enough permissions",
headers={"WWW-Authenticate": authenticate_value},
)
return user
async def get_current_active_user(
current_user: Annotated[Player, Depends(get_current_user)],
):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], response: Response
) -> Token:
user = authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
allowed_scopes = set(user.scopes.split())
requested_scopes = set(form_data.scopes)
access_token = create_access_token(
data={"sub": user.username, "scopes": list(allowed_scopes)}
)
response.set_cookie(
"access_token",
value=access_token,
httponly=True,
samesite="strict",
max_age=config.access_token_expire_minutes * 60,
)
return Token(access_token=access_token)
async def logout(response: Response):
response.set_cookie("access_token", "", expires=0, httponly=True, samesite="strict")
return {"message": "Successfully logged out"}
def generate_one_time_token(username):
user = get_user(username)
if user:
expire = timedelta(days=7)
token = create_access_token(
data={"sub": username, "name": user.display_name},
expires_delta=expire,
)
return token
class FirstPassword(BaseModel):
token: str
password: str
async def set_first_password(req: FirstPassword):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate token",
)
with Session(engine) as session:
token_in_db = session.exec(
select(TokenDB)
.where(TokenDB.token == req.token)
.where(TokenDB.used == False)
).one_or_none()
if token_in_db:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate token",
)
try:
payload = jwt.decode(req.token, config.secret_key, algorithms=["HS256"])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Access token expired",
)
except (InvalidTokenError, ValidationError):
raise credentials_exception
user = get_user(username)
if user:
user.hashed_password = get_password_hash(req.password)
session.add(user)
token_in_db.used = True
session.add(token_in_db)
session.commit()
return Response(
"Password set successfully", status_code=status.HTTP_200_OK
)
elif session.exec(
select(TokenDB)
.where(TokenDB.token == req.token)
.where(TokenDB.used == True)
).one_or_none():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token already used",
)
else:
raise credentials_exception
class ChangedPassword(BaseModel):
current_password: str
new_password: str
async def change_password(
request: ChangedPassword,
user: Annotated[Player, Depends(get_current_active_user)],
):
if (
request.new_password
and user.hashed_password
and verify_password(request.current_password, user.hashed_password)
):
with Session(engine) as session:
user.hashed_password = get_password_hash(request.new_password)
session.add(user)
session.commit()
return PlainTextResponse(
"Password changed successfully",
status_code=status.HTTP_200_OK,
media_type="text/plain",
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Wrong password",
)
async def read_player_me(
current_user: Annotated[Player, Depends(get_current_active_user)],
):
return current_user.model_dump(exclude={"hashed_password", "disabled"})
async def read_own_items(
current_user: Annotated[Player, Depends(get_current_active_user)],
):
return [{"item_id": "Foo", "owner": current_user.username}]