cohorta/backend/app/auth/quart_jwt.py
2025-12-19 19:26:16 +00:00

214 lines
No EOL
7.8 KiB
Python
Executable file

"""
Quart-compatible JWT Authentication
Replacement for Flask-JWT-Extended to work with Quart ASGI applications.
Provides jwt_required decorator and token management functions.
"""
import os
import jwt
import functools
import json
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from quart import request, g, current_app, jsonify, Response
# JWT Configuration - ensure compatibility with Flask-JWT-Extended
JWT_SECRET_KEY = os.environ.get('SECRET_KEY', 'your-secret-key-for-sessions-and-tokens')
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
JWT_ALGORITHM = 'HS256'
class QuartJWTError(Exception):
"""Base exception for JWT errors in Quart."""
pass
def create_access_token(identity: str, expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT access token.
Args:
identity: User identifier (usually user ID)
expires_delta: Optional expiration time override
Returns:
JWT token string
"""
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + JWT_ACCESS_TOKEN_EXPIRES
payload = {
'sub': identity, # Subject (user ID)
'exp': expire,
'iat': datetime.utcnow(),
'type': 'access'
}
return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
def decode_token(token: str) -> Dict[str, Any]:
"""
Decode and validate a JWT token.
Args:
token: JWT token string
Returns:
Decoded token payload
Raises:
QuartJWTError: If token is invalid
"""
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
raise QuartJWTError("Token has expired")
except jwt.InvalidTokenError as e:
raise QuartJWTError(f"Invalid token: {str(e)}")
def get_jwt_identity() -> Optional[str]:
"""
Get the identity (user ID) from the current JWT token.
Returns:
User ID from token, or None if no valid token
"""
try:
return getattr(g, 'current_user_id', None)
except Exception:
return None
def jwt_required(optional: bool = False):
"""
Decorator to require valid JWT token for route access.
Args:
optional: If True, allow access without token (but still decode if present)
"""
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
try:
# Get token from Authorization header
auth_header = request.headers.get('Authorization')
token = None
if auth_header:
# Expected format: "Bearer <token>"
parts = auth_header.split()
if len(parts) == 2 and parts[0].lower() == 'bearer':
token = parts[1]
if not token:
if optional:
# No token provided but optional - allow access
g.current_user_id = None
result = await func(*args, **kwargs)
# Handle tuple returns
if isinstance(result, tuple) and len(result) == 2:
response, status_code = result
if hasattr(response, 'status_code'):
response.status_code = status_code
return response
else:
from quart import make_response
return make_response(response, status_code)
else:
return result
else:
# Token required but not provided
return Response(
json.dumps({'error': 'Missing authorization token'}),
status=401,
mimetype="application/json"
)
# Validate token
try:
payload = decode_token(token)
user_id = payload.get('sub')
if not user_id:
raise QuartJWTError("No user ID in token")
# Store user ID in request context
g.current_user_id = user_id
# Call the actual route function and handle tuple returns
result = await func(*args, **kwargs)
# Handle tuple returns (response, status_code)
if isinstance(result, tuple) and len(result) == 2:
response, status_code = result
if hasattr(response, 'status_code'):
response.status_code = status_code
return response
else:
# Use make_response for non-response objects
from quart import make_response
return make_response(response, status_code)
else:
return result
except QuartJWTError as e:
if optional:
# Invalid token but optional - allow access without user ID
g.current_user_id = None
result = await func(*args, **kwargs)
# Handle tuple returns here too
if isinstance(result, tuple) and len(result) == 2:
response, status_code = result
if hasattr(response, 'status_code'):
response.status_code = status_code
return response
else:
from quart import make_response
return make_response(response, status_code)
else:
return result
else:
# Invalid token and required
return Response(
json.dumps({'error': f'Invalid token: {str(e)}'}),
status=401,
mimetype="application/json"
)
except Exception as e:
current_app.logger.error(f"JWT validation error: {e}")
if optional:
g.current_user_id = None
result = await func(*args, **kwargs)
# Handle tuple returns here too
if isinstance(result, tuple) and len(result) == 2:
response, status_code = result
if hasattr(response, 'status_code'):
response.status_code = status_code
return response
else:
from quart import make_response
return make_response(response, status_code)
else:
return result
else:
return Response(
json.dumps({'error': 'Authentication error'}),
status=500,
mimetype="application/json"
)
return wrapper
return decorator
# For backward compatibility, provide the same function names as Flask-JWT-Extended
def get_current_user():
"""Get current user ID (alias for get_jwt_identity)."""
return get_jwt_identity()