""" JWT Authentication Replaces Flask session-based auth with JWT tokens + Redis refresh tokens. """ from datetime import datetime, timedelta from typing import Optional from jose import JWTError, jwt from passlib.context import CryptContext from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials import os # Password hashing pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # JWT Configuration SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-in-production") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 REFRESH_TOKEN_EXPIRE_DAYS = 7 # Security scheme security = HTTPBearer() # ===== Password Hashing ===== def hash_password(password: str) -> str: """ Hash a password using bcrypt. Args: password: Plain text password Returns: Hashed password """ return pwd_context.hash(password) def verify_password(plain_password: str, hashed_password: str) -> bool: """ Verify a password against its hash. Args: plain_password: Plain text password hashed_password: Hashed password from database Returns: True if password matches, False otherwise """ return pwd_context.verify(plain_password, hashed_password) # ===== JWT Token Creation ===== def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """ Create JWT access token (short-lived, 30 minutes). Args: data: Payload data (typically {"sub": user_id}) expires_delta: Optional custom expiration time Returns: JWT token string """ to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({ "exp": expire, "type": "access" }) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt def create_refresh_token(user_id: int) -> str: """ Create JWT refresh token (long-lived, 7 days). Stored in Redis for validation. Args: user_id: User ID from database Returns: JWT refresh token string """ expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) to_encode = { "sub": str(user_id), "exp": expire, "type": "refresh" } encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt # ===== JWT Token Validation ===== def decode_token(token: str) -> dict: """ Decode and validate JWT token. Args: token: JWT token string Returns: Decoded payload Raises: HTTPException: If token is invalid or expired """ try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) return payload except JWTError as e: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Invalid token: {str(e)}", headers={"WWW-Authenticate": "Bearer"}, ) def verify_access_token(token: str) -> int: """ Verify access token and extract user ID. Args: token: JWT access token Returns: user_id: User ID from token Raises: HTTPException: If token is invalid or not an access token """ payload = decode_token(token) # Check token type if payload.get("type") != "access": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type", headers={"WWW-Authenticate": "Bearer"}, ) # Extract user ID user_id = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload", headers={"WWW-Authenticate": "Bearer"}, ) return int(user_id) def verify_refresh_token(token: str) -> int: """ Verify refresh token and extract user ID. Args: token: JWT refresh token Returns: user_id: User ID from token Raises: HTTPException: If token is invalid or not a refresh token """ payload = decode_token(token) # Check token type if payload.get("type") != "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type", headers={"WWW-Authenticate": "Bearer"}, ) # Extract user ID user_id = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload", headers={"WWW-Authenticate": "Bearer"}, ) return int(user_id) # ===== FastAPI Dependencies ===== async def get_current_user_id( credentials: HTTPAuthorizationCredentials = Depends(security) ) -> int: """ FastAPI dependency to get current user ID from JWT token. Use this to protect endpoints: @router.get("/protected", dependencies=[Depends(get_current_user_id)]) Args: credentials: HTTP Bearer credentials from Authorization header Returns: user_id: Current user's ID Raises: HTTPException: If token is invalid """ token = credentials.credentials user_id = verify_access_token(token) return user_id # ===== Helper Functions ===== def create_tokens_response(user_id: int) -> dict: """ Create both access and refresh tokens for login response. Args: user_id: User ID from database Returns: Dict with access_token, refresh_token, token_type """ access_token = create_access_token({"sub": str(user_id)}) refresh_token = create_refresh_token(user_id) return { "access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer", "expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60 # seconds }