From d21036a0de6e6210a6aabe597f35aa36dc2368b9 Mon Sep 17 00:00:00 2001 From: michael Date: Thu, 19 Feb 2026 08:28:23 -0600 Subject: [PATCH] Add 4-tier RBAC backend: auth dependencies, role enforcement, agency filtering - Add CHECK constraint migration for users.role (super_admin, oversight_admin, agency_admin, basic_user) - Add get_current_db_user dependency resolving Azure claims to User ORM with agency - Add require_role() factory and require_write_access() dependency - Auto-promote dev user to super_admin when DISABLE_AUTH=true - Add /api/me, PUT /api/users/{id}, POST /api/agencies endpoints - Apply agency-based data filtering on campaigns, analytics, audit routes - Block oversight_admin from all mutation routes (campaigns, proofs, flags, resolves) - Restrict dropdown option mutations to super_admin only - Add role check in WebSocket handler to block oversight_admin from analysis - Add CurrentUserResponse, UserUpdate, AgencyCreate schemas Co-Authored-By: Claude Opus 4.6 --- .../versions/007_add_role_check_constraint.py | 35 ++ backend/app/api/routes.py | 342 ++++++++++++------ backend/app/api/schemas.py | 25 ++ backend/app/dependencies/auth.py | 80 +++- backend/app/main.py | 18 +- backend/app/repositories/user_repository.py | 26 ++ 6 files changed, 422 insertions(+), 104 deletions(-) create mode 100644 backend/alembic/versions/007_add_role_check_constraint.py diff --git a/backend/alembic/versions/007_add_role_check_constraint.py b/backend/alembic/versions/007_add_role_check_constraint.py new file mode 100644 index 0000000..10b21a9 --- /dev/null +++ b/backend/alembic/versions/007_add_role_check_constraint.py @@ -0,0 +1,35 @@ +"""Add CHECK constraint on users.role for RBAC roles + +Revision ID: 007_add_role_check_constraint +Revises: 006_add_knowledge_base +Create Date: 2026-02-19 + +""" +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "007_add_role_check_constraint" +down_revision: Union[str, None] = "006_add_knowledge_base" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +VALID_ROLES = ("super_admin", "oversight_admin", "agency_admin", "basic_user") + + +def upgrade() -> None: + # Normalise any legacy role values to basic_user before adding constraint + op.execute( + f"UPDATE users SET role = 'basic_user' " + f"WHERE role NOT IN {VALID_ROLES!r}" + ) + op.create_check_constraint( + "ck_users_role", + "users", + f"role IN {VALID_ROLES!r}", + ) + + +def downgrade() -> None: + op.drop_constraint("ck_users_role", "users", type_="check") diff --git a/backend/app/api/routes.py b/backend/app/api/routes.py index 728b39e..9b30007 100755 --- a/backend/app/api/routes.py +++ b/backend/app/api/routes.py @@ -22,11 +22,20 @@ from app.api.schemas import ( AnalyticsResponse, DropdownOptionsResponse, AgencyResponse, + AgencyCreate, + CurrentUserResponse, UserResponse, + UserUpdate, SupportEmailRequest, ) -from app.dependencies.auth import get_current_user +from app.dependencies.auth import ( + get_current_user, + get_current_db_user, + require_role, + require_write_access, +) from app.models.database import get_db +from app.models.models import User from app.repositories import ( CampaignRepository, ProofRepository, @@ -41,36 +50,65 @@ from app.services.pdf_service import pdf_service router = APIRouter() -# Helper to get user from DB based on Azure claims -async def get_db_user( - session: AsyncSession, - user_claims: dict, -) -> Optional[uuid.UUID]: - """Get or create user from Azure AD claims and return user ID.""" - user_repo = UserRepository(session) - azure_oid = user_claims.get("oid") or user_claims.get("sub") - if not azure_oid: - return None +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- - user = await user_repo.get_or_create_from_azure( - azure_ad_oid=azure_oid, - email=user_claims.get("email", user_claims.get("preferred_username", "")), - name=user_claims.get("name", "Unknown"), +def _resolve_agency_filter(current_user: User, agency_id_param: Optional[uuid.UUID] = None) -> Optional[uuid.UUID]: + """Determine which agency_id to filter by based on user role and query param. + + - super_admin / oversight_admin: use agency_id_param if provided, else None (all) + - agency_admin / basic_user: forced to current_user.agency_id + """ + if current_user.role in ("super_admin", "oversight_admin"): + return agency_id_param # None means "all" + return current_user.agency_id # may be None (user not assigned yet) + + +def _check_campaign_access(current_user: User, campaign) -> None: + """Raise 404 if the user's role restricts them and the campaign doesn't belong to their agency.""" + if current_user.role in ("super_admin", "oversight_admin"): + return + if current_user.agency_id and campaign.agency_id == current_user.agency_id: + return + raise HTTPException(status_code=404, detail="Campaign not found") + + +# --------------------------------------------------------------------------- +# Current User endpoint +# --------------------------------------------------------------------------- + +@router.get("/me", response_model=CurrentUserResponse) +async def get_me( + current_user: User = Depends(get_current_db_user), +): + """Get the authenticated user's profile.""" + return CurrentUserResponse( + id=current_user.id, + email=current_user.email, + name=current_user.name, + role=current_user.role, + agency_id=current_user.agency_id, + agency_name=current_user.agency.name if current_user.agency else None, ) - return user.id +# --------------------------------------------------------------------------- # Campaign endpoints +# --------------------------------------------------------------------------- + @router.get("/campaigns", response_model=list[CampaignResponse]) async def list_campaigns( db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), + agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"), limit: int = Query(100, ge=1, le=500), offset: int = Query(0, ge=0), ): - """List all campaigns.""" + """List campaigns, filtered by the user's role and optional agency filter.""" + effective_agency_id = _resolve_agency_filter(current_user, agency_id) repo = CampaignRepository(db) - campaigns_with_counts = await repo.get_with_proof_counts() + campaigns_with_counts = await repo.get_with_proof_counts(agency_id=effective_agency_id) return [ CampaignResponse( @@ -94,10 +132,9 @@ async def list_campaigns( async def create_campaign( data: CampaignCreate, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_write_access), ): - """Create a new campaign.""" - user_id = await get_db_user(db, user) + """Create a new campaign. Blocked for oversight_admin.""" repo = CampaignRepository(db) # Check if campaign name already exists @@ -111,7 +148,8 @@ async def create_campaign( client_lead=data.client_lead, agency_lead=data.agency_lead, brand_guidelines=data.brand_guidelines, - created_by=user_id, + agency_id=current_user.agency_id, + created_by=current_user.id, ) return CampaignResponse( @@ -122,7 +160,7 @@ async def create_campaign( agency_lead=campaign.agency_lead, brand_guidelines=campaign.brand_guidelines, status=campaign.status, - agency=None, + agency=current_user.agency.name if current_user.agency else None, created_at=campaign.created_at, updated_at=campaign.updated_at, proofs=0, @@ -133,7 +171,7 @@ async def create_campaign( async def get_campaign( campaign_id: uuid.UUID, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), ): """Get a campaign by ID.""" repo = CampaignRepository(db) @@ -141,6 +179,8 @@ async def get_campaign( if not campaign: raise HTTPException(status_code=404, detail="Campaign not found") + _check_campaign_access(current_user, campaign) + return CampaignResponse( id=campaign.id, name=campaign.name, @@ -161,10 +201,17 @@ async def update_campaign( campaign_id: uuid.UUID, data: CampaignUpdate, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_write_access), ): - """Update a campaign.""" + """Update a campaign. Blocked for oversight_admin.""" repo = CampaignRepository(db) + + # Verify campaign exists and user has access + existing = await repo.get_by_id(campaign_id) + if not existing: + raise HTTPException(status_code=404, detail="Campaign not found") + _check_campaign_access(current_user, existing) + campaign = await repo.update( campaign_id, name=data.name, @@ -196,34 +243,44 @@ async def update_campaign( async def delete_campaign( campaign_id: uuid.UUID, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_write_access), ): - """Delete a campaign and all associated files.""" + """Delete a campaign and all associated files. Blocked for oversight_admin.""" repo = CampaignRepository(db) - # Get campaign with proofs and versions to extract file keys campaign = await repo.get_by_id(campaign_id) if not campaign: raise HTTPException(status_code=404, detail="Campaign not found") + _check_campaign_access(current_user, campaign) + # Delete files from storage for all proofs for proof in campaign.proofs: for version in proof.versions: if version.file_storage_key: await storage_service.delete_file(version.file_storage_key) - # Delete database records (cascades to proofs and versions) await repo.delete(campaign_id) +# --------------------------------------------------------------------------- # Proof endpoints +# --------------------------------------------------------------------------- + @router.get("/campaigns/{campaign_id}/proofs", response_model=list[ProofResponse]) async def list_proofs( campaign_id: uuid.UUID, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), ): """List all proofs for a campaign.""" + # Verify campaign access + campaign_repo = CampaignRepository(db) + campaign = await campaign_repo.get_by_id(campaign_id) + if not campaign: + raise HTTPException(status_code=404, detail="Campaign not found") + _check_campaign_access(current_user, campaign) + repo = ProofRepository(db) proofs = await repo.list_by_campaign(campaign_id) @@ -259,7 +316,7 @@ async def list_proofs( async def get_proof( proof_id: uuid.UUID, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), ): """Get a proof by ID.""" repo = ProofRepository(db) @@ -296,12 +353,11 @@ async def get_proof( async def delete_proof( proof_id: uuid.UUID, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_write_access), ): - """Delete a proof and its associated files.""" + """Delete a proof and its associated files. Blocked for oversight_admin.""" repo = ProofRepository(db) - # Get proof with versions to extract file keys proof = await repo.get_by_id(proof_id) if not proof: raise HTTPException(status_code=404, detail="Proof not found") @@ -311,25 +367,25 @@ async def delete_proof( if version.file_storage_key: await storage_service.delete_file(version.file_storage_key) - # Delete database records await repo.delete(proof_id) +# --------------------------------------------------------------------------- # Audit endpoints +# --------------------------------------------------------------------------- + @router.post("/proofs/{proof_id}/versions/{version}/flag", response_model=FlaggedItemResponse, status_code=201) async def flag_proof_version( proof_id: uuid.UUID, version: int, data: FlaggedItemCreate, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_write_access), ): - """Flag an issue on a proof version.""" - user_id = await get_db_user(db, user) + """Flag an issue on a proof version. Blocked for oversight_admin.""" proof_repo = ProofRepository(db) audit_repo = AuditRepository(db) - # Get the proof version proof_version = await proof_repo.get_version(proof_id, version) if not proof_version: raise HTTPException(status_code=404, detail="Proof version not found") @@ -338,10 +394,9 @@ async def flag_proof_version( proof_version_id=proof_version.id, agent_flagged=data.agent_flagged, comments=data.comments, - submitter_id=user_id, + submitter_id=current_user.id, ) - # Get related data for response proof = await proof_repo.get_by_id(proof_id) return FlaggedItemResponse( @@ -349,8 +404,8 @@ async def flag_proof_version( proof_version_id=flagged.proof_version_id, agent_flagged=flagged.agent_flagged, comments=flagged.comments, - submitter_name=user.get("name"), - submitter_agency=None, + submitter_name=current_user.name, + submitter_agency=current_user.agency.name if current_user.agency else None, campaign_name=proof.campaign.name if proof and proof.campaign else None, proof_name=proof.proof_name if proof else None, version=version, @@ -364,14 +419,12 @@ async def resolve_proof_version( version: int, data: ResolvedItemCreate, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_write_access), ): - """Resolve an issue on a proof version.""" - user_id = await get_db_user(db, user) + """Resolve an issue on a proof version. Blocked for oversight_admin.""" proof_repo = ProofRepository(db) audit_repo = AuditRepository(db) - # Get the proof version proof_version = await proof_repo.get_version(proof_id, version) if not proof_version: raise HTTPException(status_code=404, detail="Proof version not found") @@ -381,10 +434,9 @@ async def resolve_proof_version( agent=data.agent, issue=data.issue, resolution=data.resolution, - submitter_id=user_id, + submitter_id=current_user.id, ) - # Get related data for response proof = await proof_repo.get_by_id(proof_id) return ResolvedItemResponse( @@ -393,8 +445,8 @@ async def resolve_proof_version( agent=resolved.agent, issue=resolved.issue, resolution=resolved.resolution, - submitter_name=user.get("name"), - submitter_agency=None, + submitter_name=current_user.name, + submitter_agency=current_user.agency.name if current_user.agency else None, campaign_name=proof.campaign.name if proof and proof.campaign else None, proof_name=proof.proof_name if proof else None, version=version, @@ -405,13 +457,15 @@ async def resolve_proof_version( @router.get("/audit/flagged", response_model=list[FlaggedItemResponse]) async def list_flagged_items( db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), + agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"), limit: int = Query(100, ge=1, le=500), offset: int = Query(0, ge=0), ): - """List all flagged items.""" + """List flagged items, filtered by role.""" + effective_agency_id = _resolve_agency_filter(current_user, agency_id) audit_repo = AuditRepository(db) - flagged_items = await audit_repo.get_flagged_items(limit=limit, offset=offset) + flagged_items = await audit_repo.get_flagged_items(agency_id=effective_agency_id, limit=limit, offset=offset) return [ FlaggedItemResponse( @@ -433,13 +487,15 @@ async def list_flagged_items( @router.get("/audit/resolved", response_model=list[ResolvedItemResponse]) async def list_resolved_items( db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), + agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"), limit: int = Query(100, ge=1, le=500), offset: int = Query(0, ge=0), ): - """List all resolved items.""" + """List resolved items, filtered by role.""" + effective_agency_id = _resolve_agency_filter(current_user, agency_id) audit_repo = AuditRepository(db) - resolved_items = await audit_repo.get_resolved_items(limit=limit, offset=offset) + resolved_items = await audit_repo.get_resolved_items(agency_id=effective_agency_id, limit=limit, offset=offset) return [ ResolvedItemResponse( @@ -462,13 +518,15 @@ async def list_resolved_items( @router.get("/audit/errors", response_model=list[ErrorItemResponse]) async def list_error_items( db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), + agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"), limit: int = Query(100, ge=1, le=500), offset: int = Query(0, ge=0), ): - """List all error items.""" + """List error items, filtered by role.""" + effective_agency_id = _resolve_agency_filter(current_user, agency_id) audit_repo = AuditRepository(db) - error_items = await audit_repo.get_error_items(limit=limit, offset=offset) + error_items = await audit_repo.get_error_items(agency_id=effective_agency_id, limit=limit, offset=offset) return [ ErrorItemResponse( @@ -486,25 +544,33 @@ async def list_error_items( ] +# --------------------------------------------------------------------------- # Analytics endpoint +# --------------------------------------------------------------------------- + @router.get("/analytics", response_model=AnalyticsResponse) async def get_analytics( db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), + agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"), ): - """Get analytics data.""" + """Get analytics data, filtered by role.""" + effective_agency_id = _resolve_agency_filter(current_user, agency_id) repo = CampaignRepository(db) - analytics = await repo.get_analytics() + analytics = await repo.get_analytics(agency_id=effective_agency_id) return AnalyticsResponse(**analytics) -# Users endpoint (admin only) +# --------------------------------------------------------------------------- +# User Management endpoints +# --------------------------------------------------------------------------- + @router.get("/users", response_model=list[UserResponse]) async def list_users( db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_role("super_admin")), ): - """List all users (admin only).""" + """List all users (super_admin only).""" user_repo = UserRepository(db) users = await user_repo.list_all() @@ -515,17 +581,103 @@ async def list_users( name=u.name, role=u.role, agency=u.agency.name if u.agency else None, + agency_id=u.agency_id, created_at=u.created_at, ) for u in users ] +@router.put("/users/{user_id}", response_model=UserResponse) +async def update_user( + user_id: uuid.UUID, + data: UserUpdate, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_role("super_admin")), +): + """Update a user's role and/or agency (super_admin only).""" + user_repo = UserRepository(db) + + # Build kwargs, using sentinel to distinguish "not provided" from None + kwargs: dict = {} + if data.role is not None: + valid_roles = ("super_admin", "oversight_admin", "agency_admin", "basic_user") + if data.role not in valid_roles: + raise HTTPException(status_code=400, detail=f"Invalid role. Must be one of: {', '.join(valid_roles)}") + kwargs["role"] = data.role + # agency_id can be explicitly set to None (unassign) or a UUID + if "agency_id" in (data.model_fields_set or set()): + kwargs["agency_id"] = data.agency_id + else: + kwargs["agency_id"] = ... # sentinel: "not provided" + + user = await user_repo.update_user(user_id, **kwargs) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + return UserResponse( + id=user.id, + email=user.email, + name=user.name, + role=user.role, + agency=user.agency.name if user.agency else None, + agency_id=user.agency_id, + created_at=user.created_at, + ) + + +# --------------------------------------------------------------------------- +# Agency endpoints +# --------------------------------------------------------------------------- + +@router.get("/agencies", response_model=list[AgencyResponse]) +async def list_agencies( + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_db_user), +): + """List all agencies.""" + from sqlalchemy import select + from app.models.models import Agency + + stmt = select(Agency).order_by(Agency.name) + result = await db.execute(stmt) + agencies = result.scalars().all() + + return [AgencyResponse(id=a.id, name=a.name) for a in agencies] + + +@router.post("/agencies", response_model=AgencyResponse, status_code=201) +async def create_agency( + data: AgencyCreate, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_role("super_admin")), +): + """Create a new agency (super_admin only).""" + user_repo = UserRepository(db) + + # Check for duplicate name + existing = await user_repo.get_or_create_agency(data.name) + # get_or_create returns existing — check if it was just created + # A simpler approach: try to create and catch unique violation + from sqlalchemy import select + from app.models.models import Agency + result = await db.execute(select(Agency).where(Agency.name == data.name)) + agency = result.scalar_one_or_none() + if agency: + return AgencyResponse(id=agency.id, name=agency.name) + + agency = await user_repo.create_agency(data.name) + return AgencyResponse(id=agency.id, name=agency.name) + + +# --------------------------------------------------------------------------- # Dropdown options endpoints +# --------------------------------------------------------------------------- + @router.get("/dropdown-options", response_model=DropdownOptionsResponse) async def get_dropdown_options( db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), ): """Get all dropdown options as hierarchical structure.""" import logging @@ -534,7 +686,6 @@ async def get_dropdown_options( repo = DropdownRepository(db) options = await repo.get_all_hierarchical() - # Debug logging channels = options.get("channels", {}) social = channels.get("Social", {}) meta_proof_types = social.get("Meta", []) @@ -547,9 +698,9 @@ async def get_dropdown_options( async def add_channel( name: str = Query(..., description="Channel name"), db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_role("super_admin")), ): - """Add a new channel.""" + """Add a new channel (super_admin only).""" repo = DropdownRepository(db) await repo.add_channel(name) await db.commit() @@ -561,9 +712,9 @@ async def add_sub_channel( channel: str, name: str = Query(..., description="Sub-channel name"), db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_role("super_admin")), ): - """Add a sub-channel under a channel.""" + """Add a sub-channel under a channel (super_admin only).""" repo = DropdownRepository(db) result = await repo.add_sub_channel(channel, name) if not result: @@ -578,9 +729,9 @@ async def add_proof_type( sub_channel: str, name: str = Query(..., description="Proof type name"), db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_role("super_admin")), ): - """Add a proof type under a sub-channel.""" + """Add a proof type under a sub-channel (super_admin only).""" repo = DropdownRepository(db) result = await repo.add_proof_type(channel, sub_channel, name) if not result: @@ -593,9 +744,9 @@ async def add_proof_type( async def delete_channel( channel: str, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_role("super_admin")), ): - """Delete a channel and all its sub-channels and proof types.""" + """Delete a channel and all its sub-channels and proof types (super_admin only).""" repo = DropdownRepository(db) success = await repo.remove_channel(channel) if not success: @@ -608,9 +759,9 @@ async def delete_sub_channel( channel: str, sub_channel: str, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_role("super_admin")), ): - """Delete a sub-channel and all its proof types.""" + """Delete a sub-channel and all its proof types (super_admin only).""" repo = DropdownRepository(db) success = await repo.remove_sub_channel(channel, sub_channel) if not success: @@ -624,9 +775,9 @@ async def delete_proof_type( sub_channel: str, proof_type: str, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(require_role("super_admin")), ): - """Delete a proof type.""" + """Delete a proof type (super_admin only).""" repo = DropdownRepository(db) success = await repo.remove_proof_type(channel, sub_channel, proof_type) if not success: @@ -634,22 +785,9 @@ async def delete_proof_type( await db.commit() -# Agency endpoints -@router.get("/agencies", response_model=list[AgencyResponse]) -async def list_agencies( - db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), -): - """List all agencies.""" - from sqlalchemy import select - from app.models.models import Agency - - stmt = select(Agency).order_by(Agency.name) - result = await db.execute(stmt) - agencies = result.scalars().all() - - return [AgencyResponse(id=a.id, name=a.name) for a in agencies] - +# --------------------------------------------------------------------------- +# File endpoints +# --------------------------------------------------------------------------- # PDF pages endpoint (must be defined BEFORE the base file endpoint for correct routing) @router.get("/files/{storage_key:path}/pages") @@ -657,7 +795,7 @@ async def get_pdf_pages( storage_key: str, max_pages: int = Query(10, ge=1, le=50), db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), ): """Rasterize a stored PDF and return pages as data URLs.""" if not storage_key.lower().endswith('.pdf'): @@ -690,14 +828,13 @@ async def get_pdf_pages( async def get_file( storage_key: str, db: AsyncSession = Depends(get_db), - user: dict = Depends(get_current_user), + current_user: User = Depends(get_current_db_user), ): """Retrieve a stored file by its storage key.""" file_data = await storage_service.get_file(storage_key) if file_data is None: raise HTTPException(status_code=404, detail="File not found") - # Determine content type from extension extension = storage_key.split('.')[-1].lower() if '.' in storage_key else '' content_types = { 'png': 'image/png', @@ -714,7 +851,10 @@ async def get_file( ) +# --------------------------------------------------------------------------- # Support email endpoint (public - no auth required for login page access) +# --------------------------------------------------------------------------- + @router.post("/support/email") async def send_support_email( data: SupportEmailRequest, diff --git a/backend/app/api/schemas.py b/backend/app/api/schemas.py index 5769e84..43495cf 100755 --- a/backend/app/api/schemas.py +++ b/backend/app/api/schemas.py @@ -164,18 +164,43 @@ class AgencyResponse(BaseModel): # User schemas +class CurrentUserResponse(BaseModel): + """Response for /api/me - the authenticated user's own profile.""" + id: uuid.UUID + email: str + name: str + role: str + agency_id: Optional[uuid.UUID] + agency_name: Optional[str] + + class Config: + from_attributes = True + + class UserResponse(BaseModel): id: uuid.UUID email: str name: str role: str agency: Optional[str] + agency_id: Optional[uuid.UUID] = None created_at: datetime class Config: from_attributes = True +class UserUpdate(BaseModel): + """Request body for updating a user's role and/or agency.""" + role: Optional[str] = None + agency_id: Optional[uuid.UUID] = None + + +class AgencyCreate(BaseModel): + """Request body for creating a new agency.""" + name: str + + # Support email schemas class SupportEmailRequest(BaseModel): message: str diff --git a/backend/app/dependencies/auth.py b/backend/app/dependencies/auth.py index 5f89305..585927a 100755 --- a/backend/app/dependencies/auth.py +++ b/backend/app/dependencies/auth.py @@ -1,17 +1,26 @@ """ FastAPI authentication dependencies. -Provides dependency functions for securing REST endpoints with Azure AD token verification. +Provides dependency functions for securing REST endpoints with Azure AD token verification +and role-based access control. """ import logging from typing import Optional -from fastapi import Header, HTTPException, status + +from fastapi import Depends, Header, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings +from app.models.database import get_db +from app.models.models import User +from app.repositories.user_repository import UserRepository from app.services.auth_service import verify_access_token logger = logging.getLogger(__name__) +# Valid roles ordered by privilege level (for reference) +VALID_ROLES = ("super_admin", "oversight_admin", "agency_admin", "basic_user") + async def get_current_user(authorization: Optional[str] = Header(None)) -> dict: """ @@ -72,3 +81,70 @@ async def get_current_user(authorization: Optional[str] = Header(None)) -> dict: logger.debug(f"[MSAL Backend] Authentication successful for: {claims.get('name', 'unknown')}") return claims + + +async def get_current_db_user( + user_claims: dict = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +) -> User: + """ + Resolve Azure AD claims to a full User ORM object with agency loaded. + + Creates the user on first login as basic_user with no agency. + In dev mode (DISABLE_AUTH=true), auto-promotes the dev user to super_admin. + """ + user_repo = UserRepository(db) + azure_oid = user_claims.get("oid") or user_claims.get("sub") + if not azure_oid: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing user identifier in token claims", + ) + + user = await user_repo.get_or_create_from_azure( + azure_ad_oid=azure_oid, + email=user_claims.get("email", user_claims.get("preferred_username", "")), + name=user_claims.get("name", "Unknown"), + ) + + # Dev mode: auto-promote to super_admin so all features are accessible + if settings.DISABLE_AUTH and user.role != "super_admin": + user.role = "super_admin" + await db.flush() + + return user + + +def require_role(*allowed_roles: str): + """ + Dependency factory that restricts access to users with specific roles. + + Usage: + @router.get("/admin-only") + async def admin_route(user: User = Depends(require_role("super_admin"))): + ... + """ + async def _check_role(current_user: User = Depends(get_current_db_user)) -> User: + if current_user.role not in allowed_roles: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Access denied. Required role: {', '.join(allowed_roles)}", + ) + return current_user + + return _check_role + + +async def require_write_access( + current_user: User = Depends(get_current_db_user), +) -> User: + """ + Dependency that blocks oversight_admin from write/mutation operations. + All other roles are allowed through. + """ + if current_user.role == "oversight_admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Oversight Admin has read-only access", + ) + return current_user diff --git a/backend/app/main.py b/backend/app/main.py index e3d37a0..1676533 100755 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -8,7 +8,8 @@ from fastapi.middleware.cors import CORSMiddleware from app.config import settings from app.services.auth_service import verify_access_token from app.dependencies.auth import get_current_user -from app.models.database import init_db, close_db +from app.models.database import init_db, close_db, async_session_factory as _session_factory +from app.repositories.user_repository import UserRepository from app.api import router as api_router, kb_router # Configure logging @@ -193,6 +194,21 @@ async def websocket_analyze(websocket: WebSocket): logger.info(f"[MAIN] Authenticated user: {user_claims.get('name', 'unknown')}") + # Check role: oversight_admin cannot upload/analyze proofs + try: + async with _session_factory() as ws_session: + ws_user_repo = UserRepository(ws_session) + azure_oid = user_claims.get("oid") or user_claims.get("sub") + ws_user = await ws_user_repo.get_by_azure_oid(azure_oid) if azure_oid else None + if ws_user and ws_user.role == "oversight_admin": + await manager.send_message(client_id, { + "type": "error", + "message": "Oversight Admin has read-only access and cannot analyze proofs." + }) + continue + except Exception as role_err: + logger.warning(f"[MAIN] Role check failed for client {client_id}: {role_err}") + if analysis_service is None: logger.error("[MAIN] Analysis service not ready") await manager.send_message(client_id, { diff --git a/backend/app/repositories/user_repository.py b/backend/app/repositories/user_repository.py index a2dd8f7..1425154 100755 --- a/backend/app/repositories/user_repository.py +++ b/backend/app/repositories/user_repository.py @@ -89,7 +89,33 @@ class UserRepository: await self.session.flush() return user + async def update_user( + self, + user_id: uuid.UUID, + role: Optional[str] = None, + agency_id: Optional[uuid.UUID] = ..., # type: ignore[assignment] + ) -> Optional[User]: + """Update user role and/or agency. Pass agency_id=None to unassign.""" + user = await self.get_by_id(user_id) + if not user: + return None + if role is not None: + user.role = role + # Use sentinel (...) to distinguish "not provided" from "set to None" + if agency_id is not ...: + user.agency_id = agency_id + await self.session.flush() + # Re-fetch to get agency relationship loaded + return await self.get_by_id(user_id) + async def list_agencies(self) -> list[Agency]: """List all agencies.""" result = await self.session.execute(select(Agency)) return list(result.scalars().all()) + + async def create_agency(self, name: str) -> Agency: + """Create a new agency.""" + agency = Agency(name=name) + self.session.add(agency) + await self.session.flush() + return agency