feat(mt-16): JWT org_ids claim + transient user.org_ids in deps
- create_access_token gains optional org_ids: list[str] param; encodes
{exp, sub, org_ids, v:2} — org_ids is a prefilter hint only, never
used as authorization source of truth (Redis cache is authoritative)
- Login, MS login, refresh endpoints: fetch memberships and include
org_ids in issued access tokens via _get_user_org_ids() helper
- routes_invitations.py accept flow: same org_ids population on token
- get_current_user: reads org_ids from payload, attaches as transient
user.__dict__["org_ids"] — available to OrgScopedQuery for prefilter
- Force logout: rotate JWT_SECRET env var at deployment time (no code
change needed; all existing tokens immediately invalidated)
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
54fcf47887
commit
4623b89aeb
4 changed files with 29 additions and 9 deletions
|
|
@ -37,6 +37,13 @@ router = APIRouter(prefix="/auth", tags=["auth"])
|
|||
security = HTTPBearer()
|
||||
|
||||
|
||||
async def _get_user_org_ids(user_id: str, db: AsyncIOMotorDatabase) -> list[str]:
|
||||
"""Return list of org IDs the user belongs to — used as a JWT hint only."""
|
||||
cursor = db.memberships.find({"user_id": user_id}, {"organization_id": 1})
|
||||
memberships = await cursor.to_list(length=200)
|
||||
return [str(m["organization_id"]) for m in memberships if m.get("organization_id")]
|
||||
|
||||
|
||||
def _set_auth_cookies(response: Response, refresh_token: str) -> str:
|
||||
"""Set httponly refresh_token cookie and readable csrf_token cookie. Returns the csrf token."""
|
||||
csrf_token = secrets.token_hex(32)
|
||||
|
|
@ -102,7 +109,8 @@ async def login(
|
|||
detail="User account is disabled",
|
||||
)
|
||||
|
||||
access_token = create_access_token(subject=str(user.id))
|
||||
org_ids = await _get_user_org_ids(str(user.id), db)
|
||||
access_token = create_access_token(subject=str(user.id), org_ids=org_ids)
|
||||
refresh_token = create_refresh_token(subject=str(user.id))
|
||||
|
||||
_set_auth_cookies(response, refresh_token)
|
||||
|
|
@ -183,7 +191,8 @@ async def microsoft_login(
|
|||
detail="User account is disabled",
|
||||
)
|
||||
|
||||
access_token = create_access_token(subject=str(user.id))
|
||||
org_ids = await _get_user_org_ids(str(user.id), db)
|
||||
access_token = create_access_token(subject=str(user.id), org_ids=org_ids)
|
||||
refresh_token = create_refresh_token(subject=str(user.id))
|
||||
|
||||
_set_auth_cookies(response, refresh_token)
|
||||
|
|
@ -253,8 +262,9 @@ async def refresh_token(
|
|||
detail="User account is disabled",
|
||||
)
|
||||
|
||||
# Create new tokens
|
||||
new_access_token = create_access_token(subject=user_id)
|
||||
# Create new tokens (include org_ids claim for prefilter hint)
|
||||
_org_ids = await _get_user_org_ids(user_id, db)
|
||||
new_access_token = create_access_token(subject=user_id, org_ids=_org_ids)
|
||||
new_refresh_token = create_refresh_token(subject=user_id)
|
||||
|
||||
# Rotate both refresh and CSRF cookies
|
||||
|
|
|
|||
|
|
@ -333,8 +333,9 @@ async def accept_invitation(
|
|||
org_name=org_name,
|
||||
)
|
||||
|
||||
# Issue JWT tokens
|
||||
access_token = create_access_token(subject=user_id)
|
||||
# Issue JWT tokens with org_ids claim
|
||||
_inv_org_ids = [m["organization_id"] async for m in db.memberships.find({"user_id": user_id}, {"organization_id": 1})]
|
||||
access_token = create_access_token(subject=user_id, org_ids=[str(o) for o in _inv_org_ids if o])
|
||||
refresh_token = create_refresh_token(subject=user_id)
|
||||
|
||||
org_name, org_slug = await _get_org_name(org_id, db)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,12 @@ async def get_current_user(
|
|||
detail="User not found",
|
||||
)
|
||||
|
||||
return User(**user_doc)
|
||||
user = User(**user_doc)
|
||||
# Attach org_ids hint from token as transient attribute (never used for authz)
|
||||
token_org_ids = payload.get("org_ids", [])
|
||||
if token_org_ids:
|
||||
user.__dict__["org_ids"] = token_org_ids
|
||||
return user
|
||||
|
||||
|
||||
def require_role(required_role: UserRole):
|
||||
|
|
|
|||
|
|
@ -11,14 +11,18 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
org_ids: list[str] | None = None,
|
||||
) -> str:
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.jwt_access_ttl_min)
|
||||
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
to_encode: dict[str, Any] = {"exp": expire, "sub": str(subject), "v": 2}
|
||||
if org_ids:
|
||||
to_encode["org_ids"] = org_ids
|
||||
encoded_jwt = jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_alg)
|
||||
return encoded_jwt
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue