fix: code health sweep — M-01 through M-07
M-01 authz.py: move cache_key above try block to avoid NameError when
first Redis call returns None
M-02 main.py: re-enable validation middleware (was TEMPORARILY DISABLED)
M-03 routes_auth.py / main.py: replace print() debug lines with
structured logger calls; logger now module-level in routes_auth.py
M-04 gcs.py: asyncio.get_event_loop() → get_running_loop() (deprecation)
M-05 translate_and_synthesize.py: bind loop vars in closure defaults
to fix B023 ruff warnings (transcreate/translate_captions/etc.)
M-06 rate_limiting.py: only trust X-Forwarded-For when X-Forwarded-Proto
is https; use rightmost entry (proxy-appended) not leftmost
M-07 validation.py: extend MongoDB operator blocklist to cover $expr,
$function, $accumulator, $nin, $gte, $lte, $jsonSchema, $mod
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
86ef5a86fb
commit
4c6624c3d4
7 changed files with 53 additions and 64 deletions
|
|
@ -6,29 +6,31 @@ from motor.motor_asyncio import AsyncIOMotorDatabase
|
|||
|
||||
from ...core.config import settings
|
||||
from ...core.database import get_database
|
||||
from ...core.logging import get_logger
|
||||
from ...core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
verify_password,
|
||||
)
|
||||
from ...models.audit_log import AuditAction, AuditLogSeverity
|
||||
from ...models.user import User, AuthProvider, UserRole
|
||||
from ...schemas.auth import (
|
||||
LoginRequest,
|
||||
LoginResponse,
|
||||
LogoutResponse,
|
||||
RefreshResponse,
|
||||
MicrosoftLoginRequest,
|
||||
MicrosoftLoginResponse,
|
||||
)
|
||||
from ...services.microsoft_auth import (
|
||||
get_microsoft_auth_service,
|
||||
MicrosoftTokenValidationError,
|
||||
MicrosoftAuthError,
|
||||
RefreshResponse,
|
||||
)
|
||||
from ...services.audit_logger import log_auth_success, log_auth_failure, audit_logger
|
||||
from ...models.audit_log import AuditAction, AuditLogSeverity
|
||||
from ...services.microsoft_auth import (
|
||||
get_microsoft_auth_service,
|
||||
MicrosoftAuthError,
|
||||
MicrosoftTokenValidationError,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
security = HTTPBearer()
|
||||
|
||||
|
|
@ -191,29 +193,23 @@ async def refresh_token(
|
|||
db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
):
|
||||
refresh_token = request.cookies.get("refresh_token")
|
||||
print(f"🔍 REFRESH DEBUG: Cookie exists: {bool(refresh_token)}")
|
||||
|
||||
if not refresh_token:
|
||||
print("🚨 REFRESH ERROR: No refresh token in cookies")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token not found",
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"🔍 REFRESH DEBUG: Attempting to decode token...")
|
||||
payload = decode_token(refresh_token)
|
||||
print(f"🔍 REFRESH DEBUG: Token decoded successfully, type={payload.get('type')}")
|
||||
|
||||
if payload.get("type") != "refresh":
|
||||
print(f"🚨 REFRESH ERROR: Wrong token type: {payload.get('type')}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
)
|
||||
|
||||
user_id = payload.get("sub")
|
||||
print(f"🔍 REFRESH DEBUG: User ID from token: {user_id}")
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -250,7 +246,7 @@ async def refresh_token(
|
|||
max_age=settings.jwt_refresh_ttl_days * 24 * 60 * 60,
|
||||
)
|
||||
|
||||
print(f"🔍 REFRESH DEBUG: Refresh successful for user {user_id}")
|
||||
logger.info("Token refresh successful for user %s", user_id)
|
||||
return RefreshResponse(
|
||||
access_token=new_access_token,
|
||||
user_id=user_id,
|
||||
|
|
@ -263,10 +259,7 @@ async def refresh_token(
|
|||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
from ...core.logging import get_logger
|
||||
get_logger(__name__).exception(
|
||||
"Refresh token error: %s\n%s", type(e).__name__, traceback.format_exc()
|
||||
)
|
||||
logger.exception("Refresh token error: %s\n%s", type(e).__name__, traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
|
|
|
|||
|
|
@ -64,10 +64,10 @@ async def _cached_memberships(
|
|||
db: AsyncIOMotorDatabase,
|
||||
) -> dict[str, OrgRole]:
|
||||
"""Load memberships, with Redis cache (60s TTL)."""
|
||||
cache_key = f"mem:user:{user_id}"
|
||||
try:
|
||||
redis = get_redis()
|
||||
if redis:
|
||||
cache_key = f"mem:user:{user_id}"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
raw = json.loads(cached)
|
||||
|
|
|
|||
|
|
@ -144,16 +144,14 @@ async def cors_error_handler(request, call_next):
|
|||
try:
|
||||
response = await call_next(request)
|
||||
except Exception as e:
|
||||
# LOG THE EXCEPTION BEFORE HANDLING IT
|
||||
print(f"🚨 EXCEPTION IN CORS MIDDLEWARE: {e}")
|
||||
import traceback
|
||||
print(f"Traceback:\n{traceback.format_exc()}")
|
||||
from .core.logging import get_logger as _get_logger
|
||||
_get_logger(__name__).exception("🚨 CORS middleware caught: %s\n%s", e, traceback.format_exc())
|
||||
|
||||
# Handle any unhandled exceptions and add CORS headers
|
||||
from fastapi.responses import JSONResponse
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error", "error": str(e)}
|
||||
content={"detail": "Internal server error"},
|
||||
)
|
||||
|
||||
# Always add CORS headers for allowed origins
|
||||
|
|
@ -215,18 +213,14 @@ async def general_exception_handler(request: Request, exc: Exception):
|
|||
from .core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.error(f"Unhandled exception in {request.method} {request.url.path}: {exc}")
|
||||
logger.error(f"Exception type: {type(exc).__name__}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
# Also print to stdout for immediate visibility
|
||||
print(f"🚨 UNHANDLED EXCEPTION: {request.method} {request.url.path}")
|
||||
print(f"Exception: {exc}")
|
||||
print(f"Traceback:\n{traceback.format_exc()}")
|
||||
logger.exception(
|
||||
"🚨 Unhandled %s %s: %s\n%s",
|
||||
request.method, request.url.path, exc, traceback.format_exc(),
|
||||
)
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error", "error": str(exc)}
|
||||
content={"detail": "Internal server error"},
|
||||
)
|
||||
|
||||
# Add CORS headers
|
||||
|
|
@ -248,11 +242,7 @@ async def rate_limiting_middleware(request, call_next):
|
|||
@app.middleware("http")
|
||||
async def validation_middleware(request, call_next):
|
||||
"""Apply request validation middleware."""
|
||||
# TEMPORARILY DISABLED FOR DEBUGGING
|
||||
return await call_next(request)
|
||||
|
||||
# Skip middleware for auth endpoints during debugging
|
||||
if request.url.path in ["/api/v1/auth/login", "/api/v1/auth/refresh"]:
|
||||
if request.url.path in ["/health", "/metrics", "/api/v1/auth/login", "/api/v1/auth/refresh"]:
|
||||
return await call_next(request)
|
||||
if hasattr(app.state, 'validation_middleware'):
|
||||
return await app.state.validation_middleware(request, call_next)
|
||||
|
|
|
|||
|
|
@ -113,16 +113,19 @@ class RateLimitMiddleware:
|
|||
|
||||
def _get_client_identifier(self, request: Request) -> str:
|
||||
"""Get client identifier for rate limiting."""
|
||||
# Try to get user ID from JWT token
|
||||
user = getattr(request.state, 'user', None)
|
||||
if user:
|
||||
return f"user:{user.id}"
|
||||
|
||||
# Fall back to IP address
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
return f"ip:{forwarded_for.split(',')[0].strip()}"
|
||||
|
||||
|
||||
# Only trust X-Forwarded-For when the request arrived via HTTPS (i.e. through
|
||||
# the Apache/nginx reverse proxy). On plain HTTP (direct connections, local
|
||||
# dev) the header can be forged, so we fall back to the socket IP.
|
||||
if request.headers.get("X-Forwarded-Proto") == "https":
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# Take the right-most IP added by the trusted proxy, not client-supplied ones.
|
||||
return f"ip:{forwarded_for.split(',')[-1].strip()}"
|
||||
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
return f"ip:{client_ip}"
|
||||
|
||||
|
|
|
|||
|
|
@ -58,8 +58,11 @@ class RequestValidator:
|
|||
r"[;&|`](?!\s*$)", # Allow $ but not as command separator
|
||||
r"(rm|wget|curl|nc|bash|sh|cmd|powershell)\s+",
|
||||
|
||||
# MongoDB injection
|
||||
r"\$where|\$ne|\$gt|\$lt|\$regex",
|
||||
# MongoDB injection — NoSQL operator abuse
|
||||
r"\$where|\$expr|\$function|\$accumulator"
|
||||
r"|\$ne|\$nin|\$not"
|
||||
r"|\$gt|\$gte|\$lt|\$lte"
|
||||
r"|\$regex|\$jsonSchema|\$mod",
|
||||
]
|
||||
|
||||
self.compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.malicious_patterns]
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class GCSService:
|
|||
|
||||
return f"gs://{settings.gcs_bucket}/{destination_path}"
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
return await loop.run_in_executor(self.executor, _upload)
|
||||
except Exception as e:
|
||||
|
|
@ -61,7 +61,7 @@ class GCSService:
|
|||
|
||||
return f"gs://{settings.gcs_bucket}/{destination_path}"
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
return await loop.run_in_executor(self.executor, _upload)
|
||||
except Exception as e:
|
||||
|
|
@ -90,7 +90,7 @@ class GCSService:
|
|||
version="v4"
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
return await loop.run_in_executor(self.executor, _get_signed_url)
|
||||
except NotFound:
|
||||
|
|
@ -106,7 +106,7 @@ class GCSService:
|
|||
blob.delete()
|
||||
return True
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
return await loop.run_in_executor(self.executor, _delete)
|
||||
except NotFound:
|
||||
|
|
@ -121,7 +121,7 @@ class GCSService:
|
|||
blob = self.bucket.blob(blob_path)
|
||||
return blob.exists()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(self.executor, _exists)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -359,13 +359,13 @@ async def _async_translate_and_synthesize(job_id: str):
|
|||
try:
|
||||
if language in transcreation_languages:
|
||||
# TRADITIONAL MODE with transcreation: cultural adaptation
|
||||
async def transcreate():
|
||||
async def transcreate(_lang=language, _gloss=_glossary):
|
||||
return await gemini_service.transcreate_content(
|
||||
source_captions_vtt,
|
||||
source_ad_vtt,
|
||||
language,
|
||||
_lang,
|
||||
brief="Standard accessibility content",
|
||||
glossary_block=_glossary,
|
||||
glossary_block=_gloss,
|
||||
_cost_ctx=_cost_ctx,
|
||||
)
|
||||
|
||||
|
|
@ -376,17 +376,17 @@ async def _async_translate_and_synthesize(job_id: str):
|
|||
|
||||
else:
|
||||
# TRADITIONAL MODE: Use Gemini translation (6-36x cheaper than Google Translate API)
|
||||
async def translate_captions():
|
||||
async def translate_captions(_lang=language, _gloss=_glossary):
|
||||
return await gemini_service.translate_vtt(
|
||||
source_captions_vtt, language, source_language=source_language,
|
||||
glossary_block=_glossary,
|
||||
source_captions_vtt, _lang, source_language=source_language,
|
||||
glossary_block=_gloss,
|
||||
_cost_ctx=_cost_ctx,
|
||||
)
|
||||
|
||||
async def translate_ad():
|
||||
async def translate_ad(_lang=language, _gloss=_glossary):
|
||||
return await gemini_service.translate_vtt(
|
||||
source_ad_vtt, language, source_language=source_language,
|
||||
glossary_block=_glossary,
|
||||
source_ad_vtt, _lang, source_language=source_language,
|
||||
glossary_block=_gloss,
|
||||
_cost_ctx=_cost_ctx,
|
||||
)
|
||||
|
||||
|
|
@ -412,10 +412,10 @@ async def _async_translate_and_synthesize(job_id: str):
|
|||
"origin": origin
|
||||
}
|
||||
if sdh_requested and source_sdh_vtt:
|
||||
async def translate_sdh():
|
||||
async def translate_sdh(_lang=language, _gloss=_glossary):
|
||||
return await gemini_service.translate_vtt(
|
||||
source_sdh_vtt, language, source_language=source_language,
|
||||
glossary_block=_glossary,
|
||||
source_sdh_vtt, _lang, source_language=source_language,
|
||||
glossary_block=_gloss,
|
||||
_cost_ctx=_cost_ctx,
|
||||
)
|
||||
translated_sdh = await retry_with_backoff(translate_sdh, max_retries=3)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue