Backend: - token_version in JWT (bump_token_version, get_token_version on User model); jwt_required checks tv claim → 401 on mismatch; login routes embed version - Quota pre-flight in all 3 LLM public methods (QuotaExceededError bubbles up) - AI runner catches QuotaExceededError → sets status paused_quota + emits WS event - Admin routes: POST /users (create), POST /users/<id>/reset-password, POST /pricing, GET /focus-groups with aggregated cost; PUT /users/<id> now bumps token_version on disable or role change - backfill_usage.py: idempotent estimated-event generator for historical data, tiktoken for GPT models, char/3.8 for Gemini, --dry-run flag Frontend: - 402 interceptor dispatches quota_exceeded CustomEvent - adminApi: createUser, resetPassword, createPricing, listFocusGroups - UsersTab: New User dialog + Reset Password in edit dialog - PricingTab: New Price dialog (model, provider, input/output/cached prices) - FocusGroupsTab: focus groups table sorted by total cost - Admin.tsx: 4th tab (Focus Groups) - FocusGroupSession: admin-only cost badge + dismissable quota exceeded banner Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
248 lines
No EOL
9.7 KiB
Python
Executable file
248 lines
No EOL
9.7 KiB
Python
Executable file
"""
|
|
Quart-compatible JWT Authentication
|
|
|
|
Replacement for Flask-JWT-Extended to work with Quart ASGI applications.
|
|
Provides jwt_required decorator and token management functions.
|
|
"""
|
|
|
|
import os
|
|
import jwt
|
|
import functools
|
|
import json
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional, Dict, Any
|
|
from quart import request, g, current_app, jsonify, Response
|
|
|
|
# JWT Configuration — reads SECRET_KEY from env, crashes if missing/weak
|
|
_raw_secret = os.environ.get('SECRET_KEY', '')
|
|
_weak_defaults = {'dev-secret-key', 'your-secret-key-for-sessions-and-tokens', '', 'change-me'}
|
|
if not _raw_secret or _raw_secret in _weak_defaults:
|
|
raise RuntimeError(
|
|
"SECRET_KEY environment variable is not set or uses a weak default. "
|
|
"Set a strong random value in backend/.env before starting the server."
|
|
)
|
|
JWT_SECRET_KEY = _raw_secret
|
|
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
|
|
JWT_ALGORITHM = 'HS256'
|
|
|
|
|
|
class QuartJWTError(Exception):
|
|
"""Base exception for JWT errors in Quart."""
|
|
pass
|
|
|
|
|
|
def create_access_token(identity: str, expires_delta: Optional[timedelta] = None, token_version: int = 0) -> str:
|
|
"""
|
|
Create a JWT access token.
|
|
|
|
Args:
|
|
identity: User identifier (usually user ID)
|
|
expires_delta: Optional expiration time override
|
|
token_version: Token version for invalidation support
|
|
|
|
Returns:
|
|
JWT token string
|
|
"""
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.now(timezone.utc) + JWT_ACCESS_TOKEN_EXPIRES
|
|
|
|
payload = {
|
|
'sub': identity, # Subject (user ID)
|
|
'exp': expire,
|
|
'iat': datetime.now(timezone.utc),
|
|
'type': 'access',
|
|
'tv': token_version,
|
|
}
|
|
|
|
return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
|
|
|
|
|
def decode_token(token: str) -> Dict[str, Any]:
|
|
"""
|
|
Decode and validate a JWT token.
|
|
|
|
Args:
|
|
token: JWT token string
|
|
|
|
Returns:
|
|
Decoded token payload
|
|
|
|
Raises:
|
|
QuartJWTError: If token is invalid
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
raise QuartJWTError("Token has expired")
|
|
except jwt.InvalidTokenError as e:
|
|
raise QuartJWTError(f"Invalid token: {str(e)}")
|
|
|
|
|
|
def get_jwt_identity() -> Optional[str]:
|
|
"""
|
|
Get the identity (user ID) from the current JWT token.
|
|
|
|
Returns:
|
|
User ID from token, or None if no valid token
|
|
"""
|
|
try:
|
|
return getattr(g, 'current_user_id', None)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def jwt_required(optional: bool = False):
|
|
"""
|
|
Decorator to require valid JWT token for route access.
|
|
|
|
Args:
|
|
optional: If True, allow access without token (but still decode if present)
|
|
"""
|
|
def decorator(func):
|
|
@functools.wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
try:
|
|
# Get token from Authorization header
|
|
auth_header = request.headers.get('Authorization')
|
|
token = None
|
|
|
|
if auth_header:
|
|
# Expected format: "Bearer <token>"
|
|
parts = auth_header.split()
|
|
if len(parts) == 2 and parts[0].lower() == 'bearer':
|
|
token = parts[1]
|
|
|
|
if not token:
|
|
if optional:
|
|
# No token provided but optional - allow access
|
|
g.current_user_id = None
|
|
result = await func(*args, **kwargs)
|
|
# Handle tuple returns
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
response, status_code = result
|
|
if hasattr(response, 'status_code'):
|
|
response.status_code = status_code
|
|
return response
|
|
else:
|
|
from quart import make_response
|
|
return make_response(response, status_code)
|
|
else:
|
|
return result
|
|
else:
|
|
# Token required but not provided
|
|
return Response(
|
|
json.dumps({'error': 'Missing authorization token'}),
|
|
status=401,
|
|
mimetype="application/json"
|
|
)
|
|
|
|
# Validate token
|
|
try:
|
|
payload = decode_token(token)
|
|
user_id = payload.get('sub')
|
|
|
|
if not user_id:
|
|
raise QuartJWTError("No user ID in token")
|
|
|
|
# Store user ID in request context
|
|
g.current_user_id = user_id
|
|
|
|
# Token-version check — invalidates old tokens after password reset or disable
|
|
try:
|
|
tv_in_token = payload.get("tv", 0)
|
|
from app.models.user import User as _User
|
|
current_tv = await _User.get_token_version(user_id)
|
|
if tv_in_token < current_tv:
|
|
return Response(
|
|
json.dumps({"error": "Token invalidated"}),
|
|
status=401,
|
|
mimetype="application/json",
|
|
)
|
|
except Exception:
|
|
pass # Non-fatal — a DB failure must not block auth
|
|
|
|
# Propagate user_id into the LLM usage ContextVar for this request.
|
|
# Each Quart request runs in its own asyncio Task, so setting the ContextVar
|
|
# here is request-scoped. Child tasks (create_task) and thread submissions
|
|
# (run_coroutine_threadsafe) inherit this context automatically.
|
|
try:
|
|
from app.services.llm_usage_context import _ctx, current_context
|
|
from dataclasses import replace as _dc_replace
|
|
_ctx.set(_dc_replace(current_context(), user_id=user_id))
|
|
except Exception:
|
|
pass # Non-fatal — telemetry only
|
|
|
|
# Call the actual route function and handle tuple returns
|
|
result = await func(*args, **kwargs)
|
|
|
|
# Handle tuple returns (response, status_code)
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
response, status_code = result
|
|
if hasattr(response, 'status_code'):
|
|
response.status_code = status_code
|
|
return response
|
|
else:
|
|
# Use make_response for non-response objects
|
|
from quart import make_response
|
|
return make_response(response, status_code)
|
|
else:
|
|
return result
|
|
|
|
except QuartJWTError as e:
|
|
if optional:
|
|
# Invalid token but optional - allow access without user ID
|
|
g.current_user_id = None
|
|
result = await func(*args, **kwargs)
|
|
# Handle tuple returns here too
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
response, status_code = result
|
|
if hasattr(response, 'status_code'):
|
|
response.status_code = status_code
|
|
return response
|
|
else:
|
|
from quart import make_response
|
|
return make_response(response, status_code)
|
|
else:
|
|
return result
|
|
else:
|
|
# Invalid token and required
|
|
return Response(
|
|
json.dumps({'error': f'Invalid token: {str(e)}'}),
|
|
status=401,
|
|
mimetype="application/json"
|
|
)
|
|
|
|
except Exception as e:
|
|
current_app.logger.error(f"JWT validation error: {e}")
|
|
if optional:
|
|
g.current_user_id = None
|
|
result = await func(*args, **kwargs)
|
|
# Handle tuple returns here too
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
response, status_code = result
|
|
if hasattr(response, 'status_code'):
|
|
response.status_code = status_code
|
|
return response
|
|
else:
|
|
from quart import make_response
|
|
return make_response(response, status_code)
|
|
else:
|
|
return result
|
|
else:
|
|
return Response(
|
|
json.dumps({'error': 'Authentication error'}),
|
|
status=500,
|
|
mimetype="application/json"
|
|
)
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
# For backward compatibility, provide the same function names as Flask-JWT-Extended
|
|
def get_current_user():
|
|
"""Get current user ID (alias for get_jwt_identity)."""
|
|
return get_jwt_identity() |