fix: code health sweep — M-01 through M-07

M-01 authz.py: move cache_key above try block to avoid NameError when
     first Redis call returns None
M-02 main.py: re-enable validation middleware (was TEMPORARILY DISABLED)
M-03 routes_auth.py / main.py: replace print() debug lines with
     structured logger calls; logger now module-level in routes_auth.py
M-04 gcs.py: asyncio.get_event_loop() → get_running_loop() (deprecation)
M-05 translate_and_synthesize.py: bind loop vars in closure defaults
     to fix B023 ruff warnings (transcreate/translate_captions/etc.)
M-06 rate_limiting.py: only trust X-Forwarded-For when X-Forwarded-Proto
     is https; use rightmost entry (proxy-appended) not leftmost
M-07 validation.py: extend MongoDB operator blocklist to cover $expr,
     $function, $accumulator, $nin, $gte, $lte, $jsonSchema, $mod

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Vadym Samoilenko 2026-04-29 14:18:02 +01:00
parent 86ef5a86fb
commit 4c6624c3d4
7 changed files with 53 additions and 64 deletions

View file

@ -6,29 +6,31 @@ from motor.motor_asyncio import AsyncIOMotorDatabase
from ...core.config import settings
from ...core.database import get_database
from ...core.logging import get_logger
from ...core.security import (
create_access_token,
create_refresh_token,
decode_token,
verify_password,
)
from ...models.audit_log import AuditAction, AuditLogSeverity
from ...models.user import User, AuthProvider, UserRole
from ...schemas.auth import (
LoginRequest,
LoginResponse,
LogoutResponse,
RefreshResponse,
MicrosoftLoginRequest,
MicrosoftLoginResponse,
)
from ...services.microsoft_auth import (
get_microsoft_auth_service,
MicrosoftTokenValidationError,
MicrosoftAuthError,
RefreshResponse,
)
from ...services.audit_logger import log_auth_success, log_auth_failure, audit_logger
from ...models.audit_log import AuditAction, AuditLogSeverity
from ...services.microsoft_auth import (
get_microsoft_auth_service,
MicrosoftAuthError,
MicrosoftTokenValidationError,
)
logger = get_logger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
security = HTTPBearer()
@ -191,29 +193,23 @@ async def refresh_token(
db: AsyncIOMotorDatabase = Depends(get_database),
):
refresh_token = request.cookies.get("refresh_token")
print(f"🔍 REFRESH DEBUG: Cookie exists: {bool(refresh_token)}")
if not refresh_token:
print("🚨 REFRESH ERROR: No refresh token in cookies")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token not found",
)
try:
print(f"🔍 REFRESH DEBUG: Attempting to decode token...")
payload = decode_token(refresh_token)
print(f"🔍 REFRESH DEBUG: Token decoded successfully, type={payload.get('type')}")
if payload.get("type") != "refresh":
print(f"🚨 REFRESH ERROR: Wrong token type: {payload.get('type')}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
)
user_id = payload.get("sub")
print(f"🔍 REFRESH DEBUG: User ID from token: {user_id}")
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -250,7 +246,7 @@ async def refresh_token(
max_age=settings.jwt_refresh_ttl_days * 24 * 60 * 60,
)
print(f"🔍 REFRESH DEBUG: Refresh successful for user {user_id}")
logger.info("Token refresh successful for user %s", user_id)
return RefreshResponse(
access_token=new_access_token,
user_id=user_id,
@ -263,10 +259,7 @@ async def refresh_token(
raise
except Exception as e:
import traceback
from ...core.logging import get_logger
get_logger(__name__).exception(
"Refresh token error: %s\n%s", type(e).__name__, traceback.format_exc()
)
logger.exception("Refresh token error: %s\n%s", type(e).__name__, traceback.format_exc())
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",

View file

@ -64,10 +64,10 @@ async def _cached_memberships(
db: AsyncIOMotorDatabase,
) -> dict[str, OrgRole]:
"""Load memberships, with Redis cache (60s TTL)."""
cache_key = f"mem:user:{user_id}"
try:
redis = get_redis()
if redis:
cache_key = f"mem:user:{user_id}"
cached = await redis.get(cache_key)
if cached:
raw = json.loads(cached)

View file

@ -144,16 +144,14 @@ async def cors_error_handler(request, call_next):
try:
response = await call_next(request)
except Exception as e:
# LOG THE EXCEPTION BEFORE HANDLING IT
print(f"🚨 EXCEPTION IN CORS MIDDLEWARE: {e}")
import traceback
print(f"Traceback:\n{traceback.format_exc()}")
from .core.logging import get_logger as _get_logger
_get_logger(__name__).exception("🚨 CORS middleware caught: %s\n%s", e, traceback.format_exc())
# Handle any unhandled exceptions and add CORS headers
from fastapi.responses import JSONResponse
response = JSONResponse(
status_code=500,
content={"detail": "Internal server error", "error": str(e)}
content={"detail": "Internal server error"},
)
# Always add CORS headers for allowed origins
@ -215,18 +213,14 @@ async def general_exception_handler(request: Request, exc: Exception):
from .core.logging import get_logger
logger = get_logger(__name__)
logger.error(f"Unhandled exception in {request.method} {request.url.path}: {exc}")
logger.error(f"Exception type: {type(exc).__name__}")
logger.error(f"Traceback: {traceback.format_exc()}")
# Also print to stdout for immediate visibility
print(f"🚨 UNHANDLED EXCEPTION: {request.method} {request.url.path}")
print(f"Exception: {exc}")
print(f"Traceback:\n{traceback.format_exc()}")
logger.exception(
"🚨 Unhandled %s %s: %s\n%s",
request.method, request.url.path, exc, traceback.format_exc(),
)
response = JSONResponse(
status_code=500,
content={"detail": "Internal server error", "error": str(exc)}
content={"detail": "Internal server error"},
)
# Add CORS headers
@ -248,11 +242,7 @@ async def rate_limiting_middleware(request, call_next):
@app.middleware("http")
async def validation_middleware(request, call_next):
"""Apply request validation middleware."""
# TEMPORARILY DISABLED FOR DEBUGGING
return await call_next(request)
# Skip middleware for auth endpoints during debugging
if request.url.path in ["/api/v1/auth/login", "/api/v1/auth/refresh"]:
if request.url.path in ["/health", "/metrics", "/api/v1/auth/login", "/api/v1/auth/refresh"]:
return await call_next(request)
if hasattr(app.state, 'validation_middleware'):
return await app.state.validation_middleware(request, call_next)

View file

@ -113,16 +113,19 @@ class RateLimitMiddleware:
def _get_client_identifier(self, request: Request) -> str:
"""Get client identifier for rate limiting."""
# Try to get user ID from JWT token
user = getattr(request.state, 'user', None)
if user:
return f"user:{user.id}"
# Fall back to IP address
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return f"ip:{forwarded_for.split(',')[0].strip()}"
# Only trust X-Forwarded-For when the request arrived via HTTPS (i.e. through
# the Apache/nginx reverse proxy). On plain HTTP (direct connections, local
# dev) the header can be forged, so we fall back to the socket IP.
if request.headers.get("X-Forwarded-Proto") == "https":
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# Take the right-most IP added by the trusted proxy, not client-supplied ones.
return f"ip:{forwarded_for.split(',')[-1].strip()}"
client_ip = request.client.host if request.client else "unknown"
return f"ip:{client_ip}"

View file

@ -58,8 +58,11 @@ class RequestValidator:
r"[;&|`](?!\s*$)", # Allow $ but not as command separator
r"(rm|wget|curl|nc|bash|sh|cmd|powershell)\s+",
# MongoDB injection
r"\$where|\$ne|\$gt|\$lt|\$regex",
# MongoDB injection — NoSQL operator abuse
r"\$where|\$expr|\$function|\$accumulator"
r"|\$ne|\$nin|\$not"
r"|\$gt|\$gte|\$lt|\$lte"
r"|\$regex|\$jsonSchema|\$mod",
]
self.compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.malicious_patterns]

View file

@ -40,7 +40,7 @@ class GCSService:
return f"gs://{settings.gcs_bucket}/{destination_path}"
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
try:
return await loop.run_in_executor(self.executor, _upload)
except Exception as e:
@ -61,7 +61,7 @@ class GCSService:
return f"gs://{settings.gcs_bucket}/{destination_path}"
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
try:
return await loop.run_in_executor(self.executor, _upload)
except Exception as e:
@ -90,7 +90,7 @@ class GCSService:
version="v4"
)
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
try:
return await loop.run_in_executor(self.executor, _get_signed_url)
except NotFound:
@ -106,7 +106,7 @@ class GCSService:
blob.delete()
return True
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
try:
return await loop.run_in_executor(self.executor, _delete)
except NotFound:
@ -121,7 +121,7 @@ class GCSService:
blob = self.bucket.blob(blob_path)
return blob.exists()
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self.executor, _exists)

View file

@ -359,13 +359,13 @@ async def _async_translate_and_synthesize(job_id: str):
try:
if language in transcreation_languages:
# TRADITIONAL MODE with transcreation: cultural adaptation
async def transcreate():
async def transcreate(_lang=language, _gloss=_glossary):
return await gemini_service.transcreate_content(
source_captions_vtt,
source_ad_vtt,
language,
_lang,
brief="Standard accessibility content",
glossary_block=_glossary,
glossary_block=_gloss,
_cost_ctx=_cost_ctx,
)
@ -376,17 +376,17 @@ async def _async_translate_and_synthesize(job_id: str):
else:
# TRADITIONAL MODE: Use Gemini translation (6-36x cheaper than Google Translate API)
async def translate_captions():
async def translate_captions(_lang=language, _gloss=_glossary):
return await gemini_service.translate_vtt(
source_captions_vtt, language, source_language=source_language,
glossary_block=_glossary,
source_captions_vtt, _lang, source_language=source_language,
glossary_block=_gloss,
_cost_ctx=_cost_ctx,
)
async def translate_ad():
async def translate_ad(_lang=language, _gloss=_glossary):
return await gemini_service.translate_vtt(
source_ad_vtt, language, source_language=source_language,
glossary_block=_glossary,
source_ad_vtt, _lang, source_language=source_language,
glossary_block=_gloss,
_cost_ctx=_cost_ctx,
)
@ -412,10 +412,10 @@ async def _async_translate_and_synthesize(job_id: str):
"origin": origin
}
if sdh_requested and source_sdh_vtt:
async def translate_sdh():
async def translate_sdh(_lang=language, _gloss=_glossary):
return await gemini_service.translate_vtt(
source_sdh_vtt, language, source_language=source_language,
glossary_block=_glossary,
source_sdh_vtt, _lang, source_language=source_language,
glossary_block=_gloss,
_cost_ctx=_cost_ctx,
)
translated_sdh = await retry_with_backoff(translate_sdh, max_retries=3)