feat(mt-16): JWT org_ids claim + transient user.org_ids in deps

- create_access_token gains optional org_ids: list[str] param; encodes
  {exp, sub, org_ids, v:2} — org_ids is a prefilter hint only, never
  used as authorization source of truth (Redis cache is authoritative)
- Login, MS login, refresh endpoints: fetch memberships and include
  org_ids in issued access tokens via _get_user_org_ids() helper
- routes_invitations.py accept flow: same org_ids population on token
- get_current_user: reads org_ids from payload, attaches as transient
  user.__dict__["org_ids"] — available to OrgScopedQuery for prefilter
- Force logout: rotate JWT_SECRET env var at deployment time (no code
  change needed; all existing tokens immediately invalidated)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Vadym Samoilenko 2026-04-29 20:46:39 +01:00
parent 54fcf47887
commit 4623b89aeb
4 changed files with 29 additions and 9 deletions

View file

@ -37,6 +37,13 @@ router = APIRouter(prefix="/auth", tags=["auth"])
security = HTTPBearer()
async def _get_user_org_ids(user_id: str, db: AsyncIOMotorDatabase) -> list[str]:
"""Return list of org IDs the user belongs to — used as a JWT hint only."""
cursor = db.memberships.find({"user_id": user_id}, {"organization_id": 1})
memberships = await cursor.to_list(length=200)
return [str(m["organization_id"]) for m in memberships if m.get("organization_id")]
def _set_auth_cookies(response: Response, refresh_token: str) -> str:
"""Set httponly refresh_token cookie and readable csrf_token cookie. Returns the csrf token."""
csrf_token = secrets.token_hex(32)
@ -102,7 +109,8 @@ async def login(
detail="User account is disabled",
)
access_token = create_access_token(subject=str(user.id))
org_ids = await _get_user_org_ids(str(user.id), db)
access_token = create_access_token(subject=str(user.id), org_ids=org_ids)
refresh_token = create_refresh_token(subject=str(user.id))
_set_auth_cookies(response, refresh_token)
@ -183,7 +191,8 @@ async def microsoft_login(
detail="User account is disabled",
)
access_token = create_access_token(subject=str(user.id))
org_ids = await _get_user_org_ids(str(user.id), db)
access_token = create_access_token(subject=str(user.id), org_ids=org_ids)
refresh_token = create_refresh_token(subject=str(user.id))
_set_auth_cookies(response, refresh_token)
@ -253,8 +262,9 @@ async def refresh_token(
detail="User account is disabled",
)
# Create new tokens
new_access_token = create_access_token(subject=user_id)
# Create new tokens (include org_ids claim for prefilter hint)
_org_ids = await _get_user_org_ids(user_id, db)
new_access_token = create_access_token(subject=user_id, org_ids=_org_ids)
new_refresh_token = create_refresh_token(subject=user_id)
# Rotate both refresh and CSRF cookies

View file

@ -333,8 +333,9 @@ async def accept_invitation(
org_name=org_name,
)
# Issue JWT tokens
access_token = create_access_token(subject=user_id)
# Issue JWT tokens with org_ids claim
_inv_org_ids = [m["organization_id"] async for m in db.memberships.find({"user_id": user_id}, {"organization_id": 1})]
access_token = create_access_token(subject=user_id, org_ids=[str(o) for o in _inv_org_ids if o])
refresh_token = create_refresh_token(subject=user_id)
org_name, org_slug = await _get_org_name(org_id, db)

View file

@ -43,7 +43,12 @@ async def get_current_user(
detail="User not found",
)
return User(**user_doc)
user = User(**user_doc)
# Attach org_ids hint from token as transient attribute (never used for authz)
token_org_ids = payload.get("org_ids", [])
if token_org_ids:
user.__dict__["org_ids"] = token_org_ids
return user
def require_role(required_role: UserRole):

View file

@ -11,14 +11,18 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
org_ids: list[str] | None = None,
) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.jwt_access_ttl_min)
to_encode = {"exp": expire, "sub": str(subject)}
to_encode: dict[str, Any] = {"exp": expire, "sub": str(subject), "v": 2}
if org_ids:
to_encode["org_ids"] = org_ids
encoded_jwt = jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_alg)
return encoded_jwt