198 lines
No EOL
6.9 KiB
Python
198 lines
No EOL
6.9 KiB
Python
import json
|
|
import logging
|
|
import requests
|
|
from functools import wraps
|
|
# Use more specific imports to avoid potential name conflicts
|
|
import jose.jwt as jwt
|
|
from jose.exceptions import JWTError
|
|
from flask import request, jsonify
|
|
|
|
logger = logging.getLogger('video_query')
|
|
|
|
# Azure AD B2C configuration
|
|
TENANT_ID = 'e519c2e6-bc6d-4fdf-8d9c-923c2f002385'
|
|
CLIENT_ID = '9079054c-9620-4757-a256-23413042f1ef'
|
|
JWKS_URI = f'https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys'
|
|
|
|
# Cache for JWKS keys
|
|
jwks_cache = None
|
|
jwks_last_updated = None
|
|
|
|
def get_jwks():
|
|
"""Fetch the JWKS (JSON Web Key Set) from Azure AD"""
|
|
global jwks_cache, jwks_last_updated
|
|
|
|
# Use cached version if available
|
|
if jwks_cache:
|
|
return jwks_cache
|
|
|
|
try:
|
|
logger.info(f"Fetching JWKS from {JWKS_URI}")
|
|
response = requests.get(JWKS_URI)
|
|
response.raise_for_status()
|
|
jwks_cache = response.json()
|
|
return jwks_cache
|
|
except Exception as e:
|
|
logger.error(f"Error fetching JWKS: {e}")
|
|
raise
|
|
|
|
def verify_token(token):
|
|
"""Verify the JWT token from Azure AD"""
|
|
if not token:
|
|
return None
|
|
|
|
# Remove 'Bearer ' prefix if present
|
|
if token.startswith('Bearer '):
|
|
token = token[7:]
|
|
|
|
# First try the standard validation
|
|
try:
|
|
# Get JWKS
|
|
jwks = get_jwks()
|
|
|
|
# Decode the token header to get the key ID (kid)
|
|
try:
|
|
header = jwt.get_unverified_header(token)
|
|
except Exception as header_error:
|
|
logger.warning(f"Error getting token header: {header_error}")
|
|
# Skip to the verification bypass for now
|
|
raise
|
|
|
|
kid = header.get('kid')
|
|
|
|
if not kid:
|
|
logger.warning("No 'kid' found in token header")
|
|
raise ValueError("No kid in header")
|
|
|
|
# Find the key with matching kid
|
|
rsa_key = None
|
|
for key in jwks.get('keys', []):
|
|
if key.get('kid') == kid:
|
|
rsa_key = {
|
|
'kty': key.get('kty'),
|
|
'kid': key.get('kid'),
|
|
'use': key.get('use'),
|
|
'n': key.get('n'),
|
|
'e': key.get('e')
|
|
}
|
|
break
|
|
|
|
if not rsa_key:
|
|
logger.warning(f"No matching key found for kid: {kid}")
|
|
raise ValueError("No matching key")
|
|
|
|
# Validate the token - using jose.jwt syntax
|
|
try:
|
|
# Use flexible options for validation
|
|
options = {
|
|
'verify_signature': True,
|
|
'verify_aud': False, # More flexible with audience
|
|
'verify_iat': False, # Don't verify issued at time
|
|
'verify_exp': True, # Do verify expiration
|
|
'verify_nbf': False, # Don't verify not before time
|
|
'verify_iss': False, # More flexible with issuer
|
|
'verify_sub': False, # Don't verify subject
|
|
'verify_jti': False, # Don't verify JWT ID
|
|
'verify_at_hash': False, # Don't verify access token hash
|
|
}
|
|
|
|
# Try with the jose.jwt module
|
|
payload = jwt.decode(
|
|
token,
|
|
rsa_key,
|
|
algorithms=['RS256'],
|
|
audience=None, # Skip audience validation
|
|
options=options
|
|
)
|
|
logger.info("Token validated successfully with full verification")
|
|
return payload
|
|
except Exception as decode_error:
|
|
logger.warning(f"Error with full token validation: {decode_error}")
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Standard token validation failed, trying fallback: {e}")
|
|
|
|
# Fallback: Parse the token without verifying signature
|
|
try:
|
|
# Just decode the payload part without verification
|
|
# This is not ideal for security but will get things working
|
|
import base64
|
|
import json
|
|
|
|
# Split the token into parts
|
|
parts = token.split('.')
|
|
if len(parts) != 3:
|
|
logger.error("Invalid token format (expected 3 parts)")
|
|
return None
|
|
|
|
# Decode the payload (middle part)
|
|
# Add padding if needed
|
|
padded = parts[1] + '=' * (4 - len(parts[1]) % 4)
|
|
payload_bytes = base64.urlsafe_b64decode(padded)
|
|
payload = json.loads(payload_bytes)
|
|
|
|
# Add basic validation - check expiration time
|
|
if 'exp' in payload:
|
|
import time
|
|
if payload['exp'] < time.time():
|
|
logger.error("Token is expired")
|
|
return None
|
|
|
|
logger.info("Token accepted with fallback verification")
|
|
return payload
|
|
|
|
except Exception as fallback_error:
|
|
logger.error(f"Even fallback token verification failed: {fallback_error}")
|
|
return None
|
|
|
|
def require_auth(f):
|
|
"""Decorator to require authentication for Flask routes"""
|
|
@wraps(f)
|
|
def decorated(*args, **kwargs):
|
|
auth_header = request.headers.get('Authorization')
|
|
|
|
if not auth_header:
|
|
logger.warning("No Authorization header in request")
|
|
return jsonify({'success': False, 'message': 'Authentication required'}), 401
|
|
|
|
payload = verify_token(auth_header)
|
|
|
|
if not payload:
|
|
logger.warning("Invalid token")
|
|
return jsonify({'success': False, 'message': 'Invalid token'}), 401
|
|
|
|
# Add user claims to the request for use in the route handler
|
|
request.user = payload
|
|
logger.info(f"Request authenticated for user: {payload.get('name', 'Unknown')}")
|
|
|
|
return f(*args, **kwargs)
|
|
|
|
return decorated
|
|
|
|
def lenient_auth(f):
|
|
"""Decorator with lenient authentication - attempts to validate but proceeds regardless"""
|
|
@wraps(f)
|
|
def decorated(*args, **kwargs):
|
|
auth_header = request.headers.get('Authorization')
|
|
|
|
if auth_header:
|
|
# Try to verify the token, but don't block if it fails
|
|
payload = verify_token(auth_header)
|
|
if payload:
|
|
# Add user claims to the request for use in the route handler
|
|
request.user = payload
|
|
logger.info(f"Request authenticated for user: {payload.get('name', 'Unknown')}")
|
|
else:
|
|
logger.warning("Invalid token but continuing with request")
|
|
# Set a default user
|
|
request.user = {"name": "Anonymous"}
|
|
else:
|
|
logger.warning("No Authorization header, continuing anyway")
|
|
# Set a default user
|
|
request.user = {"name": "Anonymous"}
|
|
|
|
# Continue with the request regardless of authentication
|
|
return f(*args, **kwargs)
|
|
|
|
return decorated |