Add 4-tier RBAC backend: auth dependencies, role enforcement, agency filtering

- Add CHECK constraint migration for users.role (super_admin, oversight_admin, agency_admin, basic_user)
- Add get_current_db_user dependency resolving Azure claims to User ORM with agency
- Add require_role() factory and require_write_access() dependency
- Auto-promote dev user to super_admin when DISABLE_AUTH=true
- Add /api/me, PUT /api/users/{id}, POST /api/agencies endpoints
- Apply agency-based data filtering on campaigns, analytics, audit routes
- Block oversight_admin from all mutation routes (campaigns, proofs, flags, resolves)
- Restrict dropdown option mutations to super_admin only
- Add role check in WebSocket handler to block oversight_admin from analysis
- Add CurrentUserResponse, UserUpdate, AgencyCreate schemas

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
michael 2026-02-19 08:28:23 -06:00
parent 6bd8a03a15
commit d21036a0de
6 changed files with 422 additions and 104 deletions

View file

@ -0,0 +1,35 @@
"""Add CHECK constraint on users.role for RBAC roles
Revision ID: 007_add_role_check_constraint
Revises: 006_add_knowledge_base
Create Date: 2026-02-19
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "007_add_role_check_constraint"
down_revision: Union[str, None] = "006_add_knowledge_base"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
VALID_ROLES = ("super_admin", "oversight_admin", "agency_admin", "basic_user")
def upgrade() -> None:
# Normalise any legacy role values to basic_user before adding constraint
op.execute(
f"UPDATE users SET role = 'basic_user' "
f"WHERE role NOT IN {VALID_ROLES!r}"
)
op.create_check_constraint(
"ck_users_role",
"users",
f"role IN {VALID_ROLES!r}",
)
def downgrade() -> None:
op.drop_constraint("ck_users_role", "users", type_="check")

View file

@ -22,11 +22,20 @@ from app.api.schemas import (
AnalyticsResponse,
DropdownOptionsResponse,
AgencyResponse,
AgencyCreate,
CurrentUserResponse,
UserResponse,
UserUpdate,
SupportEmailRequest,
)
from app.dependencies.auth import get_current_user
from app.dependencies.auth import (
get_current_user,
get_current_db_user,
require_role,
require_write_access,
)
from app.models.database import get_db
from app.models.models import User
from app.repositories import (
CampaignRepository,
ProofRepository,
@ -41,36 +50,65 @@ from app.services.pdf_service import pdf_service
router = APIRouter()
# Helper to get user from DB based on Azure claims
async def get_db_user(
session: AsyncSession,
user_claims: dict,
) -> Optional[uuid.UUID]:
"""Get or create user from Azure AD claims and return user ID."""
user_repo = UserRepository(session)
azure_oid = user_claims.get("oid") or user_claims.get("sub")
if not azure_oid:
return None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
user = await user_repo.get_or_create_from_azure(
azure_ad_oid=azure_oid,
email=user_claims.get("email", user_claims.get("preferred_username", "")),
name=user_claims.get("name", "Unknown"),
def _resolve_agency_filter(current_user: User, agency_id_param: Optional[uuid.UUID] = None) -> Optional[uuid.UUID]:
"""Determine which agency_id to filter by based on user role and query param.
- super_admin / oversight_admin: use agency_id_param if provided, else None (all)
- agency_admin / basic_user: forced to current_user.agency_id
"""
if current_user.role in ("super_admin", "oversight_admin"):
return agency_id_param # None means "all"
return current_user.agency_id # may be None (user not assigned yet)
def _check_campaign_access(current_user: User, campaign) -> None:
"""Raise 404 if the user's role restricts them and the campaign doesn't belong to their agency."""
if current_user.role in ("super_admin", "oversight_admin"):
return
if current_user.agency_id and campaign.agency_id == current_user.agency_id:
return
raise HTTPException(status_code=404, detail="Campaign not found")
# ---------------------------------------------------------------------------
# Current User endpoint
# ---------------------------------------------------------------------------
@router.get("/me", response_model=CurrentUserResponse)
async def get_me(
current_user: User = Depends(get_current_db_user),
):
"""Get the authenticated user's profile."""
return CurrentUserResponse(
id=current_user.id,
email=current_user.email,
name=current_user.name,
role=current_user.role,
agency_id=current_user.agency_id,
agency_name=current_user.agency.name if current_user.agency else None,
)
return user.id
# ---------------------------------------------------------------------------
# Campaign endpoints
# ---------------------------------------------------------------------------
@router.get("/campaigns", response_model=list[CampaignResponse])
async def list_campaigns(
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
):
"""List all campaigns."""
"""List campaigns, filtered by the user's role and optional agency filter."""
effective_agency_id = _resolve_agency_filter(current_user, agency_id)
repo = CampaignRepository(db)
campaigns_with_counts = await repo.get_with_proof_counts()
campaigns_with_counts = await repo.get_with_proof_counts(agency_id=effective_agency_id)
return [
CampaignResponse(
@ -94,10 +132,9 @@ async def list_campaigns(
async def create_campaign(
data: CampaignCreate,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_write_access),
):
"""Create a new campaign."""
user_id = await get_db_user(db, user)
"""Create a new campaign. Blocked for oversight_admin."""
repo = CampaignRepository(db)
# Check if campaign name already exists
@ -111,7 +148,8 @@ async def create_campaign(
client_lead=data.client_lead,
agency_lead=data.agency_lead,
brand_guidelines=data.brand_guidelines,
created_by=user_id,
agency_id=current_user.agency_id,
created_by=current_user.id,
)
return CampaignResponse(
@ -122,7 +160,7 @@ async def create_campaign(
agency_lead=campaign.agency_lead,
brand_guidelines=campaign.brand_guidelines,
status=campaign.status,
agency=None,
agency=current_user.agency.name if current_user.agency else None,
created_at=campaign.created_at,
updated_at=campaign.updated_at,
proofs=0,
@ -133,7 +171,7 @@ async def create_campaign(
async def get_campaign(
campaign_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
):
"""Get a campaign by ID."""
repo = CampaignRepository(db)
@ -141,6 +179,8 @@ async def get_campaign(
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
_check_campaign_access(current_user, campaign)
return CampaignResponse(
id=campaign.id,
name=campaign.name,
@ -161,10 +201,17 @@ async def update_campaign(
campaign_id: uuid.UUID,
data: CampaignUpdate,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_write_access),
):
"""Update a campaign."""
"""Update a campaign. Blocked for oversight_admin."""
repo = CampaignRepository(db)
# Verify campaign exists and user has access
existing = await repo.get_by_id(campaign_id)
if not existing:
raise HTTPException(status_code=404, detail="Campaign not found")
_check_campaign_access(current_user, existing)
campaign = await repo.update(
campaign_id,
name=data.name,
@ -196,34 +243,44 @@ async def update_campaign(
async def delete_campaign(
campaign_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_write_access),
):
"""Delete a campaign and all associated files."""
"""Delete a campaign and all associated files. Blocked for oversight_admin."""
repo = CampaignRepository(db)
# Get campaign with proofs and versions to extract file keys
campaign = await repo.get_by_id(campaign_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
_check_campaign_access(current_user, campaign)
# Delete files from storage for all proofs
for proof in campaign.proofs:
for version in proof.versions:
if version.file_storage_key:
await storage_service.delete_file(version.file_storage_key)
# Delete database records (cascades to proofs and versions)
await repo.delete(campaign_id)
# ---------------------------------------------------------------------------
# Proof endpoints
# ---------------------------------------------------------------------------
@router.get("/campaigns/{campaign_id}/proofs", response_model=list[ProofResponse])
async def list_proofs(
campaign_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
):
"""List all proofs for a campaign."""
# Verify campaign access
campaign_repo = CampaignRepository(db)
campaign = await campaign_repo.get_by_id(campaign_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
_check_campaign_access(current_user, campaign)
repo = ProofRepository(db)
proofs = await repo.list_by_campaign(campaign_id)
@ -259,7 +316,7 @@ async def list_proofs(
async def get_proof(
proof_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
):
"""Get a proof by ID."""
repo = ProofRepository(db)
@ -296,12 +353,11 @@ async def get_proof(
async def delete_proof(
proof_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_write_access),
):
"""Delete a proof and its associated files."""
"""Delete a proof and its associated files. Blocked for oversight_admin."""
repo = ProofRepository(db)
# Get proof with versions to extract file keys
proof = await repo.get_by_id(proof_id)
if not proof:
raise HTTPException(status_code=404, detail="Proof not found")
@ -311,25 +367,25 @@ async def delete_proof(
if version.file_storage_key:
await storage_service.delete_file(version.file_storage_key)
# Delete database records
await repo.delete(proof_id)
# ---------------------------------------------------------------------------
# Audit endpoints
# ---------------------------------------------------------------------------
@router.post("/proofs/{proof_id}/versions/{version}/flag", response_model=FlaggedItemResponse, status_code=201)
async def flag_proof_version(
proof_id: uuid.UUID,
version: int,
data: FlaggedItemCreate,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_write_access),
):
"""Flag an issue on a proof version."""
user_id = await get_db_user(db, user)
"""Flag an issue on a proof version. Blocked for oversight_admin."""
proof_repo = ProofRepository(db)
audit_repo = AuditRepository(db)
# Get the proof version
proof_version = await proof_repo.get_version(proof_id, version)
if not proof_version:
raise HTTPException(status_code=404, detail="Proof version not found")
@ -338,10 +394,9 @@ async def flag_proof_version(
proof_version_id=proof_version.id,
agent_flagged=data.agent_flagged,
comments=data.comments,
submitter_id=user_id,
submitter_id=current_user.id,
)
# Get related data for response
proof = await proof_repo.get_by_id(proof_id)
return FlaggedItemResponse(
@ -349,8 +404,8 @@ async def flag_proof_version(
proof_version_id=flagged.proof_version_id,
agent_flagged=flagged.agent_flagged,
comments=flagged.comments,
submitter_name=user.get("name"),
submitter_agency=None,
submitter_name=current_user.name,
submitter_agency=current_user.agency.name if current_user.agency else None,
campaign_name=proof.campaign.name if proof and proof.campaign else None,
proof_name=proof.proof_name if proof else None,
version=version,
@ -364,14 +419,12 @@ async def resolve_proof_version(
version: int,
data: ResolvedItemCreate,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_write_access),
):
"""Resolve an issue on a proof version."""
user_id = await get_db_user(db, user)
"""Resolve an issue on a proof version. Blocked for oversight_admin."""
proof_repo = ProofRepository(db)
audit_repo = AuditRepository(db)
# Get the proof version
proof_version = await proof_repo.get_version(proof_id, version)
if not proof_version:
raise HTTPException(status_code=404, detail="Proof version not found")
@ -381,10 +434,9 @@ async def resolve_proof_version(
agent=data.agent,
issue=data.issue,
resolution=data.resolution,
submitter_id=user_id,
submitter_id=current_user.id,
)
# Get related data for response
proof = await proof_repo.get_by_id(proof_id)
return ResolvedItemResponse(
@ -393,8 +445,8 @@ async def resolve_proof_version(
agent=resolved.agent,
issue=resolved.issue,
resolution=resolved.resolution,
submitter_name=user.get("name"),
submitter_agency=None,
submitter_name=current_user.name,
submitter_agency=current_user.agency.name if current_user.agency else None,
campaign_name=proof.campaign.name if proof and proof.campaign else None,
proof_name=proof.proof_name if proof else None,
version=version,
@ -405,13 +457,15 @@ async def resolve_proof_version(
@router.get("/audit/flagged", response_model=list[FlaggedItemResponse])
async def list_flagged_items(
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
):
"""List all flagged items."""
"""List flagged items, filtered by role."""
effective_agency_id = _resolve_agency_filter(current_user, agency_id)
audit_repo = AuditRepository(db)
flagged_items = await audit_repo.get_flagged_items(limit=limit, offset=offset)
flagged_items = await audit_repo.get_flagged_items(agency_id=effective_agency_id, limit=limit, offset=offset)
return [
FlaggedItemResponse(
@ -433,13 +487,15 @@ async def list_flagged_items(
@router.get("/audit/resolved", response_model=list[ResolvedItemResponse])
async def list_resolved_items(
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
):
"""List all resolved items."""
"""List resolved items, filtered by role."""
effective_agency_id = _resolve_agency_filter(current_user, agency_id)
audit_repo = AuditRepository(db)
resolved_items = await audit_repo.get_resolved_items(limit=limit, offset=offset)
resolved_items = await audit_repo.get_resolved_items(agency_id=effective_agency_id, limit=limit, offset=offset)
return [
ResolvedItemResponse(
@ -462,13 +518,15 @@ async def list_resolved_items(
@router.get("/audit/errors", response_model=list[ErrorItemResponse])
async def list_error_items(
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
):
"""List all error items."""
"""List error items, filtered by role."""
effective_agency_id = _resolve_agency_filter(current_user, agency_id)
audit_repo = AuditRepository(db)
error_items = await audit_repo.get_error_items(limit=limit, offset=offset)
error_items = await audit_repo.get_error_items(agency_id=effective_agency_id, limit=limit, offset=offset)
return [
ErrorItemResponse(
@ -486,25 +544,33 @@ async def list_error_items(
]
# ---------------------------------------------------------------------------
# Analytics endpoint
# ---------------------------------------------------------------------------
@router.get("/analytics", response_model=AnalyticsResponse)
async def get_analytics(
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"),
):
"""Get analytics data."""
"""Get analytics data, filtered by role."""
effective_agency_id = _resolve_agency_filter(current_user, agency_id)
repo = CampaignRepository(db)
analytics = await repo.get_analytics()
analytics = await repo.get_analytics(agency_id=effective_agency_id)
return AnalyticsResponse(**analytics)
# Users endpoint (admin only)
# ---------------------------------------------------------------------------
# User Management endpoints
# ---------------------------------------------------------------------------
@router.get("/users", response_model=list[UserResponse])
async def list_users(
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_role("super_admin")),
):
"""List all users (admin only)."""
"""List all users (super_admin only)."""
user_repo = UserRepository(db)
users = await user_repo.list_all()
@ -515,17 +581,103 @@ async def list_users(
name=u.name,
role=u.role,
agency=u.agency.name if u.agency else None,
agency_id=u.agency_id,
created_at=u.created_at,
)
for u in users
]
@router.put("/users/{user_id}", response_model=UserResponse)
async def update_user(
user_id: uuid.UUID,
data: UserUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_role("super_admin")),
):
"""Update a user's role and/or agency (super_admin only)."""
user_repo = UserRepository(db)
# Build kwargs, using sentinel to distinguish "not provided" from None
kwargs: dict = {}
if data.role is not None:
valid_roles = ("super_admin", "oversight_admin", "agency_admin", "basic_user")
if data.role not in valid_roles:
raise HTTPException(status_code=400, detail=f"Invalid role. Must be one of: {', '.join(valid_roles)}")
kwargs["role"] = data.role
# agency_id can be explicitly set to None (unassign) or a UUID
if "agency_id" in (data.model_fields_set or set()):
kwargs["agency_id"] = data.agency_id
else:
kwargs["agency_id"] = ... # sentinel: "not provided"
user = await user_repo.update_user(user_id, **kwargs)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return UserResponse(
id=user.id,
email=user.email,
name=user.name,
role=user.role,
agency=user.agency.name if user.agency else None,
agency_id=user.agency_id,
created_at=user.created_at,
)
# ---------------------------------------------------------------------------
# Agency endpoints
# ---------------------------------------------------------------------------
@router.get("/agencies", response_model=list[AgencyResponse])
async def list_agencies(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_db_user),
):
"""List all agencies."""
from sqlalchemy import select
from app.models.models import Agency
stmt = select(Agency).order_by(Agency.name)
result = await db.execute(stmt)
agencies = result.scalars().all()
return [AgencyResponse(id=a.id, name=a.name) for a in agencies]
@router.post("/agencies", response_model=AgencyResponse, status_code=201)
async def create_agency(
data: AgencyCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_role("super_admin")),
):
"""Create a new agency (super_admin only)."""
user_repo = UserRepository(db)
# Check for duplicate name
existing = await user_repo.get_or_create_agency(data.name)
# get_or_create returns existing — check if it was just created
# A simpler approach: try to create and catch unique violation
from sqlalchemy import select
from app.models.models import Agency
result = await db.execute(select(Agency).where(Agency.name == data.name))
agency = result.scalar_one_or_none()
if agency:
return AgencyResponse(id=agency.id, name=agency.name)
agency = await user_repo.create_agency(data.name)
return AgencyResponse(id=agency.id, name=agency.name)
# ---------------------------------------------------------------------------
# Dropdown options endpoints
# ---------------------------------------------------------------------------
@router.get("/dropdown-options", response_model=DropdownOptionsResponse)
async def get_dropdown_options(
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
):
"""Get all dropdown options as hierarchical structure."""
import logging
@ -534,7 +686,6 @@ async def get_dropdown_options(
repo = DropdownRepository(db)
options = await repo.get_all_hierarchical()
# Debug logging
channels = options.get("channels", {})
social = channels.get("Social", {})
meta_proof_types = social.get("Meta", [])
@ -547,9 +698,9 @@ async def get_dropdown_options(
async def add_channel(
name: str = Query(..., description="Channel name"),
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_role("super_admin")),
):
"""Add a new channel."""
"""Add a new channel (super_admin only)."""
repo = DropdownRepository(db)
await repo.add_channel(name)
await db.commit()
@ -561,9 +712,9 @@ async def add_sub_channel(
channel: str,
name: str = Query(..., description="Sub-channel name"),
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_role("super_admin")),
):
"""Add a sub-channel under a channel."""
"""Add a sub-channel under a channel (super_admin only)."""
repo = DropdownRepository(db)
result = await repo.add_sub_channel(channel, name)
if not result:
@ -578,9 +729,9 @@ async def add_proof_type(
sub_channel: str,
name: str = Query(..., description="Proof type name"),
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_role("super_admin")),
):
"""Add a proof type under a sub-channel."""
"""Add a proof type under a sub-channel (super_admin only)."""
repo = DropdownRepository(db)
result = await repo.add_proof_type(channel, sub_channel, name)
if not result:
@ -593,9 +744,9 @@ async def add_proof_type(
async def delete_channel(
channel: str,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_role("super_admin")),
):
"""Delete a channel and all its sub-channels and proof types."""
"""Delete a channel and all its sub-channels and proof types (super_admin only)."""
repo = DropdownRepository(db)
success = await repo.remove_channel(channel)
if not success:
@ -608,9 +759,9 @@ async def delete_sub_channel(
channel: str,
sub_channel: str,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_role("super_admin")),
):
"""Delete a sub-channel and all its proof types."""
"""Delete a sub-channel and all its proof types (super_admin only)."""
repo = DropdownRepository(db)
success = await repo.remove_sub_channel(channel, sub_channel)
if not success:
@ -624,9 +775,9 @@ async def delete_proof_type(
sub_channel: str,
proof_type: str,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(require_role("super_admin")),
):
"""Delete a proof type."""
"""Delete a proof type (super_admin only)."""
repo = DropdownRepository(db)
success = await repo.remove_proof_type(channel, sub_channel, proof_type)
if not success:
@ -634,22 +785,9 @@ async def delete_proof_type(
await db.commit()
# Agency endpoints
@router.get("/agencies", response_model=list[AgencyResponse])
async def list_agencies(
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
):
"""List all agencies."""
from sqlalchemy import select
from app.models.models import Agency
stmt = select(Agency).order_by(Agency.name)
result = await db.execute(stmt)
agencies = result.scalars().all()
return [AgencyResponse(id=a.id, name=a.name) for a in agencies]
# ---------------------------------------------------------------------------
# File endpoints
# ---------------------------------------------------------------------------
# PDF pages endpoint (must be defined BEFORE the base file endpoint for correct routing)
@router.get("/files/{storage_key:path}/pages")
@ -657,7 +795,7 @@ async def get_pdf_pages(
storage_key: str,
max_pages: int = Query(10, ge=1, le=50),
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
):
"""Rasterize a stored PDF and return pages as data URLs."""
if not storage_key.lower().endswith('.pdf'):
@ -690,14 +828,13 @@ async def get_pdf_pages(
async def get_file(
storage_key: str,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
current_user: User = Depends(get_current_db_user),
):
"""Retrieve a stored file by its storage key."""
file_data = await storage_service.get_file(storage_key)
if file_data is None:
raise HTTPException(status_code=404, detail="File not found")
# Determine content type from extension
extension = storage_key.split('.')[-1].lower() if '.' in storage_key else ''
content_types = {
'png': 'image/png',
@ -714,7 +851,10 @@ async def get_file(
)
# ---------------------------------------------------------------------------
# Support email endpoint (public - no auth required for login page access)
# ---------------------------------------------------------------------------
@router.post("/support/email")
async def send_support_email(
data: SupportEmailRequest,

View file

@ -164,18 +164,43 @@ class AgencyResponse(BaseModel):
# User schemas
class CurrentUserResponse(BaseModel):
"""Response for /api/me - the authenticated user's own profile."""
id: uuid.UUID
email: str
name: str
role: str
agency_id: Optional[uuid.UUID]
agency_name: Optional[str]
class Config:
from_attributes = True
class UserResponse(BaseModel):
id: uuid.UUID
email: str
name: str
role: str
agency: Optional[str]
agency_id: Optional[uuid.UUID] = None
created_at: datetime
class Config:
from_attributes = True
class UserUpdate(BaseModel):
"""Request body for updating a user's role and/or agency."""
role: Optional[str] = None
agency_id: Optional[uuid.UUID] = None
class AgencyCreate(BaseModel):
"""Request body for creating a new agency."""
name: str
# Support email schemas
class SupportEmailRequest(BaseModel):
message: str

View file

@ -1,17 +1,26 @@
"""
FastAPI authentication dependencies.
Provides dependency functions for securing REST endpoints with Azure AD token verification.
Provides dependency functions for securing REST endpoints with Azure AD token verification
and role-based access control.
"""
import logging
from typing import Optional
from fastapi import Header, HTTPException, status
from fastapi import Depends, Header, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.database import get_db
from app.models.models import User
from app.repositories.user_repository import UserRepository
from app.services.auth_service import verify_access_token
logger = logging.getLogger(__name__)
# Valid roles ordered by privilege level (for reference)
VALID_ROLES = ("super_admin", "oversight_admin", "agency_admin", "basic_user")
async def get_current_user(authorization: Optional[str] = Header(None)) -> dict:
"""
@ -72,3 +81,70 @@ async def get_current_user(authorization: Optional[str] = Header(None)) -> dict:
logger.debug(f"[MSAL Backend] Authentication successful for: {claims.get('name', 'unknown')}")
return claims
async def get_current_db_user(
user_claims: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> User:
"""
Resolve Azure AD claims to a full User ORM object with agency loaded.
Creates the user on first login as basic_user with no agency.
In dev mode (DISABLE_AUTH=true), auto-promotes the dev user to super_admin.
"""
user_repo = UserRepository(db)
azure_oid = user_claims.get("oid") or user_claims.get("sub")
if not azure_oid:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing user identifier in token claims",
)
user = await user_repo.get_or_create_from_azure(
azure_ad_oid=azure_oid,
email=user_claims.get("email", user_claims.get("preferred_username", "")),
name=user_claims.get("name", "Unknown"),
)
# Dev mode: auto-promote to super_admin so all features are accessible
if settings.DISABLE_AUTH and user.role != "super_admin":
user.role = "super_admin"
await db.flush()
return user
def require_role(*allowed_roles: str):
"""
Dependency factory that restricts access to users with specific roles.
Usage:
@router.get("/admin-only")
async def admin_route(user: User = Depends(require_role("super_admin"))):
...
"""
async def _check_role(current_user: User = Depends(get_current_db_user)) -> User:
if current_user.role not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied. Required role: {', '.join(allowed_roles)}",
)
return current_user
return _check_role
async def require_write_access(
current_user: User = Depends(get_current_db_user),
) -> User:
"""
Dependency that blocks oversight_admin from write/mutation operations.
All other roles are allowed through.
"""
if current_user.role == "oversight_admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Oversight Admin has read-only access",
)
return current_user

View file

@ -8,7 +8,8 @@ from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.services.auth_service import verify_access_token
from app.dependencies.auth import get_current_user
from app.models.database import init_db, close_db
from app.models.database import init_db, close_db, async_session_factory as _session_factory
from app.repositories.user_repository import UserRepository
from app.api import router as api_router, kb_router
# Configure logging
@ -193,6 +194,21 @@ async def websocket_analyze(websocket: WebSocket):
logger.info(f"[MAIN] Authenticated user: {user_claims.get('name', 'unknown')}")
# Check role: oversight_admin cannot upload/analyze proofs
try:
async with _session_factory() as ws_session:
ws_user_repo = UserRepository(ws_session)
azure_oid = user_claims.get("oid") or user_claims.get("sub")
ws_user = await ws_user_repo.get_by_azure_oid(azure_oid) if azure_oid else None
if ws_user and ws_user.role == "oversight_admin":
await manager.send_message(client_id, {
"type": "error",
"message": "Oversight Admin has read-only access and cannot analyze proofs."
})
continue
except Exception as role_err:
logger.warning(f"[MAIN] Role check failed for client {client_id}: {role_err}")
if analysis_service is None:
logger.error("[MAIN] Analysis service not ready")
await manager.send_message(client_id, {

View file

@ -89,7 +89,33 @@ class UserRepository:
await self.session.flush()
return user
async def update_user(
self,
user_id: uuid.UUID,
role: Optional[str] = None,
agency_id: Optional[uuid.UUID] = ..., # type: ignore[assignment]
) -> Optional[User]:
"""Update user role and/or agency. Pass agency_id=None to unassign."""
user = await self.get_by_id(user_id)
if not user:
return None
if role is not None:
user.role = role
# Use sentinel (...) to distinguish "not provided" from "set to None"
if agency_id is not ...:
user.agency_id = agency_id
await self.session.flush()
# Re-fetch to get agency relationship loaded
return await self.get_by_id(user_id)
async def list_agencies(self) -> list[Agency]:
"""List all agencies."""
result = await self.session.execute(select(Agency))
return list(result.scalars().all())
async def create_agency(self, name: str) -> Agency:
"""Create a new agency."""
agency = Agency(name=name)
self.session.add(agency)
await self.session.flush()
return agency