214 lines
No EOL
7.8 KiB
Python
Executable file
214 lines
No EOL
7.8 KiB
Python
Executable file
"""
|
|
Quart-compatible JWT Authentication
|
|
|
|
Replacement for Flask-JWT-Extended to work with Quart ASGI applications.
|
|
Provides jwt_required decorator and token management functions.
|
|
"""
|
|
|
|
import os
|
|
import jwt
|
|
import functools
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, Any
|
|
from quart import request, g, current_app, jsonify, Response
|
|
|
|
# JWT Configuration - ensure compatibility with Flask-JWT-Extended
|
|
JWT_SECRET_KEY = os.environ.get('SECRET_KEY', 'your-secret-key-for-sessions-and-tokens')
|
|
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
|
|
JWT_ALGORITHM = 'HS256'
|
|
|
|
|
|
class QuartJWTError(Exception):
|
|
"""Base exception for JWT errors in Quart."""
|
|
pass
|
|
|
|
|
|
def create_access_token(identity: str, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""
|
|
Create a JWT access token.
|
|
|
|
Args:
|
|
identity: User identifier (usually user ID)
|
|
expires_delta: Optional expiration time override
|
|
|
|
Returns:
|
|
JWT token string
|
|
"""
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + JWT_ACCESS_TOKEN_EXPIRES
|
|
|
|
payload = {
|
|
'sub': identity, # Subject (user ID)
|
|
'exp': expire,
|
|
'iat': datetime.utcnow(),
|
|
'type': 'access'
|
|
}
|
|
|
|
return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
|
|
|
|
|
def decode_token(token: str) -> Dict[str, Any]:
|
|
"""
|
|
Decode and validate a JWT token.
|
|
|
|
Args:
|
|
token: JWT token string
|
|
|
|
Returns:
|
|
Decoded token payload
|
|
|
|
Raises:
|
|
QuartJWTError: If token is invalid
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
raise QuartJWTError("Token has expired")
|
|
except jwt.InvalidTokenError as e:
|
|
raise QuartJWTError(f"Invalid token: {str(e)}")
|
|
|
|
|
|
def get_jwt_identity() -> Optional[str]:
|
|
"""
|
|
Get the identity (user ID) from the current JWT token.
|
|
|
|
Returns:
|
|
User ID from token, or None if no valid token
|
|
"""
|
|
try:
|
|
return getattr(g, 'current_user_id', None)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def jwt_required(optional: bool = False):
|
|
"""
|
|
Decorator to require valid JWT token for route access.
|
|
|
|
Args:
|
|
optional: If True, allow access without token (but still decode if present)
|
|
"""
|
|
def decorator(func):
|
|
@functools.wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
try:
|
|
# Get token from Authorization header
|
|
auth_header = request.headers.get('Authorization')
|
|
token = None
|
|
|
|
if auth_header:
|
|
# Expected format: "Bearer <token>"
|
|
parts = auth_header.split()
|
|
if len(parts) == 2 and parts[0].lower() == 'bearer':
|
|
token = parts[1]
|
|
|
|
if not token:
|
|
if optional:
|
|
# No token provided but optional - allow access
|
|
g.current_user_id = None
|
|
result = await func(*args, **kwargs)
|
|
# Handle tuple returns
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
response, status_code = result
|
|
if hasattr(response, 'status_code'):
|
|
response.status_code = status_code
|
|
return response
|
|
else:
|
|
from quart import make_response
|
|
return make_response(response, status_code)
|
|
else:
|
|
return result
|
|
else:
|
|
# Token required but not provided
|
|
return Response(
|
|
json.dumps({'error': 'Missing authorization token'}),
|
|
status=401,
|
|
mimetype="application/json"
|
|
)
|
|
|
|
# Validate token
|
|
try:
|
|
payload = decode_token(token)
|
|
user_id = payload.get('sub')
|
|
|
|
if not user_id:
|
|
raise QuartJWTError("No user ID in token")
|
|
|
|
# Store user ID in request context
|
|
g.current_user_id = user_id
|
|
|
|
# Call the actual route function and handle tuple returns
|
|
result = await func(*args, **kwargs)
|
|
|
|
# Handle tuple returns (response, status_code)
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
response, status_code = result
|
|
if hasattr(response, 'status_code'):
|
|
response.status_code = status_code
|
|
return response
|
|
else:
|
|
# Use make_response for non-response objects
|
|
from quart import make_response
|
|
return make_response(response, status_code)
|
|
else:
|
|
return result
|
|
|
|
except QuartJWTError as e:
|
|
if optional:
|
|
# Invalid token but optional - allow access without user ID
|
|
g.current_user_id = None
|
|
result = await func(*args, **kwargs)
|
|
# Handle tuple returns here too
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
response, status_code = result
|
|
if hasattr(response, 'status_code'):
|
|
response.status_code = status_code
|
|
return response
|
|
else:
|
|
from quart import make_response
|
|
return make_response(response, status_code)
|
|
else:
|
|
return result
|
|
else:
|
|
# Invalid token and required
|
|
return Response(
|
|
json.dumps({'error': f'Invalid token: {str(e)}'}),
|
|
status=401,
|
|
mimetype="application/json"
|
|
)
|
|
|
|
except Exception as e:
|
|
current_app.logger.error(f"JWT validation error: {e}")
|
|
if optional:
|
|
g.current_user_id = None
|
|
result = await func(*args, **kwargs)
|
|
# Handle tuple returns here too
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
response, status_code = result
|
|
if hasattr(response, 'status_code'):
|
|
response.status_code = status_code
|
|
return response
|
|
else:
|
|
from quart import make_response
|
|
return make_response(response, status_code)
|
|
else:
|
|
return result
|
|
else:
|
|
return Response(
|
|
json.dumps({'error': 'Authentication error'}),
|
|
status=500,
|
|
mimetype="application/json"
|
|
)
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
# For backward compatibility, provide the same function names as Flask-JWT-Extended
|
|
def get_current_user():
|
|
"""Get current user ID (alias for get_jwt_identity)."""
|
|
return get_jwt_identity() |