feat: set first password with token
This commit is contained in:
		
							
								
								
									
										71
									
								
								security.py
									
									
									
									
									
								
							
							
						
						
									
										71
									
								
								security.py
									
									
									
									
									
								
							@@ -1,6 +1,7 @@
 | 
				
			|||||||
from datetime import timedelta, timezone, datetime
 | 
					from datetime import timedelta, timezone, datetime
 | 
				
			||||||
from typing import Annotated
 | 
					from typing import Annotated
 | 
				
			||||||
from fastapi import Depends, HTTPException, Request, Response, status
 | 
					from fastapi import Depends, HTTPException, Request, Response, status
 | 
				
			||||||
 | 
					from fastapi.responses import PlainTextResponse
 | 
				
			||||||
from pydantic import BaseModel, ValidationError
 | 
					from pydantic import BaseModel, ValidationError
 | 
				
			||||||
import jwt
 | 
					import jwt
 | 
				
			||||||
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
 | 
					from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
 | 
				
			||||||
@@ -76,7 +77,7 @@ def get_user(username: str | None):
 | 
				
			|||||||
        try:
 | 
					        try:
 | 
				
			||||||
            with Session(engine) as session:
 | 
					            with Session(engine) as session:
 | 
				
			||||||
                return session.exec(
 | 
					                return session.exec(
 | 
				
			||||||
                    select(Player).where(User.username == username)
 | 
					                    select(Player).where(Player.username == username)
 | 
				
			||||||
                ).one_or_none()
 | 
					                ).one_or_none()
 | 
				
			||||||
        except OperationalError:
 | 
					        except OperationalError:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
@@ -97,7 +98,7 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
 | 
				
			|||||||
        expire = datetime.now(timezone.utc) + expires_delta
 | 
					        expire = datetime.now(timezone.utc) + expires_delta
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        expire = datetime.now(timezone.utc) + timedelta(
 | 
					        expire = datetime.now(timezone.utc) + timedelta(
 | 
				
			||||||
            seconds=config.access_token_expire_minutes
 | 
					            minutes=config.access_token_expire_minutes
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    to_encode.update({"exp": expire})
 | 
					    to_encode.update({"exp": expire})
 | 
				
			||||||
    encoded_jwt = jwt.encode(to_encode, config.secret_key, algorithm="HS256")
 | 
					    encoded_jwt = jwt.encode(to_encode, config.secret_key, algorithm="HS256")
 | 
				
			||||||
@@ -154,7 +155,7 @@ async def get_current_active_user(
 | 
				
			|||||||
):
 | 
					):
 | 
				
			||||||
    if current_user.disabled:
 | 
					    if current_user.disabled:
 | 
				
			||||||
        raise HTTPException(status_code=400, detail="Inactive user")
 | 
					        raise HTTPException(status_code=400, detail="Inactive user")
 | 
				
			||||||
    return current_user.model_dump(exclude={"hashed_password", "disabled"})
 | 
					    return current_user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def login_for_access_token(
 | 
					async def login_for_access_token(
 | 
				
			||||||
@@ -187,7 +188,69 @@ async def logout(response: Response):
 | 
				
			|||||||
    return {"message": "Successfully logged out"}
 | 
					    return {"message": "Successfully logged out"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def read_users_me(
 | 
					def generate_one_time_token(username):
 | 
				
			||||||
 | 
					    expire = timedelta(days=7)
 | 
				
			||||||
 | 
					    token = create_access_token(data={"sub": username}, expires_delta=expire)
 | 
				
			||||||
 | 
					    return token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FirstPassword(BaseModel):
 | 
				
			||||||
 | 
					    token: str
 | 
				
			||||||
 | 
					    password: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def set_first_password(req: FirstPassword):
 | 
				
			||||||
 | 
					    credentials_exception = HTTPException(
 | 
				
			||||||
 | 
					        status_code=status.HTTP_401_UNAUTHORIZED,
 | 
				
			||||||
 | 
					        detail="Could not validate token",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        payload = jwt.decode(req.token, config.secret_key, algorithms=["HS256"])
 | 
				
			||||||
 | 
					        username: str = payload.get("sub")
 | 
				
			||||||
 | 
					        if username is None:
 | 
				
			||||||
 | 
					            raise credentials_exception
 | 
				
			||||||
 | 
					    except ExpiredSignatureError:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=status.HTTP_401_UNAUTHORIZED,
 | 
				
			||||||
 | 
					            detail="Access token expired",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    except (InvalidTokenError, ValidationError):
 | 
				
			||||||
 | 
					        raise credentials_exception
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    user = get_user(username)
 | 
				
			||||||
 | 
					    if user:
 | 
				
			||||||
 | 
					        with Session(engine) as session:
 | 
				
			||||||
 | 
					            user.hashed_password = get_password_hash(req.password)
 | 
				
			||||||
 | 
					            session.add(user)
 | 
				
			||||||
 | 
					            session.commit()
 | 
				
			||||||
 | 
					            return Response("Password set successfully", status_code=status.HTTP_200_OK)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def change_password(
 | 
				
			||||||
 | 
					    current_password: str,
 | 
				
			||||||
 | 
					    new_password: str,
 | 
				
			||||||
 | 
					    user: Annotated[Player, Depends(get_current_active_user)],
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    if (
 | 
				
			||||||
 | 
					        new_password
 | 
				
			||||||
 | 
					        and user.hashed_password
 | 
				
			||||||
 | 
					        and verify_password(current_password, user.hashed_password)
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        with Session(engine) as session:
 | 
				
			||||||
 | 
					            user.hashed_password = get_password_hash(new_password)
 | 
				
			||||||
 | 
					            session.add(user)
 | 
				
			||||||
 | 
					            session.commit()
 | 
				
			||||||
 | 
					            return PlainTextResponse(
 | 
				
			||||||
 | 
					                "Password changed successfully", status_code=status.HTTP_200_OK
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=status.HTTP_400_BAD_REQUEST,
 | 
				
			||||||
 | 
					            detail="Wrong password",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def read_player_me(
 | 
				
			||||||
    current_user: Annotated[Player, Depends(get_current_active_user)],
 | 
					    current_user: Annotated[Player, Depends(get_current_active_user)],
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    return current_user
 | 
					    return current_user
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user