diff --git a/backend/app/api/v1/routes_auth.py b/backend/app/api/v1/routes_auth.py index 4ae22f8..611f382 100644 --- a/backend/app/api/v1/routes_auth.py +++ b/backend/app/api/v1/routes_auth.py @@ -37,6 +37,13 @@ router = APIRouter(prefix="/auth", tags=["auth"]) security = HTTPBearer() +async def _get_user_org_ids(user_id: str, db: AsyncIOMotorDatabase) -> list[str]: + """Return list of org IDs the user belongs to — used as a JWT hint only.""" + cursor = db.memberships.find({"user_id": user_id}, {"organization_id": 1}) + memberships = await cursor.to_list(length=200) + return [str(m["organization_id"]) for m in memberships if m.get("organization_id")] + + def _set_auth_cookies(response: Response, refresh_token: str) -> str: """Set httponly refresh_token cookie and readable csrf_token cookie. Returns the csrf token.""" csrf_token = secrets.token_hex(32) @@ -102,7 +109,8 @@ async def login( detail="User account is disabled", ) - access_token = create_access_token(subject=str(user.id)) + org_ids = await _get_user_org_ids(str(user.id), db) + access_token = create_access_token(subject=str(user.id), org_ids=org_ids) refresh_token = create_refresh_token(subject=str(user.id)) _set_auth_cookies(response, refresh_token) @@ -183,7 +191,8 @@ async def microsoft_login( detail="User account is disabled", ) - access_token = create_access_token(subject=str(user.id)) + org_ids = await _get_user_org_ids(str(user.id), db) + access_token = create_access_token(subject=str(user.id), org_ids=org_ids) refresh_token = create_refresh_token(subject=str(user.id)) _set_auth_cookies(response, refresh_token) @@ -253,8 +262,9 @@ async def refresh_token( detail="User account is disabled", ) - # Create new tokens - new_access_token = create_access_token(subject=user_id) + # Create new tokens (include org_ids claim for prefilter hint) + _org_ids = await _get_user_org_ids(user_id, db) + new_access_token = create_access_token(subject=user_id, org_ids=_org_ids) new_refresh_token = create_refresh_token(subject=user_id) # Rotate both refresh and CSRF cookies diff --git a/backend/app/api/v1/routes_invitations.py b/backend/app/api/v1/routes_invitations.py index d0a7706..3892d89 100644 --- a/backend/app/api/v1/routes_invitations.py +++ b/backend/app/api/v1/routes_invitations.py @@ -333,8 +333,9 @@ async def accept_invitation( org_name=org_name, ) - # Issue JWT tokens - access_token = create_access_token(subject=user_id) + # Issue JWT tokens with org_ids claim + _inv_org_ids = [m["organization_id"] async for m in db.memberships.find({"user_id": user_id}, {"organization_id": 1})] + access_token = create_access_token(subject=user_id, org_ids=[str(o) for o in _inv_org_ids if o]) refresh_token = create_refresh_token(subject=user_id) org_name, org_slug = await _get_org_name(org_id, db) diff --git a/backend/app/core/dependencies.py b/backend/app/core/dependencies.py index ef76542..e943e3e 100644 --- a/backend/app/core/dependencies.py +++ b/backend/app/core/dependencies.py @@ -43,7 +43,12 @@ async def get_current_user( detail="User not found", ) - return User(**user_doc) + user = User(**user_doc) + # Attach org_ids hint from token as transient attribute (never used for authz) + token_org_ids = payload.get("org_ids", []) + if token_org_ids: + user.__dict__["org_ids"] = token_org_ids + return user def require_role(required_role: UserRole): diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 960c33b..63d5d1b 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -11,14 +11,18 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def create_access_token( - subject: Union[str, Any], expires_delta: Optional[timedelta] = None + subject: Union[str, Any], + expires_delta: Optional[timedelta] = None, + org_ids: list[str] | None = None, ) -> str: if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=settings.jwt_access_ttl_min) - to_encode = {"exp": expire, "sub": str(subject)} + to_encode: dict[str, Any] = {"exp": expire, "sub": str(subject), "v": 2} + if org_ids: + to_encode["org_ids"] = org_ids encoded_jwt = jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_alg) return encoded_jwt