Add 4-tier RBAC backend: auth dependencies, role enforcement, agency filtering
- Add CHECK constraint migration for users.role (super_admin, oversight_admin, agency_admin, basic_user)
- Add get_current_db_user dependency resolving Azure claims to User ORM with agency
- Add require_role() factory and require_write_access() dependency
- Auto-promote dev user to super_admin when DISABLE_AUTH=true
- Add /api/me, PUT /api/users/{id}, POST /api/agencies endpoints
- Apply agency-based data filtering on campaigns, analytics, audit routes
- Block oversight_admin from all mutation routes (campaigns, proofs, flags, resolves)
- Restrict dropdown option mutations to super_admin only
- Add role check in WebSocket handler to block oversight_admin from analysis
- Add CurrentUserResponse, UserUpdate, AgencyCreate schemas
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
6bd8a03a15
commit
d21036a0de
6 changed files with 422 additions and 104 deletions
35
backend/alembic/versions/007_add_role_check_constraint.py
Normal file
35
backend/alembic/versions/007_add_role_check_constraint.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Add CHECK constraint on users.role for RBAC roles
|
||||
|
||||
Revision ID: 007_add_role_check_constraint
|
||||
Revises: 006_add_knowledge_base
|
||||
Create Date: 2026-02-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "007_add_role_check_constraint"
|
||||
down_revision: Union[str, None] = "006_add_knowledge_base"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
VALID_ROLES = ("super_admin", "oversight_admin", "agency_admin", "basic_user")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Normalise any legacy role values to basic_user before adding constraint
|
||||
op.execute(
|
||||
f"UPDATE users SET role = 'basic_user' "
|
||||
f"WHERE role NOT IN {VALID_ROLES!r}"
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_users_role",
|
||||
"users",
|
||||
f"role IN {VALID_ROLES!r}",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("ck_users_role", "users", type_="check")
|
||||
|
|
@ -22,11 +22,20 @@ from app.api.schemas import (
|
|||
AnalyticsResponse,
|
||||
DropdownOptionsResponse,
|
||||
AgencyResponse,
|
||||
AgencyCreate,
|
||||
CurrentUserResponse,
|
||||
UserResponse,
|
||||
UserUpdate,
|
||||
SupportEmailRequest,
|
||||
)
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.dependencies.auth import (
|
||||
get_current_user,
|
||||
get_current_db_user,
|
||||
require_role,
|
||||
require_write_access,
|
||||
)
|
||||
from app.models.database import get_db
|
||||
from app.models.models import User
|
||||
from app.repositories import (
|
||||
CampaignRepository,
|
||||
ProofRepository,
|
||||
|
|
@ -41,36 +50,65 @@ from app.services.pdf_service import pdf_service
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
# Helper to get user from DB based on Azure claims
|
||||
async def get_db_user(
|
||||
session: AsyncSession,
|
||||
user_claims: dict,
|
||||
) -> Optional[uuid.UUID]:
|
||||
"""Get or create user from Azure AD claims and return user ID."""
|
||||
user_repo = UserRepository(session)
|
||||
azure_oid = user_claims.get("oid") or user_claims.get("sub")
|
||||
if not azure_oid:
|
||||
return None
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
user = await user_repo.get_or_create_from_azure(
|
||||
azure_ad_oid=azure_oid,
|
||||
email=user_claims.get("email", user_claims.get("preferred_username", "")),
|
||||
name=user_claims.get("name", "Unknown"),
|
||||
def _resolve_agency_filter(current_user: User, agency_id_param: Optional[uuid.UUID] = None) -> Optional[uuid.UUID]:
|
||||
"""Determine which agency_id to filter by based on user role and query param.
|
||||
|
||||
- super_admin / oversight_admin: use agency_id_param if provided, else None (all)
|
||||
- agency_admin / basic_user: forced to current_user.agency_id
|
||||
"""
|
||||
if current_user.role in ("super_admin", "oversight_admin"):
|
||||
return agency_id_param # None means "all"
|
||||
return current_user.agency_id # may be None (user not assigned yet)
|
||||
|
||||
|
||||
def _check_campaign_access(current_user: User, campaign) -> None:
|
||||
"""Raise 404 if the user's role restricts them and the campaign doesn't belong to their agency."""
|
||||
if current_user.role in ("super_admin", "oversight_admin"):
|
||||
return
|
||||
if current_user.agency_id and campaign.agency_id == current_user.agency_id:
|
||||
return
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Current User endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/me", response_model=CurrentUserResponse)
|
||||
async def get_me(
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
):
|
||||
"""Get the authenticated user's profile."""
|
||||
return CurrentUserResponse(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
name=current_user.name,
|
||||
role=current_user.role,
|
||||
agency_id=current_user.agency_id,
|
||||
agency_name=current_user.agency.name if current_user.agency else None,
|
||||
)
|
||||
return user.id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Campaign endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/campaigns", response_model=list[CampaignResponse])
|
||||
async def list_campaigns(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""List all campaigns."""
|
||||
"""List campaigns, filtered by the user's role and optional agency filter."""
|
||||
effective_agency_id = _resolve_agency_filter(current_user, agency_id)
|
||||
repo = CampaignRepository(db)
|
||||
campaigns_with_counts = await repo.get_with_proof_counts()
|
||||
campaigns_with_counts = await repo.get_with_proof_counts(agency_id=effective_agency_id)
|
||||
|
||||
return [
|
||||
CampaignResponse(
|
||||
|
|
@ -94,10 +132,9 @@ async def list_campaigns(
|
|||
async def create_campaign(
|
||||
data: CampaignCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_write_access),
|
||||
):
|
||||
"""Create a new campaign."""
|
||||
user_id = await get_db_user(db, user)
|
||||
"""Create a new campaign. Blocked for oversight_admin."""
|
||||
repo = CampaignRepository(db)
|
||||
|
||||
# Check if campaign name already exists
|
||||
|
|
@ -111,7 +148,8 @@ async def create_campaign(
|
|||
client_lead=data.client_lead,
|
||||
agency_lead=data.agency_lead,
|
||||
brand_guidelines=data.brand_guidelines,
|
||||
created_by=user_id,
|
||||
agency_id=current_user.agency_id,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
|
||||
return CampaignResponse(
|
||||
|
|
@ -122,7 +160,7 @@ async def create_campaign(
|
|||
agency_lead=campaign.agency_lead,
|
||||
brand_guidelines=campaign.brand_guidelines,
|
||||
status=campaign.status,
|
||||
agency=None,
|
||||
agency=current_user.agency.name if current_user.agency else None,
|
||||
created_at=campaign.created_at,
|
||||
updated_at=campaign.updated_at,
|
||||
proofs=0,
|
||||
|
|
@ -133,7 +171,7 @@ async def create_campaign(
|
|||
async def get_campaign(
|
||||
campaign_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
):
|
||||
"""Get a campaign by ID."""
|
||||
repo = CampaignRepository(db)
|
||||
|
|
@ -141,6 +179,8 @@ async def get_campaign(
|
|||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
_check_campaign_access(current_user, campaign)
|
||||
|
||||
return CampaignResponse(
|
||||
id=campaign.id,
|
||||
name=campaign.name,
|
||||
|
|
@ -161,10 +201,17 @@ async def update_campaign(
|
|||
campaign_id: uuid.UUID,
|
||||
data: CampaignUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_write_access),
|
||||
):
|
||||
"""Update a campaign."""
|
||||
"""Update a campaign. Blocked for oversight_admin."""
|
||||
repo = CampaignRepository(db)
|
||||
|
||||
# Verify campaign exists and user has access
|
||||
existing = await repo.get_by_id(campaign_id)
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
_check_campaign_access(current_user, existing)
|
||||
|
||||
campaign = await repo.update(
|
||||
campaign_id,
|
||||
name=data.name,
|
||||
|
|
@ -196,34 +243,44 @@ async def update_campaign(
|
|||
async def delete_campaign(
|
||||
campaign_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_write_access),
|
||||
):
|
||||
"""Delete a campaign and all associated files."""
|
||||
"""Delete a campaign and all associated files. Blocked for oversight_admin."""
|
||||
repo = CampaignRepository(db)
|
||||
|
||||
# Get campaign with proofs and versions to extract file keys
|
||||
campaign = await repo.get_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
_check_campaign_access(current_user, campaign)
|
||||
|
||||
# Delete files from storage for all proofs
|
||||
for proof in campaign.proofs:
|
||||
for version in proof.versions:
|
||||
if version.file_storage_key:
|
||||
await storage_service.delete_file(version.file_storage_key)
|
||||
|
||||
# Delete database records (cascades to proofs and versions)
|
||||
await repo.delete(campaign_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Proof endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/campaigns/{campaign_id}/proofs", response_model=list[ProofResponse])
|
||||
async def list_proofs(
|
||||
campaign_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
):
|
||||
"""List all proofs for a campaign."""
|
||||
# Verify campaign access
|
||||
campaign_repo = CampaignRepository(db)
|
||||
campaign = await campaign_repo.get_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
_check_campaign_access(current_user, campaign)
|
||||
|
||||
repo = ProofRepository(db)
|
||||
proofs = await repo.list_by_campaign(campaign_id)
|
||||
|
||||
|
|
@ -259,7 +316,7 @@ async def list_proofs(
|
|||
async def get_proof(
|
||||
proof_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
):
|
||||
"""Get a proof by ID."""
|
||||
repo = ProofRepository(db)
|
||||
|
|
@ -296,12 +353,11 @@ async def get_proof(
|
|||
async def delete_proof(
|
||||
proof_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_write_access),
|
||||
):
|
||||
"""Delete a proof and its associated files."""
|
||||
"""Delete a proof and its associated files. Blocked for oversight_admin."""
|
||||
repo = ProofRepository(db)
|
||||
|
||||
# Get proof with versions to extract file keys
|
||||
proof = await repo.get_by_id(proof_id)
|
||||
if not proof:
|
||||
raise HTTPException(status_code=404, detail="Proof not found")
|
||||
|
|
@ -311,25 +367,25 @@ async def delete_proof(
|
|||
if version.file_storage_key:
|
||||
await storage_service.delete_file(version.file_storage_key)
|
||||
|
||||
# Delete database records
|
||||
await repo.delete(proof_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audit endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/proofs/{proof_id}/versions/{version}/flag", response_model=FlaggedItemResponse, status_code=201)
|
||||
async def flag_proof_version(
|
||||
proof_id: uuid.UUID,
|
||||
version: int,
|
||||
data: FlaggedItemCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_write_access),
|
||||
):
|
||||
"""Flag an issue on a proof version."""
|
||||
user_id = await get_db_user(db, user)
|
||||
"""Flag an issue on a proof version. Blocked for oversight_admin."""
|
||||
proof_repo = ProofRepository(db)
|
||||
audit_repo = AuditRepository(db)
|
||||
|
||||
# Get the proof version
|
||||
proof_version = await proof_repo.get_version(proof_id, version)
|
||||
if not proof_version:
|
||||
raise HTTPException(status_code=404, detail="Proof version not found")
|
||||
|
|
@ -338,10 +394,9 @@ async def flag_proof_version(
|
|||
proof_version_id=proof_version.id,
|
||||
agent_flagged=data.agent_flagged,
|
||||
comments=data.comments,
|
||||
submitter_id=user_id,
|
||||
submitter_id=current_user.id,
|
||||
)
|
||||
|
||||
# Get related data for response
|
||||
proof = await proof_repo.get_by_id(proof_id)
|
||||
|
||||
return FlaggedItemResponse(
|
||||
|
|
@ -349,8 +404,8 @@ async def flag_proof_version(
|
|||
proof_version_id=flagged.proof_version_id,
|
||||
agent_flagged=flagged.agent_flagged,
|
||||
comments=flagged.comments,
|
||||
submitter_name=user.get("name"),
|
||||
submitter_agency=None,
|
||||
submitter_name=current_user.name,
|
||||
submitter_agency=current_user.agency.name if current_user.agency else None,
|
||||
campaign_name=proof.campaign.name if proof and proof.campaign else None,
|
||||
proof_name=proof.proof_name if proof else None,
|
||||
version=version,
|
||||
|
|
@ -364,14 +419,12 @@ async def resolve_proof_version(
|
|||
version: int,
|
||||
data: ResolvedItemCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_write_access),
|
||||
):
|
||||
"""Resolve an issue on a proof version."""
|
||||
user_id = await get_db_user(db, user)
|
||||
"""Resolve an issue on a proof version. Blocked for oversight_admin."""
|
||||
proof_repo = ProofRepository(db)
|
||||
audit_repo = AuditRepository(db)
|
||||
|
||||
# Get the proof version
|
||||
proof_version = await proof_repo.get_version(proof_id, version)
|
||||
if not proof_version:
|
||||
raise HTTPException(status_code=404, detail="Proof version not found")
|
||||
|
|
@ -381,10 +434,9 @@ async def resolve_proof_version(
|
|||
agent=data.agent,
|
||||
issue=data.issue,
|
||||
resolution=data.resolution,
|
||||
submitter_id=user_id,
|
||||
submitter_id=current_user.id,
|
||||
)
|
||||
|
||||
# Get related data for response
|
||||
proof = await proof_repo.get_by_id(proof_id)
|
||||
|
||||
return ResolvedItemResponse(
|
||||
|
|
@ -393,8 +445,8 @@ async def resolve_proof_version(
|
|||
agent=resolved.agent,
|
||||
issue=resolved.issue,
|
||||
resolution=resolved.resolution,
|
||||
submitter_name=user.get("name"),
|
||||
submitter_agency=None,
|
||||
submitter_name=current_user.name,
|
||||
submitter_agency=current_user.agency.name if current_user.agency else None,
|
||||
campaign_name=proof.campaign.name if proof and proof.campaign else None,
|
||||
proof_name=proof.proof_name if proof else None,
|
||||
version=version,
|
||||
|
|
@ -405,13 +457,15 @@ async def resolve_proof_version(
|
|||
@router.get("/audit/flagged", response_model=list[FlaggedItemResponse])
|
||||
async def list_flagged_items(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""List all flagged items."""
|
||||
"""List flagged items, filtered by role."""
|
||||
effective_agency_id = _resolve_agency_filter(current_user, agency_id)
|
||||
audit_repo = AuditRepository(db)
|
||||
flagged_items = await audit_repo.get_flagged_items(limit=limit, offset=offset)
|
||||
flagged_items = await audit_repo.get_flagged_items(agency_id=effective_agency_id, limit=limit, offset=offset)
|
||||
|
||||
return [
|
||||
FlaggedItemResponse(
|
||||
|
|
@ -433,13 +487,15 @@ async def list_flagged_items(
|
|||
@router.get("/audit/resolved", response_model=list[ResolvedItemResponse])
|
||||
async def list_resolved_items(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""List all resolved items."""
|
||||
"""List resolved items, filtered by role."""
|
||||
effective_agency_id = _resolve_agency_filter(current_user, agency_id)
|
||||
audit_repo = AuditRepository(db)
|
||||
resolved_items = await audit_repo.get_resolved_items(limit=limit, offset=offset)
|
||||
resolved_items = await audit_repo.get_resolved_items(agency_id=effective_agency_id, limit=limit, offset=offset)
|
||||
|
||||
return [
|
||||
ResolvedItemResponse(
|
||||
|
|
@ -462,13 +518,15 @@ async def list_resolved_items(
|
|||
@router.get("/audit/errors", response_model=list[ErrorItemResponse])
|
||||
async def list_error_items(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""List all error items."""
|
||||
"""List error items, filtered by role."""
|
||||
effective_agency_id = _resolve_agency_filter(current_user, agency_id)
|
||||
audit_repo = AuditRepository(db)
|
||||
error_items = await audit_repo.get_error_items(limit=limit, offset=offset)
|
||||
error_items = await audit_repo.get_error_items(agency_id=effective_agency_id, limit=limit, offset=offset)
|
||||
|
||||
return [
|
||||
ErrorItemResponse(
|
||||
|
|
@ -486,25 +544,33 @@ async def list_error_items(
|
|||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Analytics endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/analytics", response_model=AnalyticsResponse)
|
||||
async def get_analytics(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
agency_id: Optional[uuid.UUID] = Query(None, description="Filter by agency"),
|
||||
):
|
||||
"""Get analytics data."""
|
||||
"""Get analytics data, filtered by role."""
|
||||
effective_agency_id = _resolve_agency_filter(current_user, agency_id)
|
||||
repo = CampaignRepository(db)
|
||||
analytics = await repo.get_analytics()
|
||||
analytics = await repo.get_analytics(agency_id=effective_agency_id)
|
||||
return AnalyticsResponse(**analytics)
|
||||
|
||||
|
||||
# Users endpoint (admin only)
|
||||
# ---------------------------------------------------------------------------
|
||||
# User Management endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/users", response_model=list[UserResponse])
|
||||
async def list_users(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_role("super_admin")),
|
||||
):
|
||||
"""List all users (admin only)."""
|
||||
"""List all users (super_admin only)."""
|
||||
user_repo = UserRepository(db)
|
||||
users = await user_repo.list_all()
|
||||
|
||||
|
|
@ -515,17 +581,103 @@ async def list_users(
|
|||
name=u.name,
|
||||
role=u.role,
|
||||
agency=u.agency.name if u.agency else None,
|
||||
agency_id=u.agency_id,
|
||||
created_at=u.created_at,
|
||||
)
|
||||
for u in users
|
||||
]
|
||||
|
||||
|
||||
@router.put("/users/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
user_id: uuid.UUID,
|
||||
data: UserUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(require_role("super_admin")),
|
||||
):
|
||||
"""Update a user's role and/or agency (super_admin only)."""
|
||||
user_repo = UserRepository(db)
|
||||
|
||||
# Build kwargs, using sentinel to distinguish "not provided" from None
|
||||
kwargs: dict = {}
|
||||
if data.role is not None:
|
||||
valid_roles = ("super_admin", "oversight_admin", "agency_admin", "basic_user")
|
||||
if data.role not in valid_roles:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid role. Must be one of: {', '.join(valid_roles)}")
|
||||
kwargs["role"] = data.role
|
||||
# agency_id can be explicitly set to None (unassign) or a UUID
|
||||
if "agency_id" in (data.model_fields_set or set()):
|
||||
kwargs["agency_id"] = data.agency_id
|
||||
else:
|
||||
kwargs["agency_id"] = ... # sentinel: "not provided"
|
||||
|
||||
user = await user_repo.update_user(user_id, **kwargs)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
role=user.role,
|
||||
agency=user.agency.name if user.agency else None,
|
||||
agency_id=user.agency_id,
|
||||
created_at=user.created_at,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agency endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/agencies", response_model=list[AgencyResponse])
|
||||
async def list_agencies(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
):
|
||||
"""List all agencies."""
|
||||
from sqlalchemy import select
|
||||
from app.models.models import Agency
|
||||
|
||||
stmt = select(Agency).order_by(Agency.name)
|
||||
result = await db.execute(stmt)
|
||||
agencies = result.scalars().all()
|
||||
|
||||
return [AgencyResponse(id=a.id, name=a.name) for a in agencies]
|
||||
|
||||
|
||||
@router.post("/agencies", response_model=AgencyResponse, status_code=201)
|
||||
async def create_agency(
|
||||
data: AgencyCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(require_role("super_admin")),
|
||||
):
|
||||
"""Create a new agency (super_admin only)."""
|
||||
user_repo = UserRepository(db)
|
||||
|
||||
# Check for duplicate name
|
||||
existing = await user_repo.get_or_create_agency(data.name)
|
||||
# get_or_create returns existing — check if it was just created
|
||||
# A simpler approach: try to create and catch unique violation
|
||||
from sqlalchemy import select
|
||||
from app.models.models import Agency
|
||||
result = await db.execute(select(Agency).where(Agency.name == data.name))
|
||||
agency = result.scalar_one_or_none()
|
||||
if agency:
|
||||
return AgencyResponse(id=agency.id, name=agency.name)
|
||||
|
||||
agency = await user_repo.create_agency(data.name)
|
||||
return AgencyResponse(id=agency.id, name=agency.name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dropdown options endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/dropdown-options", response_model=DropdownOptionsResponse)
|
||||
async def get_dropdown_options(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
):
|
||||
"""Get all dropdown options as hierarchical structure."""
|
||||
import logging
|
||||
|
|
@ -534,7 +686,6 @@ async def get_dropdown_options(
|
|||
repo = DropdownRepository(db)
|
||||
options = await repo.get_all_hierarchical()
|
||||
|
||||
# Debug logging
|
||||
channels = options.get("channels", {})
|
||||
social = channels.get("Social", {})
|
||||
meta_proof_types = social.get("Meta", [])
|
||||
|
|
@ -547,9 +698,9 @@ async def get_dropdown_options(
|
|||
async def add_channel(
|
||||
name: str = Query(..., description="Channel name"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_role("super_admin")),
|
||||
):
|
||||
"""Add a new channel."""
|
||||
"""Add a new channel (super_admin only)."""
|
||||
repo = DropdownRepository(db)
|
||||
await repo.add_channel(name)
|
||||
await db.commit()
|
||||
|
|
@ -561,9 +712,9 @@ async def add_sub_channel(
|
|||
channel: str,
|
||||
name: str = Query(..., description="Sub-channel name"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_role("super_admin")),
|
||||
):
|
||||
"""Add a sub-channel under a channel."""
|
||||
"""Add a sub-channel under a channel (super_admin only)."""
|
||||
repo = DropdownRepository(db)
|
||||
result = await repo.add_sub_channel(channel, name)
|
||||
if not result:
|
||||
|
|
@ -578,9 +729,9 @@ async def add_proof_type(
|
|||
sub_channel: str,
|
||||
name: str = Query(..., description="Proof type name"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_role("super_admin")),
|
||||
):
|
||||
"""Add a proof type under a sub-channel."""
|
||||
"""Add a proof type under a sub-channel (super_admin only)."""
|
||||
repo = DropdownRepository(db)
|
||||
result = await repo.add_proof_type(channel, sub_channel, name)
|
||||
if not result:
|
||||
|
|
@ -593,9 +744,9 @@ async def add_proof_type(
|
|||
async def delete_channel(
|
||||
channel: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_role("super_admin")),
|
||||
):
|
||||
"""Delete a channel and all its sub-channels and proof types."""
|
||||
"""Delete a channel and all its sub-channels and proof types (super_admin only)."""
|
||||
repo = DropdownRepository(db)
|
||||
success = await repo.remove_channel(channel)
|
||||
if not success:
|
||||
|
|
@ -608,9 +759,9 @@ async def delete_sub_channel(
|
|||
channel: str,
|
||||
sub_channel: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_role("super_admin")),
|
||||
):
|
||||
"""Delete a sub-channel and all its proof types."""
|
||||
"""Delete a sub-channel and all its proof types (super_admin only)."""
|
||||
repo = DropdownRepository(db)
|
||||
success = await repo.remove_sub_channel(channel, sub_channel)
|
||||
if not success:
|
||||
|
|
@ -624,9 +775,9 @@ async def delete_proof_type(
|
|||
sub_channel: str,
|
||||
proof_type: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(require_role("super_admin")),
|
||||
):
|
||||
"""Delete a proof type."""
|
||||
"""Delete a proof type (super_admin only)."""
|
||||
repo = DropdownRepository(db)
|
||||
success = await repo.remove_proof_type(channel, sub_channel, proof_type)
|
||||
if not success:
|
||||
|
|
@ -634,22 +785,9 @@ async def delete_proof_type(
|
|||
await db.commit()
|
||||
|
||||
|
||||
# Agency endpoints
|
||||
@router.get("/agencies", response_model=list[AgencyResponse])
|
||||
async def list_agencies(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""List all agencies."""
|
||||
from sqlalchemy import select
|
||||
from app.models.models import Agency
|
||||
|
||||
stmt = select(Agency).order_by(Agency.name)
|
||||
result = await db.execute(stmt)
|
||||
agencies = result.scalars().all()
|
||||
|
||||
return [AgencyResponse(id=a.id, name=a.name) for a in agencies]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# PDF pages endpoint (must be defined BEFORE the base file endpoint for correct routing)
|
||||
@router.get("/files/{storage_key:path}/pages")
|
||||
|
|
@ -657,7 +795,7 @@ async def get_pdf_pages(
|
|||
storage_key: str,
|
||||
max_pages: int = Query(10, ge=1, le=50),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
):
|
||||
"""Rasterize a stored PDF and return pages as data URLs."""
|
||||
if not storage_key.lower().endswith('.pdf'):
|
||||
|
|
@ -690,14 +828,13 @@ async def get_pdf_pages(
|
|||
async def get_file(
|
||||
storage_key: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: dict = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
):
|
||||
"""Retrieve a stored file by its storage key."""
|
||||
file_data = await storage_service.get_file(storage_key)
|
||||
if file_data is None:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Determine content type from extension
|
||||
extension = storage_key.split('.')[-1].lower() if '.' in storage_key else ''
|
||||
content_types = {
|
||||
'png': 'image/png',
|
||||
|
|
@ -714,7 +851,10 @@ async def get_file(
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Support email endpoint (public - no auth required for login page access)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/support/email")
|
||||
async def send_support_email(
|
||||
data: SupportEmailRequest,
|
||||
|
|
|
|||
|
|
@ -164,18 +164,43 @@ class AgencyResponse(BaseModel):
|
|||
|
||||
|
||||
# User schemas
|
||||
class CurrentUserResponse(BaseModel):
|
||||
"""Response for /api/me - the authenticated user's own profile."""
|
||||
id: uuid.UUID
|
||||
email: str
|
||||
name: str
|
||||
role: str
|
||||
agency_id: Optional[uuid.UUID]
|
||||
agency_name: Optional[str]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
email: str
|
||||
name: str
|
||||
role: str
|
||||
agency: Optional[str]
|
||||
agency_id: Optional[uuid.UUID] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""Request body for updating a user's role and/or agency."""
|
||||
role: Optional[str] = None
|
||||
agency_id: Optional[uuid.UUID] = None
|
||||
|
||||
|
||||
class AgencyCreate(BaseModel):
|
||||
"""Request body for creating a new agency."""
|
||||
name: str
|
||||
|
||||
|
||||
# Support email schemas
|
||||
class SupportEmailRequest(BaseModel):
|
||||
message: str
|
||||
|
|
|
|||
|
|
@ -1,17 +1,26 @@
|
|||
"""
|
||||
FastAPI authentication dependencies.
|
||||
|
||||
Provides dependency functions for securing REST endpoints with Azure AD token verification.
|
||||
Provides dependency functions for securing REST endpoints with Azure AD token verification
|
||||
and role-based access control.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import Header, HTTPException, status
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.models.database import get_db
|
||||
from app.models.models import User
|
||||
from app.repositories.user_repository import UserRepository
|
||||
from app.services.auth_service import verify_access_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid roles ordered by privilege level (for reference)
|
||||
VALID_ROLES = ("super_admin", "oversight_admin", "agency_admin", "basic_user")
|
||||
|
||||
|
||||
async def get_current_user(authorization: Optional[str] = Header(None)) -> dict:
|
||||
"""
|
||||
|
|
@ -72,3 +81,70 @@ async def get_current_user(authorization: Optional[str] = Header(None)) -> dict:
|
|||
|
||||
logger.debug(f"[MSAL Backend] Authentication successful for: {claims.get('name', 'unknown')}")
|
||||
return claims
|
||||
|
||||
|
||||
async def get_current_db_user(
|
||||
user_claims: dict = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
Resolve Azure AD claims to a full User ORM object with agency loaded.
|
||||
|
||||
Creates the user on first login as basic_user with no agency.
|
||||
In dev mode (DISABLE_AUTH=true), auto-promotes the dev user to super_admin.
|
||||
"""
|
||||
user_repo = UserRepository(db)
|
||||
azure_oid = user_claims.get("oid") or user_claims.get("sub")
|
||||
if not azure_oid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing user identifier in token claims",
|
||||
)
|
||||
|
||||
user = await user_repo.get_or_create_from_azure(
|
||||
azure_ad_oid=azure_oid,
|
||||
email=user_claims.get("email", user_claims.get("preferred_username", "")),
|
||||
name=user_claims.get("name", "Unknown"),
|
||||
)
|
||||
|
||||
# Dev mode: auto-promote to super_admin so all features are accessible
|
||||
if settings.DISABLE_AUTH and user.role != "super_admin":
|
||||
user.role = "super_admin"
|
||||
await db.flush()
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def require_role(*allowed_roles: str):
|
||||
"""
|
||||
Dependency factory that restricts access to users with specific roles.
|
||||
|
||||
Usage:
|
||||
@router.get("/admin-only")
|
||||
async def admin_route(user: User = Depends(require_role("super_admin"))):
|
||||
...
|
||||
"""
|
||||
async def _check_role(current_user: User = Depends(get_current_db_user)) -> User:
|
||||
if current_user.role not in allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied. Required role: {', '.join(allowed_roles)}",
|
||||
)
|
||||
return current_user
|
||||
|
||||
return _check_role
|
||||
|
||||
|
||||
async def require_write_access(
|
||||
current_user: User = Depends(get_current_db_user),
|
||||
) -> User:
|
||||
"""
|
||||
Dependency that blocks oversight_admin from write/mutation operations.
|
||||
All other roles are allowed through.
|
||||
"""
|
||||
if current_user.role == "oversight_admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Oversight Admin has read-only access",
|
||||
)
|
||||
return current_user
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from app.config import settings
|
||||
from app.services.auth_service import verify_access_token
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.models.database import init_db, close_db
|
||||
from app.models.database import init_db, close_db, async_session_factory as _session_factory
|
||||
from app.repositories.user_repository import UserRepository
|
||||
from app.api import router as api_router, kb_router
|
||||
|
||||
# Configure logging
|
||||
|
|
@ -193,6 +194,21 @@ async def websocket_analyze(websocket: WebSocket):
|
|||
|
||||
logger.info(f"[MAIN] Authenticated user: {user_claims.get('name', 'unknown')}")
|
||||
|
||||
# Check role: oversight_admin cannot upload/analyze proofs
|
||||
try:
|
||||
async with _session_factory() as ws_session:
|
||||
ws_user_repo = UserRepository(ws_session)
|
||||
azure_oid = user_claims.get("oid") or user_claims.get("sub")
|
||||
ws_user = await ws_user_repo.get_by_azure_oid(azure_oid) if azure_oid else None
|
||||
if ws_user and ws_user.role == "oversight_admin":
|
||||
await manager.send_message(client_id, {
|
||||
"type": "error",
|
||||
"message": "Oversight Admin has read-only access and cannot analyze proofs."
|
||||
})
|
||||
continue
|
||||
except Exception as role_err:
|
||||
logger.warning(f"[MAIN] Role check failed for client {client_id}: {role_err}")
|
||||
|
||||
if analysis_service is None:
|
||||
logger.error("[MAIN] Analysis service not ready")
|
||||
await manager.send_message(client_id, {
|
||||
|
|
|
|||
|
|
@ -89,7 +89,33 @@ class UserRepository:
|
|||
await self.session.flush()
|
||||
return user
|
||||
|
||||
async def update_user(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
role: Optional[str] = None,
|
||||
agency_id: Optional[uuid.UUID] = ..., # type: ignore[assignment]
|
||||
) -> Optional[User]:
|
||||
"""Update user role and/or agency. Pass agency_id=None to unassign."""
|
||||
user = await self.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
if role is not None:
|
||||
user.role = role
|
||||
# Use sentinel (...) to distinguish "not provided" from "set to None"
|
||||
if agency_id is not ...:
|
||||
user.agency_id = agency_id
|
||||
await self.session.flush()
|
||||
# Re-fetch to get agency relationship loaded
|
||||
return await self.get_by_id(user_id)
|
||||
|
||||
async def list_agencies(self) -> list[Agency]:
|
||||
"""List all agencies."""
|
||||
result = await self.session.execute(select(Agency))
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create_agency(self, name: str) -> Agency:
|
||||
"""Create a new agency."""
|
||||
agency = Agency(name=name)
|
||||
self.session.add(agency)
|
||||
await self.session.flush()
|
||||
return agency
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue