Azure AD v1 access tokens (sts.windows.net issuer) use the 'upn' claim for the user principal name/email, not 'email' or 'preferred_username'. Add 'upn' as a fallback so email is correctly resolved on login. Also add debug logging to show which claims are present. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
150 lines
5.2 KiB
Python
Executable file
150 lines
5.2 KiB
Python
Executable file
"""
|
|
FastAPI authentication dependencies.
|
|
|
|
Provides dependency functions for securing REST endpoints with Azure AD token verification
|
|
and role-based access control.
|
|
"""
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from fastapi import Depends, Header, HTTPException, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.config import settings
|
|
from app.models.database import get_db
|
|
from app.models.models import User
|
|
from app.repositories.user_repository import UserRepository
|
|
from app.services.auth_service import verify_access_token
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Valid roles ordered by privilege level (for reference)
|
|
VALID_ROLES = ("super_admin", "oversight_admin", "agency_admin", "basic_user")
|
|
|
|
|
|
async def get_current_user(authorization: Optional[str] = Header(None)) -> dict:
|
|
"""
|
|
FastAPI dependency to verify the access token and return user claims.
|
|
|
|
Use as a dependency on protected endpoints:
|
|
@app.get("/protected")
|
|
async def protected_route(user: dict = Depends(get_current_user)):
|
|
return {"message": f"Hello {user.get('name')}"}
|
|
|
|
Args:
|
|
authorization: The Authorization header value (Bearer <token>)
|
|
|
|
Returns:
|
|
The token claims dict containing user information
|
|
|
|
Raises:
|
|
HTTPException: 401 if token is missing or invalid
|
|
"""
|
|
logger.debug("[MSAL Backend] get_current_user dependency called")
|
|
|
|
# If auth is disabled, return mock user immediately
|
|
if settings.DISABLE_AUTH:
|
|
logger.debug("[MSAL Backend] Auth disabled - returning mock user")
|
|
return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"}
|
|
|
|
if not authorization:
|
|
logger.warning("[MSAL Backend] Missing authorization header")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing authorization header",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
logger.debug(f"[MSAL Backend] Authorization header present, length: {len(authorization)}")
|
|
|
|
# Extract token from "Bearer <token>" format
|
|
parts = authorization.split()
|
|
if len(parts) != 2 or parts[0].lower() != "bearer":
|
|
logger.warning(f"[MSAL Backend] Invalid auth header format: {parts[0] if parts else 'empty'}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid authorization header format. Expected: Bearer <token>",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
token = parts[1]
|
|
logger.debug("[MSAL Backend] Extracted Bearer token, calling verify_access_token...")
|
|
claims = await verify_access_token(token)
|
|
|
|
if not claims:
|
|
logger.warning("[MSAL Backend] Token verification failed - returning 401")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or expired token",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
logger.debug(f"[MSAL Backend] Authentication successful for: {claims.get('name', 'unknown')}")
|
|
return claims
|
|
|
|
|
|
async def get_current_db_user(
|
|
user_claims: dict = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> User:
|
|
"""
|
|
Resolve Azure AD claims to a full User ORM object with agency loaded.
|
|
|
|
Creates the user on first login as basic_user with no agency.
|
|
In dev mode (DISABLE_AUTH=true), auto-promotes the dev user to super_admin.
|
|
"""
|
|
user_repo = UserRepository(db)
|
|
azure_oid = user_claims.get("oid") or user_claims.get("sub")
|
|
if not azure_oid:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing user identifier in token claims",
|
|
)
|
|
|
|
# Azure AD v1 access tokens use 'upn'; v2/ID tokens use 'email' or 'preferred_username'
|
|
email = (
|
|
user_claims.get("email")
|
|
or user_claims.get("preferred_username")
|
|
or user_claims.get("upn")
|
|
or ""
|
|
)
|
|
logger.debug(f"[Auth] Resolved email='{email}' from claims keys: {list(user_claims.keys())}")
|
|
user = await user_repo.get_or_create_from_azure(
|
|
azure_ad_oid=azure_oid,
|
|
email=email,
|
|
name=user_claims.get("name", "Unknown"),
|
|
)
|
|
|
|
# Dev mode: auto-promote to super_admin so all features are accessible
|
|
if settings.DISABLE_AUTH and user.role != "super_admin":
|
|
user.role = "super_admin"
|
|
await db.flush()
|
|
|
|
return user
|
|
|
|
|
|
def require_role(*allowed_roles: str):
|
|
"""
|
|
Dependency factory that restricts access to users with specific roles.
|
|
|
|
Usage:
|
|
@router.get("/admin-only")
|
|
async def admin_route(user: User = Depends(require_role("super_admin"))):
|
|
...
|
|
"""
|
|
async def _check_role(current_user: User = Depends(get_current_db_user)) -> User:
|
|
if current_user.role not in allowed_roles:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Access denied. Required role: {', '.join(allowed_roles)}",
|
|
)
|
|
return current_user
|
|
|
|
return _check_role
|
|
|
|
|
|
async def require_write_access(
|
|
current_user: User = Depends(get_current_db_user),
|
|
) -> User:
|
|
"""Dependency for write/mutation operations."""
|
|
return current_user
|