""" FastAPI authentication dependencies. Provides dependency functions for securing REST endpoints with Azure AD token verification and role-based access control. """ import logging from typing import Optional from fastapi import Depends, Header, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.models.database import get_db from app.models.models import User from app.repositories.user_repository import UserRepository from app.services.auth_service import verify_access_token logger = logging.getLogger(__name__) # Valid roles ordered by privilege level (for reference) VALID_ROLES = ("super_admin", "oversight_admin", "agency_admin", "basic_user") async def get_current_user(authorization: Optional[str] = Header(None)) -> dict: """ FastAPI dependency to verify the access token and return user claims. Use as a dependency on protected endpoints: @app.get("/protected") async def protected_route(user: dict = Depends(get_current_user)): return {"message": f"Hello {user.get('name')}"} Args: authorization: The Authorization header value (Bearer ) Returns: The token claims dict containing user information Raises: HTTPException: 401 if token is missing or invalid """ logger.debug("[MSAL Backend] get_current_user dependency called") # If auth is disabled, return mock user immediately if settings.DISABLE_AUTH: logger.debug("[MSAL Backend] Auth disabled - returning mock user") return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"} if not authorization: logger.warning("[MSAL Backend] Missing authorization header") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing authorization header", headers={"WWW-Authenticate": "Bearer"}, ) logger.debug(f"[MSAL Backend] Authorization header present, length: {len(authorization)}") # Extract token from "Bearer " format parts = authorization.split() if len(parts) != 2 or parts[0].lower() != "bearer": logger.warning(f"[MSAL Backend] Invalid auth header format: {parts[0] if parts else 'empty'}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authorization header format. Expected: Bearer ", headers={"WWW-Authenticate": "Bearer"}, ) token = parts[1] logger.debug("[MSAL Backend] Extracted Bearer token, calling verify_access_token...") claims = await verify_access_token(token) if not claims: logger.warning("[MSAL Backend] Token verification failed - returning 401") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token", headers={"WWW-Authenticate": "Bearer"}, ) logger.debug(f"[MSAL Backend] Authentication successful for: {claims.get('name', 'unknown')}") return claims async def get_current_db_user( user_claims: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ) -> User: """ Resolve Azure AD claims to a full User ORM object with agency loaded. Creates the user on first login as basic_user with no agency. In dev mode (DISABLE_AUTH=true), auto-promotes the dev user to super_admin. """ user_repo = UserRepository(db) azure_oid = user_claims.get("oid") or user_claims.get("sub") if not azure_oid: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing user identifier in token claims", ) # Azure AD v1 access tokens use 'upn'; v2/ID tokens use 'email' or 'preferred_username' email = ( user_claims.get("email") or user_claims.get("preferred_username") or user_claims.get("upn") or "" ) logger.debug(f"[Auth] Resolved email='{email}' from claims keys: {list(user_claims.keys())}") user = await user_repo.get_or_create_from_azure( azure_ad_oid=azure_oid, email=email, name=user_claims.get("name", "Unknown"), ) # Dev mode: auto-promote to super_admin so all features are accessible if settings.DISABLE_AUTH and user.role != "super_admin": user.role = "super_admin" await db.flush() return user def require_role(*allowed_roles: str): """ Dependency factory that restricts access to users with specific roles. Usage: @router.get("/admin-only") async def admin_route(user: User = Depends(require_role("super_admin"))): ... """ async def _check_role(current_user: User = Depends(get_current_db_user)) -> User: if current_user.role not in allowed_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Access denied. Required role: {', '.join(allowed_roles)}", ) return current_user return _check_role async def require_write_access( current_user: User = Depends(get_current_db_user), ) -> User: """Dependency for write/mutation operations.""" return current_user