""" Central authorization module (Phase 3 SaaS). Provides: - MembershipContext — per-request resolved memberships dict - get_membership_context — FastAPI dependency - require_org_role(min_role) — dependency factory for org-scoped endpoints - require_platform_admin — dependency for platform-only endpoints - OrgScopedQuery — helper to add organization_id filter to MongoDB queries """ import json from dataclasses import dataclass from typing import Optional from fastapi import Depends, HTTPException, status from motor.motor_asyncio import AsyncIOMotorDatabase from ..models.organization import OrgRole from ..models.user import User, UserRole from .database import get_database from .dependencies import get_current_user from .redis import get_redis # Roles that see everything regardless of membership (kill-switch in comments) # To restore old bypass: change STAFF_ROLES back to: # {UserRole.ADMIN, UserRole.REVIEWER, UserRole.LINGUIST, UserRole.PRODUCTION} PLATFORM_ADMIN_ROLES = {UserRole.ADMIN} MEMBERSHIP_CACHE_TTL = 60 # seconds @dataclass class MembershipContext: user: User is_platform_admin: bool # org_id → OrgRole mapping from the memberships collection memberships: dict[str, OrgRole] def can_access_org(self, org_id: str, min_role: OrgRole = OrgRole.VIEWER) -> bool: if self.is_platform_admin: return True role = self.memberships.get(org_id) return role is not None and role >= min_role def accessible_org_ids(self) -> list[str]: """Return all org IDs the user has any membership in.""" return list(self.memberships.keys()) async def _load_memberships(user_id: str, db: AsyncIOMotorDatabase) -> dict[str, OrgRole]: """Load memberships from DB.""" result: dict[str, OrgRole] = {} async for doc in db.memberships.find({"user_id": user_id}): try: result[doc["organization_id"]] = OrgRole(doc["role_in_org"]) except ValueError: pass return result async def _cached_memberships( user_id: str, db: AsyncIOMotorDatabase, ) -> dict[str, OrgRole]: """Load memberships, with Redis cache (60s TTL).""" try: redis = get_redis() if redis: cache_key = f"mem:user:{user_id}" cached = await redis.get(cache_key) if cached: raw = json.loads(cached) return {org_id: OrgRole(role) for org_id, role in raw.items()} except Exception: pass memberships = await _load_memberships(user_id, db) try: redis = get_redis() if redis: await redis.setex( cache_key, MEMBERSHIP_CACHE_TTL, json.dumps({k: v.value for k, v in memberships.items()}), ) except Exception: pass return memberships async def get_membership_context( current_user: User = Depends(get_current_user), db: AsyncIOMotorDatabase = Depends(get_database), ) -> MembershipContext: is_platform_admin = current_user.role in PLATFORM_ADMIN_ROLES if is_platform_admin: return MembershipContext( user=current_user, is_platform_admin=True, memberships={}, ) memberships = await _cached_memberships(str(current_user.id), db) return MembershipContext( user=current_user, is_platform_admin=False, memberships=memberships, ) def require_org_role(min_role: OrgRole): """ Dependency factory: ensures the current user has at least `min_role` in the organization identified by the `org_id` path parameter. Platform admins always pass. """ async def checker( org_id: str, ctx: MembershipContext = Depends(get_membership_context), ) -> MembershipContext: if not ctx.can_access_org(org_id, min_role): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Requires {min_role.value} role in this organization", ) return ctx return checker def require_platform_admin(): """Dependency: platform admin only.""" async def checker(ctx: MembershipContext = Depends(get_membership_context)) -> MembershipContext: if not ctx.is_platform_admin: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Platform admin required") return ctx return checker class OrgScopedQuery: """ Helper that adds organization_id filters to MongoDB queries. Usage: scoped = OrgScopedQuery(ctx) query = scoped.filter({"status": "completed"}, org_id_from_request) If the user is a platform admin, the query is returned unchanged. If a specific org_id is given, it's validated against the user's memberships. If no org_id is given, the query is scoped to all orgs the user belongs to. """ def __init__(self, ctx: MembershipContext): self.ctx = ctx def filter( self, base_query: dict, org_id: Optional[str] = None, org_field: str = "organization_id", ) -> dict: if self.ctx.is_platform_admin: if org_id: return {**base_query, org_field: org_id} return base_query if org_id: if not self.ctx.can_access_org(org_id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Access to this organization is not permitted", ) return {**base_query, org_field: org_id} accessible = self.ctx.accessible_org_ids() if not accessible: # User has no memberships — return an impossible query return {**base_query, org_field: {"$in": []}} return {**base_query, org_field: {"$in": accessible}} async def bump_user_membership_cache(user_id: str) -> None: """Invalidate the Redis membership cache for a user (call on any membership write).""" try: redis = get_redis() if redis: await redis.delete(f"mem:user:{user_id}") except Exception: pass