236 lines
7.7 KiB
Python
236 lines
7.7 KiB
Python
"""
|
|
Central authorization module (Phase 3 SaaS).
|
|
|
|
Provides:
|
|
- MembershipContext — per-request resolved memberships dict
|
|
- get_membership_context — FastAPI dependency
|
|
- require_org_role(min_role) — dependency factory for org-scoped endpoints
|
|
- require_platform_admin — dependency for platform-only endpoints
|
|
- OrgScopedQuery — helper to add organization_id filter to MongoDB queries
|
|
"""
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
|
|
from fastapi import Depends, HTTPException, status
|
|
from motor.motor_asyncio import AsyncIOMotorDatabase
|
|
|
|
from ..models.organization import OrgRole
|
|
from ..models.user import User, UserRole
|
|
from .database import get_database
|
|
from .dependencies import get_current_user
|
|
from .redis import get_redis
|
|
|
|
# Roles that see everything regardless of membership (kill-switch in comments)
|
|
# To restore old bypass: change STAFF_ROLES back to:
|
|
# {UserRole.ADMIN, UserRole.REVIEWER, UserRole.LINGUIST, UserRole.PRODUCTION}
|
|
PLATFORM_ADMIN_ROLES = {UserRole.ADMIN}
|
|
|
|
MEMBERSHIP_CACHE_TTL = 60 # seconds
|
|
|
|
|
|
@dataclass
|
|
class MembershipContext:
|
|
user: User
|
|
is_platform_admin: bool
|
|
# org_id → OrgRole mapping from the memberships collection
|
|
memberships: dict[str, OrgRole]
|
|
|
|
def can_access_org(self, org_id: str, min_role: OrgRole = OrgRole.VIEWER) -> bool:
|
|
if self.is_platform_admin:
|
|
return True
|
|
role = self.memberships.get(org_id)
|
|
return role is not None and role >= min_role
|
|
|
|
def accessible_org_ids(self) -> list[str]:
|
|
"""Return all org IDs the user has any membership in."""
|
|
return list(self.memberships.keys())
|
|
|
|
|
|
async def _load_memberships(user_id: str, db: AsyncIOMotorDatabase) -> dict[str, OrgRole]:
|
|
"""Load memberships from DB."""
|
|
result: dict[str, OrgRole] = {}
|
|
async for doc in db.memberships.find({"user_id": user_id}):
|
|
try:
|
|
result[doc["organization_id"]] = OrgRole(doc["role_in_org"])
|
|
except ValueError:
|
|
pass
|
|
return result
|
|
|
|
|
|
async def _cached_memberships(
|
|
user_id: str,
|
|
db: AsyncIOMotorDatabase,
|
|
) -> dict[str, OrgRole]:
|
|
"""Load memberships, with Redis cache (60s TTL)."""
|
|
cache_key = f"mem:user:{user_id}"
|
|
try:
|
|
redis = get_redis()
|
|
if redis:
|
|
cached = await redis.get(cache_key)
|
|
if cached:
|
|
raw = json.loads(cached)
|
|
return {org_id: OrgRole(role) for org_id, role in raw.items()}
|
|
except Exception:
|
|
pass
|
|
|
|
memberships = await _load_memberships(user_id, db)
|
|
|
|
try:
|
|
redis = get_redis()
|
|
if redis:
|
|
await redis.setex(
|
|
cache_key,
|
|
MEMBERSHIP_CACHE_TTL,
|
|
json.dumps({k: v.value for k, v in memberships.items()}),
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
return memberships
|
|
|
|
|
|
async def get_membership_context(
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncIOMotorDatabase = Depends(get_database),
|
|
) -> MembershipContext:
|
|
is_platform_admin = current_user.role in PLATFORM_ADMIN_ROLES
|
|
if is_platform_admin:
|
|
return MembershipContext(
|
|
user=current_user,
|
|
is_platform_admin=True,
|
|
memberships={},
|
|
)
|
|
|
|
memberships = await _cached_memberships(str(current_user.id), db)
|
|
return MembershipContext(
|
|
user=current_user,
|
|
is_platform_admin=False,
|
|
memberships=memberships,
|
|
)
|
|
|
|
|
|
def require_org_role(min_role: OrgRole):
|
|
"""
|
|
Dependency factory: ensures the current user has at least `min_role` in the
|
|
organization identified by the `org_id` path parameter.
|
|
Platform admins always pass.
|
|
"""
|
|
async def checker(
|
|
org_id: str,
|
|
ctx: MembershipContext = Depends(get_membership_context),
|
|
) -> MembershipContext:
|
|
if not ctx.can_access_org(org_id, min_role):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Requires {min_role.value} role in this organization",
|
|
)
|
|
return ctx
|
|
|
|
return checker
|
|
|
|
|
|
def require_platform_admin():
|
|
"""Dependency: platform admin only."""
|
|
async def checker(ctx: MembershipContext = Depends(get_membership_context)) -> MembershipContext:
|
|
if not ctx.is_platform_admin:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Platform admin required")
|
|
return ctx
|
|
return checker
|
|
|
|
|
|
class OrgScopedQuery:
|
|
"""
|
|
Helper that adds organization_id filters to MongoDB queries.
|
|
|
|
Usage:
|
|
scoped = OrgScopedQuery(ctx)
|
|
query = scoped.filter({"status": "completed"}, org_id_from_request)
|
|
|
|
If the user is a platform admin, the query is returned unchanged.
|
|
If a specific org_id is given, it's validated against the user's memberships.
|
|
If no org_id is given, the query is scoped to all orgs the user belongs to.
|
|
"""
|
|
|
|
def __init__(self, ctx: MembershipContext):
|
|
self.ctx = ctx
|
|
|
|
def filter(
|
|
self,
|
|
base_query: dict,
|
|
org_id: str | None = None,
|
|
org_field: str = "organization_id",
|
|
) -> dict:
|
|
if self.ctx.is_platform_admin:
|
|
if org_id:
|
|
return {**base_query, org_field: org_id}
|
|
return base_query
|
|
|
|
if org_id:
|
|
if not self.ctx.can_access_org(org_id):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access to this organization is not permitted",
|
|
)
|
|
return {**base_query, org_field: org_id}
|
|
|
|
accessible = self.ctx.accessible_org_ids()
|
|
if not accessible:
|
|
# User has no memberships — return an impossible query
|
|
return {**base_query, org_field: {"$in": []}}
|
|
|
|
return {**base_query, org_field: {"$in": accessible}}
|
|
|
|
|
|
def assert_user_in_org(
|
|
ctx: "MembershipContext",
|
|
org_id: str,
|
|
min_role: OrgRole = OrgRole.VIEWER,
|
|
) -> None:
|
|
"""Raise 403 if ctx user does not have min_role in org_id. Platform admins always pass."""
|
|
if not ctx.can_access_org(org_id, min_role):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access to this organization is not permitted",
|
|
)
|
|
|
|
|
|
async def get_job_or_403(
|
|
job_id: str,
|
|
ctx: "MembershipContext",
|
|
db: AsyncIOMotorDatabase,
|
|
) -> dict:
|
|
"""Load job document and verify ctx user can access its organization. Returns 404 for missing jobs."""
|
|
job_doc = await db.jobs.find_one({"_id": job_id})
|
|
if not job_doc:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
|
|
|
|
org_id = job_doc.get("organization_id")
|
|
if not org_id:
|
|
# Legacy job without org: try resolving via project
|
|
project_id = job_doc.get("project_id")
|
|
if project_id:
|
|
project = await db.projects.find_one({"_id": project_id}, {"client_id": 1})
|
|
if project:
|
|
org_id = project.get("client_id")
|
|
|
|
if org_id:
|
|
if not ctx.can_access_org(org_id):
|
|
# Return 404 to avoid leaking existence of cross-org jobs
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
|
|
else:
|
|
# Truly legacy job (no project, no org): only the original uploader or admin can access
|
|
if not ctx.is_platform_admin and job_doc.get("client_id") != str(ctx.user.id):
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
|
|
|
|
return job_doc
|
|
|
|
|
|
async def bump_user_membership_cache(user_id: str) -> None:
|
|
"""Invalidate the Redis membership cache for a user (call on any membership write)."""
|
|
try:
|
|
redis = get_redis()
|
|
if redis:
|
|
await redis.delete(f"mem:user:{user_id}")
|
|
except Exception:
|
|
pass
|