Add LLM usage tracking infrastructure (Phases A-C)

- Model renames: gpt-5.2 → gpt-5.4-2026-03-05, gemini-3-pro-preview → gemini-3.1-pro-preview; retire gpt-4.1 via alias fallback
- New: llm_usage_context.py (ContextVar-based attribution), model_pricing.py (tiered pricing + 60s cache), usage_event.py (append-only telemetry), quota.py (user/FG quota enforcement with 80% warning)
- Wire _record_usage into all 3 LLM methods; set_llm_context at every service entry point
- Fix admin_required decorator (was sync, never awaited User.find_by_id); add active_required and with_user_context decorators
- Inject user_id into ContextVar from JWT on every authenticated request
- Add DB indexes for usage_events, model_pricing, users collections
- Seed script for model pricing (gpt-5.4 single-tier, gemini-3.1 two-tier 200k threshold)
- Fix parse_json_response NameError (logger undefined at module level)
- 70 passing tests: conftest.py with sys.modules stubs, test_usage_infrastructure.py (52 tests), rewrite stale test_llm_service.py (18 tests)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Vadym Samoilenko 2026-04-24 18:08:27 +01:00
parent 0bf6043fad
commit 3e9ccafad2
26 changed files with 1566 additions and 286 deletions

164
CLAUDE.md
View file

@ -3,124 +3,88 @@
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Commands
- **Build**: `npm run build` (use this for all testing and verification)
- **Dev Server**: `npm run dev` (port 5173, proxies `/api``localhost:5137`)
- **Build**: `npm run build` (use this to verify TypeScript compilation)
- **Dev Build**: `npm run build:dev` (development mode build)
- **Lint**: `npm run lint`
- **Preview**: `npm run preview`
- **Backend**: `cd backend && python run.py`
- **Backend**: `cd backend && python run.py` (Hypercorn ASGI on port 5137)
**Note**: This project supports both local development and production deployment. See Environment Configuration section below.
## Backend Commands
- **Start Backend**: `python run.py` (from backend/ directory)
- **Backend Server**: Runs on port 5137 with Hypercorn ASGI server
- **Database**: MongoDB with PyMongo
- **Authentication**: Custom Quart-compatible JWT (not Flask-JWT-Extended)
## Testing
**Python Backend**: After modifying any Python files:
## Backend Testing
After modifying any Python files:
```bash
source backend/venv/bin/activate
python -c "import app.services.module_name" # Test specific module
python -c "from app import create_app; app = create_app()" # Test app creation
python -c "from app import create_app; create_app()" # Test app creation
```
**Frontend**: Run `npm run build` to verify TypeScript compilation
## Architecture Overview
### Real-Time Communication
The application uses Socket.IO for real-time WebSocket communication between frontend and backend:
- **Backend**: `python-socketio` with `AsyncServer` wrapped in ASGI app
- **Frontend**: `socket.io-client` managed via `WebSocketContext`
- WebSocket manager (`websocket_manager_async.py`) handles room-based messaging for focus group sessions
### ASGI Stack (critical detail)
`create_app()` returns a **`socketio.ASGIApp`** wrapping a Quart app — not the Quart app itself. Accessing `app.quart_app` gives the inner Quart instance. This distinction matters whenever you write ASGI middleware or access app config directly.
### Autonomous Conversation System
Focus group sessions can run autonomously with AI-driven conversations:
- `ai_runner_service.py` - Manages background task execution for autonomous mode
- `autonomous_conversation_controller.py` - Orchestrates multi-persona conversations
- `conversation_decision_service.py` - Determines next speaker and conversation flow
- `conversation_context_service.py` - Maintains conversation state and history
### Real-Time Communication
Socket.IO via `python-socketio` `AsyncServer` (ASGI mode). The `WebSocketContextNew.tsx` context manages the client connection. `websocket_manager_async.py` handles room-based messaging for focus group sessions. The WebSocket manager must call `ws_mgr.set_main_loop(asyncio.get_running_loop())` at startup so that cross-thread emits from the AI Runner land on the right loop.
> `VITE_ENABLE_WEBSOCKET` is hardcoded `true` in dev and `false` in production builds via `vite.config.ts` — it is not controlled by `.env`.
### AI Runner + Threading
`ai_runner_service.py` is a singleton that owns a **dedicated OS thread** with a single asyncio event loop. All autonomous AI conversations run in this thread. This solves Motor (AsyncIOMotorClient) event-loop affinity: Motor clients in the AI runner are bound to that loop, while regular API routes use synchronous PyMongo. Never share Motor clients between the two contexts.
### Autonomous Conversation Pipeline
1. `ai_runner_service.py` — spawns coroutines on the dedicated thread's event loop
2. `autonomous_conversation_controller.py` — orchestrates the full session
3. `conversation_decision_service.py` — picks the next speaker
4. `conversation_context_service.py` — maintains history/state
5. `conversation_state_manager.py` — in-memory state across turns
### Task Manager
`task_manager.py` is a singleton tracking cancellable asyncio tasks (persona generation, discussion guides, etc.). Tasks are exposed via `/api/tasks` routes. A background sweeper cleans up completed/expired tasks. Frontend polling is handled by `useTaskPolling.ts`.
### LLM Integration
Multi-model support through `llm_service.py`:
- **Google Gemini** (`gemini-3-pro-preview`) - Default model
- **OpenAI** (`gpt-4.1`, `gpt-5.2`) - Alternative models
- Prompts are stored as markdown templates in `/backend/prompts/`
`llm_service.py` creates fresh clients per call (avoids event-loop mismatch in ASGI). Default model: **Google Gemini** via `google-genai`. Alternative: **OpenAI** (`AsyncOpenAI`). Both require env vars `GEMINI_API_KEY` and `OPENAI_API_KEY` — startup fails if missing. Prompts are markdown templates in `/backend/prompts/` loaded by `prompt_loader.py`.
## Code Style Guidelines
- **Imports**: Group imports by source (React, third-party, local)
- **Types**: Use TypeScript. Project allows nullable types (`strictNullChecks: false`)
- **Components**: Use functional components with hooks
- **Naming**: Use PascalCase for components, camelCase for variables/functions
- **Formatting**: Follow ESLint recommendations, focus on readability
- **Error Handling**: Use try/catch blocks with toast for user feedback
- **CSS**: Use Tailwind classes for styling, with component-specific CSS files when needed
- **File Structure**: Components in `/src/components`, pages in `/src/pages`, hooks in `/src/hooks`
- **UI Components**: Use shadcn-ui components from `/src/components/ui`
- **State Management**: React hooks for local state, context/props for sharing
- **URL Construction**: ALWAYS use `import.meta.env.BASE_URL` when constructing URLs for static assets, images, or links. This project uses base path `/semblance/` in production. Example: `${import.meta.env.BASE_URL}image.png` instead of `/image.png`
## Project Stack
**Frontend**: Vite, React 18, TypeScript, Tailwind CSS, shadcn-ui
**Backend**: Quart (async Flask), Hypercorn ASGI, PyMongo, python-socketio
**Key Libraries**:
- UI: Radix UI components, Lucide React icons
- State: TanStack Query, React Hook Form with Zod validation
- Routing: React Router DOM
- AI/LLM: OpenAI, Google Generative AI (genai)
- Real-time: Socket.IO (client and server)
- Charts: Recharts
- Drag & Drop: DND Kit
## API Configuration
- **Frontend API Base**: `/semblance_back/api` (configurable via VITE_API_BASE_URL)
- **Backend Proxy**: Vite dev server proxies `/api` to `localhost:5137`
- **Production Base Path**: `/semblance/` (configured in vite.config.ts)
- **Authentication**: JWT tokens stored in localStorage
## Code Style
- TypeScript with `strictNullChecks: false`
- Functional components with hooks; local state via hooks, shared state via context/props
- `@/` alias maps to `src/`
- **URL construction**: always use `${import.meta.env.BASE_URL}asset.png` — production base is `/semblance/`
- Error handling: try/catch + `sonner` toast for user feedback
## File Organization
- **Backend Services**: `/backend/app/services/` - Business logic and AI integrations
- **Backend Models**: `/backend/app/models/` - Data models (User, FocusGroup, Persona, Folder)
- **Backend Routes**: `/backend/app/routes/` - API endpoints (auth, personas, focus-groups, ai-personas, folders, tasks)
- **AI Prompts**: `/backend/prompts/` - LLM prompt templates (markdown files loaded by `prompt_loader.py`)
- **Frontend Components**:
- `/src/components/ui/` - Reusable shadcn-ui components
- `/src/components/focus-group-session/` - Focus group session UI (DiscussionPanel, ParticipantPanel, ThemesPanel, etc.)
- `/src/components/persona/` - Persona management components
- **Types**: `/src/types/` - TypeScript type definitions
- **Contexts**: `/src/contexts/` - React context providers (AuthContext, WebSocketContext, NavigationContext)
```
backend/
app/
routes/ # Blueprints: auth, personas, focus-groups, ai-personas, focus-group-ai, folders, tasks
services/ # Business logic: llm_service, ai_runner_service, task_manager, autonomous_*, conversation_*
models/ # Data models: User, FocusGroup, Persona, Folder
auth/ # Auth utilities (JWT helpers)
prompts/ # LLM prompt markdown templates
websocket_manager_async.py # Room-based async WebSocket manager
extensions.py # socketio.AsyncServer singleton
src/
pages/ # Route-level components (Dashboard, FocusGroups, FocusGroupSession, Login, SyntheticUsers)
components/
focus-group-session/ # Session UI panels (Discussion, Participant, Themes, etc.)
persona/ # Persona management components
ui/ # shadcn-ui primitives
contexts/ # AuthContext, WebSocketContextNew, NavigationContext
hooks/ # useTaskPolling, useWebSocket, usePersonaStorage, useDiscussionGuideGeneration, etc.
types/ # TypeScript type definitions
```
## Environment Configuration
This application supports both local development and production deployment through environment-specific configuration files:
| Setting | Development | Production |
|---------|-------------|------------|
| Base path | `/` | `/semblance/` |
| API base | `/api` (proxied to 5137) | `https://optical-dev.oliver.solutions/semblance_back/api` |
| WebSocket path | `/socket.io/` | `/semblance_back/socket.io/` |
| MSAL redirect | `http://localhost:5173/` | `https://optical-dev.oliver.solutions/semblance` |
### Environment Files
- **`.env.development`**: Local development configuration
- **`.env.production`**: Production server configuration
- **`.env`**: Active configuration (copy from appropriate environment file)
Setup: copy `.env.development` or `.env.production` to `.env`. Backend requires `backend/.env` with `SECRET_KEY`, `JWT_SECRET_KEY`, `GEMINI_API_KEY`, `OPENAI_API_KEY` — startup will throw `RuntimeError` if any are missing or use weak defaults.
### Development vs Production
The application automatically adapts based on environment variables:
**Development Mode:**
- Base path: `/` (root)
- API base: `/api`
- WebSocket path: `/socket.io/`
- MSAL redirect: `http://localhost:5173/`
**Production Mode:**
- Base path: `/semblance/`
- API base: `https://optical-dev.oliver.solutions/semblance_back/api`
- WebSocket path: `/semblance_back/socket.io/`
- MSAL redirect: `https://optical-dev.oliver.solutions/semblance`
### Setup Instructions
1. **For local development**: Copy `.env.development` to `.env`
2. **For production**: Copy `.env.production` to `.env`
3. The build system will use the appropriate configuration
### Technical Details
- **Base Path**: Configured in vite.config.ts based on `NODE_ENV`
- **Backend Port**: 5137 (Hypercorn ASGI server)
- **Frontend Dev Port**: 5173
- **Temp Directories**: Backend creates `/backend/temp/` for file handling
## Knowledge Wiki
A cross-project knowledge base is maintained automatically from all Claude Code sessions.
- **Index:** `/Users/ai_leed/Library/Mobile Documents/iCloud~md~obsidian/Documents/VadymSamoilenko/wiki/index.md`
- **Query:** `cd ~/.claude/memory-compiler && uv run python scripts/query.py "your question"`

View file

@ -148,6 +148,17 @@ def jwt_required(optional: bool = False):
# Store user ID in request context
g.current_user_id = user_id
# Propagate user_id into the LLM usage ContextVar for this request.
# Each Quart request runs in its own asyncio Task, so setting the ContextVar
# here is request-scoped. Child tasks (create_task) and thread submissions
# (run_coroutine_threadsafe) inherit this context automatically.
try:
from app.services.llm_usage_context import _ctx, current_context
from dataclasses import replace as _dc_replace
_ctx.set(_dc_replace(current_context(), user_id=user_id))
except Exception:
pass # Non-fatal — telemetry only
# Call the actual route function and handle tuple returns
result = await func(*args, **kwargs)

View file

@ -50,10 +50,21 @@ async def get_db():
try:
await database.users.create_index("username", unique=True, background=True)
await database.users.create_index("email", unique=True, background=True)
await database.users.create_index("role", background=True)
await database.users.create_index([("is_active", 1), ("username", 1)], background=True)
await database.personas.create_index("created_by", background=True)
await database.focus_groups.create_index("created_by", background=True)
await database.folders.create_index("created_by", background=True)
await database.folders.create_index("parent_folder_id", background=True)
# usage_events indexes
await database.usage_events.create_index([("user_id", 1), ("ts", -1)], background=True)
await database.usage_events.create_index([("focus_group_id", 1), ("ts", -1)], background=True)
await database.usage_events.create_index([("ts", -1)], background=True)
await database.usage_events.create_index([("feature", 1), ("ts", -1)], background=True)
await database.usage_events.create_index([("model", 1), ("ts", -1)], background=True)
await database.usage_events.create_index([("status", 1), ("ts", -1)], background=True)
# model_pricing indexes
await database.model_pricing.create_index([("model", 1), ("effective_from", -1)], background=True)
except Exception as e:
logging.warning(f"Index creation warning (non-fatal): {e}")

View file

@ -52,7 +52,7 @@ class FocusGroup:
# Set default LLM model if not provided
if "llm_model" not in focus_group_data:
focus_group_data["llm_model"] = "gemini-3-pro-preview"
focus_group_data["llm_model"] = "gemini-3.1-pro-preview"
# Set default GPT-5 parameters if not provided
if "reasoning_effort" not in focus_group_data:

View file

@ -0,0 +1,104 @@
from app.db import get_db
from datetime import datetime, timezone
import logging
import time
logger = logging.getLogger(__name__)
# In-process cache: (model_name -> (pricing_dict, cached_at_monotonic))
_pricing_cache: dict = {}
_CACHE_TTL_SECONDS = 60
def _cache_get(model: str):
entry = _pricing_cache.get(model)
if entry and (time.monotonic() - entry[1]) < _CACHE_TTL_SECONDS:
return entry[0]
return None
def _cache_set(model: str, pricing: dict):
_pricing_cache[model] = (pricing, time.monotonic())
class ModelPricing:
@staticmethod
async def current_for(model_name: str) -> dict | None:
"""Return the active pricing row for a model, with 60 s in-process cache.
Resolves MODEL_ALIASES before lookup so callers can pass raw model names.
Returns None if no pricing is configured (cost will be recorded as 0).
"""
from app.services.llm_service import MODEL_ALIASES
resolved = MODEL_ALIASES.get(model_name, model_name)
cached = _cache_get(resolved)
if cached is not None:
return cached
try:
db = await get_db()
now = datetime.now(timezone.utc)
doc = await db.model_pricing.find_one(
{
"model": resolved,
"effective_from": {"$lte": now},
"$or": [
{"effective_until": None},
{"effective_until": {"$gt": now}},
],
},
sort=[("effective_from", -1)],
)
_cache_set(resolved, doc)
return doc
except Exception:
logger.warning(f"Failed to fetch pricing for model {resolved}", exc_info=True)
return None
@staticmethod
def pick_tier(pricing: dict, prompt_tokens: int) -> dict | None:
"""Return the cost tier that applies for a given prompt token count."""
if not pricing:
return None
tiers = pricing.get("tiers") or []
if not tiers:
return None
# Pick the tier with the largest threshold still <= prompt_tokens
applicable = [t for t in tiers if t.get("threshold_input_tokens", 0) <= prompt_tokens]
if not applicable:
applicable = tiers # fall back to first tier
return max(applicable, key=lambda t: t.get("threshold_input_tokens", 0))
@staticmethod
def compute_cost(pricing: dict | None, prompt_tokens: int, completion_tokens: int,
cached_tokens: int = 0) -> dict:
"""Compute cost breakdown from token counts and pricing doc.
Returns a dict with keys: input, cached, output, total (all USD floats).
All values are 0.0 if pricing is None.
"""
zero = {"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}
if not pricing:
return zero
tier = ModelPricing.pick_tier(pricing, prompt_tokens)
if not tier:
return zero
input_per_mtok = tier.get("input_per_mtok") or 0
cached_per_mtok = tier.get("cached_input_per_mtok") or 0
output_per_mtok = tier.get("output_per_mtok") or 0
billable_input = max(0, prompt_tokens - cached_tokens)
cost_input = billable_input * input_per_mtok / 1_000_000
cost_cached = cached_tokens * cached_per_mtok / 1_000_000
cost_output = completion_tokens * output_per_mtok / 1_000_000
cost_total = cost_input + cost_cached + cost_output
return {
"input": round(cost_input, 8),
"cached": round(cost_cached, 8),
"output": round(cost_output, 8),
"total": round(cost_total, 8),
}

View file

@ -0,0 +1,90 @@
from datetime import datetime, timezone
import logging
logger = logging.getLogger(__name__)
class QuotaExceededError(Exception):
def __init__(self, scope: str, limit_usd: float, used_usd: float, period_start=None):
self.scope = scope # "user" | "focus_group"
self.limit_usd = limit_usd
self.used_usd = used_usd
self.period_start = period_start
super().__init__(
f"Quota exceeded ({scope}): used ${used_usd:.4f} of ${limit_usd:.2f} limit"
)
class QuotaWarning:
def __init__(self, scope: str, limit_usd: float, used_usd: float, pct: float):
self.scope = scope
self.limit_usd = limit_usd
self.used_usd = used_usd
self.pct = pct
async def check_quota(user_id: str | None, focus_group_id: str | None) -> QuotaWarning | None:
"""Check quotas and raise QuotaExceededError if either is exceeded.
Returns a QuotaWarning (not raised) when usage is between 80 % and 100 %.
Returns None if all quotas are fine.
Admins and users with override_quota=True bypass user-level quota.
Focus-group quotas apply to everyone (including admins) they are project budgets.
"""
from app.models.user import User
from app.models.usage_event import UsageEvent
warning = None
if user_id:
try:
user = await User.find_by_id(user_id)
if user:
is_admin = user.get("role") == "admin"
override = user.get("override_quota", False)
if not is_admin and not override:
quota = user.get("quota") or {}
limit = quota.get("monthly_usd")
if limit:
now = datetime.now(timezone.utc)
period_start = now.replace(
day=1, hour=0, minute=0, second=0, microsecond=0
)
spent = await UsageEvent.sum_cost(
{"user_id": user_id, "ts": {"$gte": period_start}}
)
pct = spent / limit if limit else 0
if spent >= limit:
raise QuotaExceededError("user", limit, spent, period_start)
elif pct >= 0.8:
warning = QuotaWarning("user", limit, spent, pct)
except QuotaExceededError:
raise
except Exception:
logger.warning("Quota check failed (non-fatal, allowing call)", exc_info=True)
if focus_group_id:
try:
from app.models.focus_group import FocusGroup
fg = await FocusGroup.find_by_id(focus_group_id)
if fg:
fg_quota = fg.get("quota") or {}
fg_limit = fg_quota.get("total_usd")
if fg_limit:
spent = await UsageEvent.sum_cost({"focus_group_id": focus_group_id})
pct = spent / fg_limit if fg_limit else 0
if spent >= fg_limit:
raise QuotaExceededError("focus_group", fg_limit, spent, None)
elif pct >= 0.8 and not warning:
warning = QuotaWarning("focus_group", fg_limit, spent, pct)
except QuotaExceededError:
raise
except Exception:
logger.warning("Focus-group quota check failed (non-fatal)", exc_info=True)
return warning

View file

@ -0,0 +1,91 @@
from app.db import get_db
from datetime import datetime, timezone
import logging
logger = logging.getLogger(__name__)
VALID_FEATURES = {
"moderator", "persona_response", "persona_generate", "persona_modify",
"persona_export", "key_themes", "summary", "discussion_guide",
"image_description", "conversation_decision", "other",
}
class UsageEvent:
@staticmethod
async def record(
*,
provider: str,
model: str,
prompt_tokens: int,
completion_tokens: int,
cached_tokens: int = 0,
reasoning_tokens: int = 0,
cost_usd: dict,
price_snapshot_id=None,
duration_ms: int = 0,
retry_count: int = 0,
status: str = "success",
error: str | None = None,
is_estimated: bool = False,
estimate_method: str | None = None,
user_id: str | None = None,
focus_group_id: str | None = None,
persona_id: str | None = None,
feature: str = "other",
task_id: str | None = None,
) -> None:
"""Append one usage event doc. Never raises — telemetry must not kill LLM calls."""
try:
if feature not in VALID_FEATURES:
feature = "other"
doc = {
"ts": datetime.now(timezone.utc),
"user_id": user_id,
"focus_group_id": focus_group_id,
"persona_id": persona_id,
"task_id": task_id,
"feature": feature,
"provider": provider,
"model": model,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"cached_tokens": cached_tokens,
"reasoning_tokens": reasoning_tokens,
"total_tokens": prompt_tokens + completion_tokens,
"cost_usd": cost_usd,
"price_snapshot_id": price_snapshot_id,
"duration_ms": duration_ms,
"retry_count": retry_count,
"status": status,
"error": (error or "")[:500] if error else None,
"is_estimated": is_estimated,
"estimate_method": estimate_method,
}
db = await get_db()
await db.usage_events.insert_one(doc)
if not user_id or not focus_group_id:
logger.info(
f"Usage event recorded with partial context: user_id={user_id} "
f"focus_group_id={focus_group_id} feature={feature} model={model}"
)
except Exception:
logger.warning("Failed to record usage event (non-fatal)", exc_info=True)
@staticmethod
async def sum_cost(match: dict) -> float:
"""Return total cost_usd.total for the given match filter."""
try:
db = await get_db()
pipeline = [
{"$match": match},
{"$group": {"_id": None, "total": {"$sum": "$cost_usd.total"}}},
]
result = await db.usage_events.aggregate(pipeline).to_list(1)
return result[0]["total"] if result else 0.0
except Exception:
logger.warning("Failed to sum usage costs (non-fatal)", exc_info=True)
return 0.0

View file

@ -74,7 +74,7 @@ async def generate_basic_profiles():
temperature = 1.0
customer_data_session_id = data.get('customer_data_session_id') # Optional parameter
llm_model = data.get('llm_model', 'gemini-3-pro-preview') # Optional parameter with default
llm_model = data.get('llm_model', 'gemini-3.1-pro-preview') # Optional parameter with default
try:
# Register current task for cancellation
@ -210,7 +210,7 @@ async def complete_and_save_persona():
temperature = 1.0
customer_data_session_id = data.get('customer_data_session_id') # Optional parameter
llm_model = data.get('llm_model', 'gemini-3-pro-preview') # Optional parameter with default
llm_model = data.get('llm_model', 'gemini-3.1-pro-preview') # Optional parameter with default
# Get persona name for logging
persona_name = basic_profile.get('name', 'Unknown')
@ -835,7 +835,7 @@ async def batch_generate_summaries():
if not (0 <= temperature <= 1.5):
temperature = 1.0
llm_model = data.get('llm_model', 'gemini-3-pro-preview') # Optional parameter with default
llm_model = data.get('llm_model', 'gemini-3.1-pro-preview') # Optional parameter with default
# Log the request with model information
print(f"🔄 Backend: Received batch-generate-summaries request for {len(persona_ids)} personas with model: {llm_model}")
@ -1192,7 +1192,7 @@ async def generate_personas_full():
temperature = 1.0
customer_data_session_id = data.get('customer_data_session_id')
llm_model = data.get('llm_model', 'gemini-3-pro-preview')
llm_model = data.get('llm_model', 'gemini-3.1-pro-preview')
target_folder_id = data.get('target_folder_id')
try:

View file

@ -185,8 +185,8 @@ Be genuine and specific in your feedback, drawing on your personal experiences a
conversation_context=multimodal_context['conversation_context'],
temperature=temperature,
model_name=llm_model,
reasoning_effort=reasoning_effort if llm_model in ('gpt-5', 'gpt-5.2') else None,
verbosity=verbosity if llm_model in ('gpt-5', 'gpt-5.2') else None
reasoning_effort=reasoning_effort if llm_model in ('gpt-5', 'gpt-5.4-2026-03-05') else None,
verbosity=verbosity if llm_model in ('gpt-5', 'gpt-5.4-2026-03-05') else None
)
else:
response_text = await generate_persona_response(

View file

@ -161,7 +161,7 @@ async def modify_persona_with_ai(persona_id):
Request body should include:
- modification_prompt: Natural language description of desired changes
- llm_model: Model to use (defaults to 'gemini-3-pro-preview')
- llm_model: Model to use (defaults to 'gemini-3.1-pro-preview')
- reasoning_effort: For GPT-5 (minimal, low, medium, high)
- verbosity: For GPT-5 (low, medium, high)
- preview_only: If true, returns modified data without saving to database (defaults to false)
@ -175,7 +175,7 @@ async def modify_persona_with_ai(persona_id):
if not modification_prompt:
return jsonify({"error": "modification_prompt is required"}), 400
llm_model = request_data.get('llm_model', 'gemini-3-pro-preview')
llm_model = request_data.get('llm_model', 'gemini-3.1-pro-preview')
reasoning_effort = request_data.get('reasoning_effort', 'medium')
verbosity = request_data.get('verbosity', 'medium')
preview_only = request_data.get('preview_only', False)
@ -246,7 +246,7 @@ async def export_persona_profile(persona_id):
Returns 202 immediately; result delivered via WebSocket task_completed event.
Request body can optionally include:
- llm_model: Model to use (defaults to 'gpt-4.1')
- llm_model: Model to use (defaults to 'gemini-3.1-pro-preview')
- temperature: Temperature for generation (defaults to 0.3)
"""
try:
@ -255,7 +255,7 @@ async def export_persona_profile(persona_id):
return jsonify({"error": "Persona not found"}), 404
request_data = await request.get_json() or {}
llm_model = request_data.get('llm_model', 'gpt-4.1')
llm_model = request_data.get('llm_model', 'gemini-3.1-pro-preview')
temperature = request_data.get('temperature', 0.3)
user_id = get_jwt_identity()

View file

@ -585,6 +585,8 @@ class AIModeratorService:
Generated moderator response
"""
try:
from app.services.llm_usage_context import set_llm_context
set_llm_context(feature="moderator", focus_group_id=focus_group_id)
# Get previous messages for context
messages = await FocusGroup.get_messages(focus_group_id)
recent_messages = messages[-10:] if messages else [] # Last 10 messages

View file

@ -138,6 +138,8 @@ async def generate_basic_personas(
Raises:
PersonaGenerationError: If there's an issue with the AI generation or JSON parsing
"""
from app.services.llm_usage_context import set_llm_context
set_llm_context(feature="persona_generate")
last_error = None
for attempt in range(max_retries + 1):
@ -467,6 +469,8 @@ async def generate_persona(
PersonaGenerationError: If there's an issue with the AI generation or JSON parsing
"""
try:
from app.services.llm_usage_context import set_llm_context
set_llm_context(feature="persona_generate")
# If audience_brief or research_objective provided but no prompt_customization,
# generate customization so the LLM knows the research context
if not prompt_customization and (audience_brief or research_objective):

View file

@ -38,6 +38,8 @@ class ConversationDecisionService:
print(f"🎯 Decision request: {mode} mode for focus group {focus_group_id}")
try:
from app.services.llm_usage_context import set_llm_context
set_llm_context(feature="conversation_decision", focus_group_id=focus_group_id)
# Get full conversation context
context = await ConversationContextService.get_full_context(focus_group_id)
formatted_context = ConversationContextService.format_context_for_llm(context)

View file

@ -46,13 +46,15 @@ async def generate_persona_response(
FocusGroupResponseError: If there's an issue with the response generation
"""
try:
from app.services.llm_usage_context import set_llm_context
set_llm_context(
feature="persona_response",
focus_group_id=focus_group_id or None,
persona_id=str(persona.get("_id", "")) or None,
)
print(f"🎭 Generating persona response for {persona.get('name', 'Unknown')}")
print(f" - focus_group_id: {focus_group_id}")
print(f" - current_topic: {current_topic[:50]}...")
if llm_model in ('gpt-5', 'gpt-5.2'):
print(f" - llm_model: {llm_model} (reasoning_effort: {reasoning_effort or 'medium'}, verbosity: {verbosity or 'medium'}) [using Responses API]")
else:
print(f" - llm_model: {llm_model or 'default (gemini-3-pro-preview)'}")
# Import LLMService at the top to avoid scoping issues
from app.services.llm_service import LLMService
@ -417,6 +419,12 @@ async def generate_creative_review_response(
FocusGroupResponseError: If there's an issue with the response generation
"""
try:
from app.services.llm_usage_context import set_llm_context
set_llm_context(
feature="persona_response",
focus_group_id=focus_group_id or None,
persona_id=str(persona.get("_id", "")) or None,
)
print(f"🎨 CREATIVE REVIEW RESPONSE DEBUG:")
print(f" - persona: {persona.get('name', 'Unknown')}")
print(f" - current_topic: {current_topic}")

View file

@ -43,6 +43,8 @@ async def generate_focus_group_summary(
A concise one-line summary string (max ~100 characters), or None on failure
"""
try:
from app.services.llm_usage_context import set_llm_context
set_llm_context(feature="summary")
# Extract data from focus group
name = focus_group_data.get('name', 'Unnamed Focus Group')
topic = focus_group_data.get('topic', 'General Research')

View file

@ -40,8 +40,10 @@ class KeyThemeService:
KeyThemeServiceError: If there's an issue with the generation process
"""
logger = logging.getLogger(__name__)
from app.services.llm_usage_context import set_llm_context
set_llm_context(feature="key_themes", focus_group_id=focus_group_id)
logger.info(f"Starting key theme generation for focus group {focus_group_id} with temperature {temperature}")
logger.info(f"Using LLM model: {llm_model or 'default (gemini-3-pro-preview)'}")
logger.info(f"Using LLM model: {llm_model or 'default'}")
try:
# Get the focus group

View file

@ -11,6 +11,7 @@ import asyncio
import logging
import base64
import traceback
import time
from google import genai
from google.genai import errors as genai_errors
from openai import AsyncOpenAI
@ -59,18 +60,20 @@ def get_openai_client():
return AsyncOpenAI(api_key=OPENAI_API_KEY, timeout=600.0)
# The default model we're using
DEFAULT_MODEL = "gemini-3-pro-preview"
DEFAULT_MODEL = "gemini-3.1-pro-preview"
# Supported models
SUPPORTED_MODELS = {
'gemini-3-pro-preview': 'gemini',
'gpt-4.1': 'openai',
'gpt-5.2': 'openai'
'gemini-3.1-pro-preview': 'gemini',
'gpt-5.4-2026-03-05': 'openai',
}
# Aliases for renamed/legacy model IDs stored in the database
MODEL_ALIASES = {
'gpt-5': 'gpt-5.2',
'gpt-5': 'gpt-5.4-2026-03-05',
'gpt-5.2': 'gpt-5.4-2026-03-05',
'gemini-3-pro-preview': 'gemini-3.1-pro-preview',
'gpt-4.1': 'gemini-3.1-pro-preview',
}
class LLMServiceError(Exception):
@ -182,6 +185,83 @@ class LLMService:
raise
raise LLMServiceError(f"Error extracting text from new GenAI SDK response: {str(e)}")
@staticmethod
def _extract_usage_metadata(response, provider: str) -> dict:
"""Extract token counts from a provider response. All fields default to 0."""
if provider == 'gemini':
um = getattr(response, 'usage_metadata', None)
if um is None:
return {'prompt': 0, 'completion': 0, 'cached': 0, 'reasoning': 0}
return {
'prompt': getattr(um, 'prompt_token_count', 0) or 0,
'completion': getattr(um, 'candidates_token_count', 0) or 0,
'cached': getattr(um, 'cached_content_token_count', 0) or 0,
'reasoning': 0,
}
elif provider == 'openai':
usage = getattr(response, 'usage', None)
if usage is None:
return {'prompt': 0, 'completion': 0, 'cached': 0, 'reasoning': 0}
# Responses API (gpt-5.4-2026-03-05)
if hasattr(usage, 'input_tokens'):
input_details = getattr(usage, 'input_tokens_details', None)
output_details = getattr(usage, 'output_tokens_details', None)
return {
'prompt': getattr(usage, 'input_tokens', 0) or 0,
'completion': getattr(usage, 'output_tokens', 0) or 0,
'cached': getattr(input_details, 'cached_tokens', 0) or 0 if input_details else 0,
'reasoning': getattr(output_details, 'reasoning_tokens', 0) or 0 if output_details else 0,
}
# Chat Completions API
prompt_details = getattr(usage, 'prompt_tokens_details', None)
return {
'prompt': getattr(usage, 'prompt_tokens', 0) or 0,
'completion': getattr(usage, 'completion_tokens', 0) or 0,
'cached': getattr(prompt_details, 'cached_tokens', 0) or 0 if prompt_details else 0,
'reasoning': 0,
}
return {'prompt': 0, 'completion': 0, 'cached': 0, 'reasoning': 0}
@staticmethod
async def _record_usage(response, provider: str, model: str, start_time: float, retry_count: int) -> None:
"""Record a usage event after a successful LLM call. Never raises."""
try:
from app.services.llm_usage_context import current_context
from app.models.usage_event import UsageEvent
from app.models.model_pricing import ModelPricing
ctx = current_context()
tokens = LLMService._extract_usage_metadata(response, provider)
pricing = await ModelPricing.current_for(model)
cost = ModelPricing.compute_cost(
pricing,
prompt_tokens=tokens['prompt'],
completion_tokens=tokens['completion'],
cached_tokens=tokens['cached'],
)
price_id = pricing.get('_id') if pricing else None
await UsageEvent.record(
provider=provider,
model=model,
prompt_tokens=tokens['prompt'],
completion_tokens=tokens['completion'],
cached_tokens=tokens['cached'],
reasoning_tokens=tokens['reasoning'],
cost_usd=cost,
price_snapshot_id=price_id,
duration_ms=int((time.monotonic() - start_time) * 1000),
retry_count=retry_count,
status="success",
user_id=ctx.user_id,
focus_group_id=ctx.focus_group_id,
persona_id=ctx.persona_id,
feature=ctx.feature,
task_id=ctx.task_id,
)
except Exception:
logging.getLogger(__name__).warning("_record_usage failed (non-fatal)", exc_info=True)
@staticmethod
async def generate_content(
prompt: str,
@ -216,6 +296,7 @@ class LLMService:
actual_model = LLMService._resolve_model(model_name)
provider = LLMService._get_model_provider(model_name)
_start_time = time.monotonic()
for attempt in range(max_retries):
attempt_num = attempt + 1
@ -223,8 +304,8 @@ class LLMService:
try:
if provider == 'openai':
if actual_model == 'gpt-5.2':
# Use OpenAI Responses API for GPT-5
if actual_model == 'gpt-5.4-2026-03-05':
# Use OpenAI Responses API for gpt-5.4-2026-03-05
input_content = prompt
if system_prompt:
input_content = f"System: {system_prompt}\n\nUser: {prompt}"
@ -303,6 +384,7 @@ class LLMService:
if attempt > 0:
logger.info(f"LLM content generation succeeded on attempt {attempt_num}/{max_retries}")
await LLMService._record_usage(response, provider, actual_model, _start_time, attempt)
return result
except genai_errors.APIError as e:
@ -410,7 +492,7 @@ class LLMService:
except json.JSONDecodeError as e:
error_msg = f"Failed to parse JSON response: {str(e)}. Raw response: {response_text[:200]}..."
logger.error(error_msg)
logging.getLogger(__name__).error(error_msg)
raise LLMServiceError(error_msg)
@staticmethod
@ -531,6 +613,7 @@ class LLMService:
provider = LLMService._get_model_provider(model_name)
logger.info(f"Generating multimodal content with {len(image_paths)} image(s) using {provider} provider")
_start_time = time.monotonic()
for attempt in range(max_retries):
attempt_num = attempt + 1
@ -564,8 +647,8 @@ class LLMService:
})
logger.debug(f"Successfully loaded image for OpenAI: {image_path}")
if actual_model == 'gpt-5.2':
# Use Responses API for GPT-5.2 multimodal
if actual_model == 'gpt-5.4-2026-03-05':
# Use Responses API for gpt-5.4-2026-03-05 multimodal
# Note: GPT-5 Responses API supports multimodal input
input_content = [{"role": "user", "content": [{"type": "input_text", "text": prompt}]}]
# Add images to the content array
@ -660,6 +743,7 @@ class LLMService:
if attempt > 0:
logger.info(f"Multimodal content generation succeeded on attempt {attempt_num}/{max_retries}")
await LLMService._record_usage(response, provider, actual_model, _start_time, attempt)
return result
except Exception as e:
@ -766,6 +850,7 @@ class LLMService:
max_retries = 3
last_error = None
_start_time = time.monotonic()
for attempt in range(max_retries):
attempt_num = attempt + 1
@ -790,8 +875,8 @@ class LLMService:
}
})
if actual_model == 'gpt-5.2':
# Use Responses API for GPT-5.2 contextual multimodal
if actual_model == 'gpt-5.4-2026-03-05':
# Use Responses API for gpt-5.4-2026-03-05 contextual multimodal
input_content = [{"role": "user", "content": [{"type": "input_text", "text": full_prompt}]}]
# Add images to the content array
for img_content in image_content:
@ -875,6 +960,7 @@ class LLMService:
print(f" - Result length: {len(result) if result else 0} characters")
print(f" - Result preview: '{result[:200] if result else 'EMPTY'}...'")
print(f" - Result repr: {repr(result[:50]) if result else 'NONE'}")
await LLMService._record_usage(response, provider, actual_model, _start_time, attempt)
return result
except genai_errors.APIError as e:

View file

@ -0,0 +1,54 @@
from contextvars import ContextVar
from dataclasses import dataclass, replace
from contextlib import contextmanager
from typing import Optional
@dataclass(frozen=True)
class LLMCallContext:
user_id: Optional[str] = None
focus_group_id: Optional[str] = None
persona_id: Optional[str] = None
feature: str = "other"
task_id: Optional[str] = None
_ctx: ContextVar[LLMCallContext] = ContextVar("llm_call_context", default=LLMCallContext())
def current_context() -> LLMCallContext:
return _ctx.get()
def set_llm_context(**overrides) -> None:
"""Mutate the LLM context for the current asyncio task without cleanup.
Use this at service entry points where the feature/focus_group_id/persona_id
should persist for the duration of the whole async call tree (including sub-awaits).
The change lives until the asyncio Task ends or is overridden again.
Unlike llm_context(), this does NOT restore the previous value on exit suitable
for top-level service calls, not for re-entrant helpers.
"""
prev = _ctx.get()
_ctx.set(replace(prev, **overrides))
@contextmanager
def llm_context(**overrides):
"""Context manager that sets LLM call attribution metadata.
Usage:
with llm_context(user_id="abc", focus_group_id="xyz", feature="moderator"):
await LLMService.generate_content(...)
Overrides stack inner contexts extend (not replace) outer ones.
Safe across asyncio tasks and run_coroutine_threadsafe hops because
ContextVar inherits context on task creation / thread submission.
"""
prev = _ctx.get()
token = _ctx.set(replace(prev, **overrides))
try:
yield
finally:
_ctx.reset(token)

View file

@ -134,7 +134,7 @@ class PersonaModificationService:
async def modify_persona(
persona_id: str,
modification_prompt: str,
llm_model: str = 'gemini-3-pro-preview',
llm_model: str = 'gemini-3.1-pro-preview',
reasoning_effort: str = 'medium',
verbosity: str = 'medium',
max_retries: int = 3,
@ -159,6 +159,8 @@ class PersonaModificationService:
PersonaModificationError: If modification fails or validation fails
"""
try:
from app.services.llm_usage_context import set_llm_context
set_llm_context(feature="persona_modify", persona_id=persona_id)
# Fetch the original persona
original_persona = await Persona.find_by_id(persona_id)
if not original_persona:
@ -188,8 +190,8 @@ class PersonaModificationService:
prompt=final_prompt,
temperature=0.3, # Lower temperature for consistent modifications
model_name=llm_model,
reasoning_effort=reasoning_effort if llm_model in ('gpt-5', 'gpt-5.2') else None,
verbosity=verbosity if llm_model in ('gpt-5', 'gpt-5.2') else None
reasoning_effort=reasoning_effort if llm_model in ('gpt-5', 'gpt-5.4-2026-03-05') else None,
verbosity=verbosity if llm_model in ('gpt-5', 'gpt-5.4-2026-03-05') else None
)
# Parse JSON response

View file

@ -22,14 +22,54 @@ def make_serializable(obj):
def admin_required(f):
"""Route decorator requiring admin role. Must be stacked BELOW @jwt_required()."""
@wraps(f)
def decorated_function(*args, **kwargs):
async def decorated(*args, **kwargs):
user_id = get_jwt_identity()
user_data = User.find_by_id(user_id)
if not user_data or user_data.get('role') != 'admin':
if not user_id:
return jsonify({"message": "Authentication required"}), 401
user_data = await User.find_by_id(user_id)
if not user_data:
return jsonify({"message": "User not found"}), 404
if user_data.get("role") != "admin":
return jsonify({"message": "Admin privileges required"}), 403
if user_data.get("is_active") is False:
return jsonify({"message": "Account disabled"}), 403
return await f(*args, **kwargs)
return decorated
return f(*args, **kwargs)
return decorated_function
def active_required(f):
"""Route decorator that rejects requests from disabled users.
Guards LLM-invoking routes so that revoking a user's access takes effect
immediately rather than waiting for their JWT to expire (24 h window).
Must be stacked BELOW @jwt_required().
"""
@wraps(f)
async def decorated(*args, **kwargs):
user_id = get_jwt_identity()
if user_id:
user_data = await User.find_by_id(user_id)
if user_data and user_data.get("is_active") is False:
return jsonify({"message": "Account disabled"}), 403
return await f(*args, **kwargs)
return decorated
def with_user_context(f):
"""Route decorator that injects the JWT user_id into the LLM usage ContextVar.
Must be stacked BELOW @jwt_required() so the token is already validated.
The context propagates to asyncio tasks and run_coroutine_threadsafe calls,
so autonomous AI runner conversations pick up the user attribution automatically.
"""
@wraps(f)
async def decorated(*args, **kwargs):
from app.services.llm_usage_context import llm_context
user_id = get_jwt_identity()
if user_id:
with llm_context(user_id=user_id):
return await f(*args, **kwargs)
return await f(*args, **kwargs)
return decorated

View file

@ -31,3 +31,6 @@ pillow==11.3.0
# Configuration & Utilities
python-dotenv==1.1.1
# Token estimation (used by backfill_usage.py script)
tiktoken>=0.9.0

View file

@ -0,0 +1,101 @@
#!/usr/bin/env python3
"""Seed model pricing for Semblance.
Run from the backend/ directory:
source venv/bin/activate
python scripts/seed_model_pricing.py
Idempotent upserts on {model, effective_from}. Safe to re-run.
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dotenv import load_dotenv
load_dotenv()
import pymongo
from datetime import datetime, timezone
MONGO_URI = os.environ.get("MONGO_URI")
MONGO_USER = os.environ.get("MONGO_USER")
MONGO_PASS = os.environ.get("MONGO_PASS")
MONGO_HOST = os.environ.get("MONGO_HOST", "localhost")
MONGO_PORT = os.environ.get("MONGO_PORT", "27017")
if not MONGO_URI:
if MONGO_USER and MONGO_PASS:
MONGO_URI = f"mongodb://{MONGO_USER}:{MONGO_PASS}@{MONGO_HOST}:{MONGO_PORT}/semblance_db?authSource=admin"
else:
MONGO_URI = f"mongodb://{MONGO_HOST}:{MONGO_PORT}"
# Pricing effective from project start — covers all historical backfill
EFFECTIVE_FROM = datetime(2024, 1, 1, tzinfo=timezone.utc)
PRICING_ROWS = [
{
"model": "gpt-5.4-2026-03-05",
"provider": "openai",
"currency": "USD",
"tiers": [
{
"threshold_input_tokens": 0,
"input_per_mtok": 2.50,
"cached_input_per_mtok": 0.25,
"output_per_mtok": 15.00,
"image_per_mtok": None,
}
],
"effective_from": EFFECTIVE_FROM,
"effective_until": None,
"notes": "gpt-5.4-2026-03-05 pricing as of 2026-04",
},
{
"model": "gemini-3.1-pro-preview",
"provider": "gemini",
"currency": "USD",
"tiers": [
{
"threshold_input_tokens": 0,
"input_per_mtok": 2.00,
"cached_input_per_mtok": None,
"output_per_mtok": 12.00,
"image_per_mtok": None,
},
{
"threshold_input_tokens": 200_000,
"input_per_mtok": 4.00,
"cached_input_per_mtok": None,
"output_per_mtok": 18.00,
"image_per_mtok": None,
},
],
"effective_from": EFFECTIVE_FROM,
"effective_until": None,
"notes": "gemini-3.1-pro-preview pricing: $2/$12 (<200k ctx), $4/$18 (>=200k ctx)",
},
]
def main():
client = pymongo.MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
db = client.semblance_db
db.model_pricing.create_index(
[("model", pymongo.ASCENDING), ("effective_from", pymongo.DESCENDING)],
background=True,
)
for row in PRICING_ROWS:
key = {"model": row["model"], "effective_from": row["effective_from"]}
result = db.model_pricing.update_one(key, {"$set": row}, upsert=True)
action = "inserted" if result.upserted_id else "updated"
print(f" {action}: {row['model']} (effective from {row['effective_from'].date()})")
print(f"\nDone. {len(PRICING_ROWS)} pricing rows seeded.")
client.close()
if __name__ == "__main__":
main()

53
backend/tests/conftest.py Normal file
View file

@ -0,0 +1,53 @@
"""
conftest.py sys.modules stubs so tests run without the full Docker venv.
All heavy external packages are replaced with MagicMocks before any app.*
module is imported. Individual tests patch specific methods as needed.
"""
import os
import sys
from unittest.mock import MagicMock, AsyncMock
# ── Fake env vars required at llm_service.py module level ─────────────────────
os.environ.setdefault('GEMINI_API_KEY', 'test-key-gemini')
os.environ.setdefault('OPENAI_API_KEY', 'test-key-openai')
def _stub(*names):
"""Register MagicMocks under each name in sys.modules."""
for name in names:
if name not in sys.modules:
sys.modules[name] = MagicMock()
# ── External packages not present in system Python ────────────────────────────
_stub(
'google', 'google.genai', 'google.genai.types',
'openai', 'openai.types', 'openai.types.responses',
'motor', 'motor.motor_asyncio',
'pymongo', 'pymongo.errors',
'quart', 'quart_cors', 'hypercorn', 'werkzeug', 'werkzeug.exceptions',
'socketio',
'bcrypt', 'jwt', 'msal',
'bson', 'bson.objectid',
'pydantic',
'PIL', 'PIL.Image',
'httpx', 'requests',
'dotenv',
'llama_cloud_services',
)
# ── app.db ─────────────────────────────────────────────────────────────────────
# Any `from app.db import get_db` will capture this AsyncMock.
# Tests that need to control DB responses should patch
# `app.models.<module>.get_db` (i.e. the local binding in the module under test).
_mock_db = MagicMock()
_mock_get_db = AsyncMock(return_value=_mock_db)
_app_db_mod = MagicMock()
_app_db_mod.get_db = _mock_get_db
sys.modules['app.db'] = _app_db_mod
# Expose for tests
mock_db = _mock_db

View file

@ -1,145 +1,128 @@
"""
Tests for the LLM Service
Tests for LLMService covers parse_json_response (sync) and generate_structured_array.
generate_content / generate_multimodal_content are async and call real provider APIs,
so they are covered via integration tests; only pure logic is unit-tested here.
"""
import unittest
from unittest.mock import patch, MagicMock
import json
import sys
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from app.services.llm_service import LLMService, LLMServiceError
class TestLLMService(unittest.TestCase):
"""Test cases for the LLM Service"""
@patch('app.services.llm_service.genai.GenerativeModel')
def test_get_model(self, mock_generative_model):
"""Test getting a Gemini model"""
# Setup mock
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
class TestParseJsonResponse:
def test_clean_json(self):
result = LLMService.parse_json_response('{"key": "value", "number": 42}')
assert result == {"key": "value", "number": 42}
# Test with default model
model = LLMService.get_model()
mock_generative_model.assert_called_once()
self.assertEqual(model, mock_model)
def test_json_in_markdown_fenced_block(self):
md = '```json\n{"key": "value", "number": 42}\n```'
assert LLMService.parse_json_response(md) == {"key": "value", "number": 42}
# Reset mock
mock_generative_model.reset_mock()
def test_json_in_generic_fenced_block(self):
md = '```\n{"key": "value", "number": 42}\n```'
assert LLMService.parse_json_response(md) == {"key": "value", "number": 42}
# Test with custom model
custom_model = "custom-model-name"
model = LLMService.get_model(custom_model)
mock_generative_model.assert_called_once_with(custom_model)
self.assertEqual(model, mock_model)
def test_invalid_json_raises(self):
with pytest.raises(LLMServiceError) as exc_info:
LLMService.parse_json_response("This is not JSON")
assert "Failed to parse JSON response" in str(exc_info.value)
@patch('app.services.llm_service.LLMService.get_model')
def test_generate_content(self, mock_get_model):
"""Test generating content with the LLM"""
# Setup mock
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Generated text response"
mock_model.generate_content.return_value = mock_response
mock_get_model.return_value = mock_model
def test_empty_string_raises(self):
with pytest.raises(LLMServiceError):
LLMService.parse_json_response("")
# Test with default parameters
prompt = "Test prompt"
response = LLMService.generate_content(prompt)
def test_json_array(self):
result = LLMService.parse_json_response('[{"a": 1}, {"b": 2}]')
assert result == [{"a": 1}, {"b": 2}]
mock_get_model.assert_called_once()
mock_model.generate_content.assert_called_once()
self.assertEqual(response, "Generated text response")
def test_nested_json(self):
result = LLMService.parse_json_response('{"outer": {"inner": [1, 2, 3]}}')
assert result == {"outer": {"inner": [1, 2, 3]}}
# Test with custom parameters
mock_get_model.reset_mock()
mock_model.generate_content.reset_mock()
response = LLMService.generate_content(
prompt="Custom prompt",
temperature=0.5,
max_tokens=100,
model_name="custom-model"
)
class TestResolveModelAndProvider:
"""_resolve_model is a pure function. Provider is looked up via SUPPORTED_MODELS."""
mock_get_model.assert_called_once_with("custom-model")
mock_model.generate_content.assert_called_once()
self.assertEqual(response, "Generated text response")
def test_none_resolves_to_default(self):
assert LLMService._resolve_model(None) == "gemini-3.1-pro-preview"
@patch('app.services.llm_service.LLMService.get_model')
def test_generate_content_error(self, mock_get_model):
"""Test error handling in generate_content"""
# Setup mock to raise an exception
mock_model = MagicMock()
mock_model.generate_content.side_effect = Exception("Model error")
mock_get_model.return_value = mock_model
def test_all_aliases_resolve(self):
assert LLMService._resolve_model("gpt-5") == "gpt-5.4-2026-03-05"
assert LLMService._resolve_model("gpt-5.2") == "gpt-5.4-2026-03-05"
assert LLMService._resolve_model("gemini-3-pro-preview") == "gemini-3.1-pro-preview"
assert LLMService._resolve_model("gpt-4.1") == "gemini-3.1-pro-preview"
# Test error handling
with self.assertRaises(LLMServiceError) as context:
LLMService.generate_content("Test prompt")
def test_known_models_unchanged(self):
assert LLMService._resolve_model("gemini-3.1-pro-preview") == "gemini-3.1-pro-preview"
assert LLMService._resolve_model("gpt-5.4-2026-03-05") == "gpt-5.4-2026-03-05"
self.assertIn("Error generating content", str(context.exception))
def test_provider_for_gemini_model(self):
from app.services.llm_service import SUPPORTED_MODELS
assert SUPPORTED_MODELS.get("gemini-3.1-pro-preview") == "gemini"
def test_parse_json_response_valid(self):
"""Test parsing valid JSON responses"""
# Test with clean JSON
clean_json = '{"key": "value", "number": 42}'
result = LLMService.parse_json_response(clean_json)
expected = {"key": "value", "number": 42}
self.assertEqual(result, expected)
def test_provider_for_openai_model(self):
from app.services.llm_service import SUPPORTED_MODELS
assert SUPPORTED_MODELS.get("gpt-5.4-2026-03-05") == "openai"
# Test with JSON in markdown code block
markdown_json = '```json\n{"key": "value", "number": 42}\n```'
result = LLMService.parse_json_response(markdown_json)
self.assertEqual(result, expected)
def test_unknown_model_not_in_supported(self):
from app.services.llm_service import SUPPORTED_MODELS
assert "gpt-4.1" not in SUPPORTED_MODELS # retired
# Test with JSON in generic code block
generic_code_block = '```\n{"key": "value", "number": 42}\n```'
result = LLMService.parse_json_response(generic_code_block)
self.assertEqual(result, expected)
def test_parse_json_response_invalid(self):
"""Test parsing invalid JSON responses"""
invalid_json = 'This is not JSON'
class TestExtractUsageMetadata:
"""Static — no external calls."""
with self.assertRaises(LLMServiceError) as context:
LLMService.parse_json_response(invalid_json)
def test_gemini_extracts_all_fields(self):
response = MagicMock()
um = MagicMock()
um.prompt_token_count = 500
um.candidates_token_count = 100
um.cached_content_token_count = 20
response.usage_metadata = um
self.assertIn("Failed to parse JSON response", str(context.exception))
result = LLMService._extract_usage_metadata(response, "gemini")
assert result == {"prompt": 500, "completion": 100, "cached": 20, "reasoning": 0}
@patch('app.services.llm_service.LLMService.generate_content')
@patch('app.services.llm_service.LLMService.parse_json_response')
def test_generate_structured_response(self, mock_parse_json, mock_generate_content):
"""Test generating a structured JSON response"""
# Setup mocks
mock_generate_content.return_value = '{"result": "success"}'
mock_parse_json.return_value = {"result": "success"}
def test_openai_responses_api(self):
response = MagicMock()
usage = MagicMock()
usage.input_tokens = 1000
usage.output_tokens = 200
usage.input_tokens_details = MagicMock(cached_tokens=50)
usage.output_tokens_details = MagicMock(reasoning_tokens=80)
response.usage = usage
# Test
result = LLMService.generate_structured_response(
prompt="Generate JSON",
temperature=0.5
)
result = LLMService._extract_usage_metadata(response, "openai")
assert result == {"prompt": 1000, "completion": 200, "cached": 50, "reasoning": 80}
mock_generate_content.assert_called_once_with(
prompt="Generate JSON",
temperature=0.5,
max_tokens=None,
model_name=None,
system_prompt=None
)
mock_parse_json.assert_called_once_with('{"result": "success"}')
self.assertEqual(result, {"result": "success"})
def test_openai_chat_completions(self):
response = MagicMock()
usage = MagicMock(spec=['prompt_tokens', 'completion_tokens', 'prompt_tokens_details'])
usage.prompt_tokens = 400
usage.completion_tokens = 100
usage.prompt_tokens_details = MagicMock(cached_tokens=10)
response.usage = usage
@patch('app.services.llm_service.LLMService.generate_content')
def test_generate_structured_response_error(self, mock_generate_content):
"""Test error handling in generate_structured_response"""
# Setup mock to raise an exception
mock_generate_content.side_effect = LLMServiceError("Generation failed")
result = LLMService._extract_usage_metadata(response, "openai")
assert result == {"prompt": 400, "completion": 100, "cached": 10, "reasoning": 0}
# Test error propagation
with self.assertRaises(LLMServiceError) as context:
LLMService.generate_structured_response("Generate JSON")
def test_missing_usage_returns_zeros(self):
response = MagicMock()
response.usage = None
assert LLMService._extract_usage_metadata(response, "openai") == {
"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0
}
self.assertEqual(str(context.exception), "Generation failed")
def test_none_values_coerced_to_zero(self):
"""Fields returning None should become 0, not None."""
response = MagicMock()
um = MagicMock()
um.prompt_token_count = None
um.candidates_token_count = None
um.cached_content_token_count = None
response.usage_metadata = um
if __name__ == '__main__':
unittest.main()
result = LLMService._extract_usage_metadata(response, "gemini")
assert all(v == 0 for v in result.values())

View file

@ -0,0 +1,664 @@
"""
Tests for the usage tracking infrastructure added in Phase A-C:
- LLMCallContext / llm_usage_context module
- ModelPricing.pick_tier and ModelPricing.compute_cost
- LLMService._extract_usage_metadata and LLMService._resolve_model
- UsageEvent.record (non-fatal DB error handling, field validation)
- check_quota (admin bypass, warning threshold, exceeded error)
"""
import asyncio
import sys
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
# ══════════════════════════════════════════════════════════════════════════════
# 1. LLM usage context — pure stdlib, no mocking needed
# ══════════════════════════════════════════════════════════════════════════════
class TestLLMUsageContext:
def setup_method(self):
# Import fresh each test to avoid cross-test ContextVar pollution
from app.services.llm_usage_context import (
LLMCallContext, _ctx, current_context, set_llm_context, llm_context
)
self.LLMCallContext = LLMCallContext
self._ctx = _ctx
self.current_context = current_context
self.set_llm_context = set_llm_context
self.llm_context = llm_context
# Reset context to default before each test
self._ctx.set(LLMCallContext())
def test_default_context(self):
ctx = self.current_context()
assert ctx.user_id is None
assert ctx.focus_group_id is None
assert ctx.persona_id is None
assert ctx.feature == "other"
assert ctx.task_id is None
def test_set_llm_context_mutates(self):
self.set_llm_context(user_id="u1", feature="persona_generate")
ctx = self.current_context()
assert ctx.user_id == "u1"
assert ctx.feature == "persona_generate"
# Unset fields remain None
assert ctx.focus_group_id is None
def test_set_llm_context_merges_with_existing(self):
self.set_llm_context(user_id="u1")
self.set_llm_context(focus_group_id="fg1")
ctx = self.current_context()
# Both values present
assert ctx.user_id == "u1"
assert ctx.focus_group_id == "fg1"
def test_llm_context_restores_on_exit(self):
self.set_llm_context(user_id="outer")
with self.llm_context(user_id="inner", feature="moderator"):
assert self.current_context().user_id == "inner"
assert self.current_context().feature == "moderator"
# Restored
restored = self.current_context()
assert restored.user_id == "outer"
assert restored.feature == "other"
def test_llm_context_restores_on_exception(self):
self.set_llm_context(user_id="before")
try:
with self.llm_context(user_id="during"):
raise ValueError("boom")
except ValueError:
pass
assert self.current_context().user_id == "before"
def test_llm_context_stacking(self):
with self.llm_context(user_id="u1"):
with self.llm_context(focus_group_id="fg1"):
with self.llm_context(feature="key_themes"):
ctx = self.current_context()
assert ctx.user_id == "u1"
assert ctx.focus_group_id == "fg1"
assert ctx.feature == "key_themes"
assert self.current_context().feature == "other" # restored
assert self.current_context().focus_group_id is None # restored
assert self.current_context().user_id is None # restored
def test_frozen_dataclass_immutable(self):
ctx = self.LLMCallContext(user_id="x")
with pytest.raises((AttributeError, TypeError)):
ctx.user_id = "y" # type: ignore
# ══════════════════════════════════════════════════════════════════════════════
# 2. ModelPricing pure logic — pick_tier and compute_cost
# ══════════════════════════════════════════════════════════════════════════════
class TestModelPricingPureLogic:
def setup_method(self):
from app.models.model_pricing import ModelPricing
self.ModelPricing = ModelPricing
def _gemini_pricing(self):
"""Two-tier Gemini pricing: <200k and >=200k."""
return {
"model": "gemini-3.1-pro-preview",
"tiers": [
{
"threshold_input_tokens": 0,
"input_per_mtok": 2.0,
"cached_input_per_mtok": None,
"output_per_mtok": 12.0,
},
{
"threshold_input_tokens": 200_000,
"input_per_mtok": 4.0,
"cached_input_per_mtok": None,
"output_per_mtok": 18.0,
},
],
}
def _gpt_pricing(self):
"""Single-tier GPT pricing."""
return {
"model": "gpt-5.4-2026-03-05",
"tiers": [
{
"threshold_input_tokens": 0,
"input_per_mtok": 2.50,
"cached_input_per_mtok": 0.25,
"output_per_mtok": 15.0,
}
],
}
# ── pick_tier ──────────────────────────────────────────────────────────────
def test_pick_tier_below_all_thresholds(self):
pricing = self._gemini_pricing()
tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=1000)
assert tier["threshold_input_tokens"] == 0
assert tier["input_per_mtok"] == 2.0
def test_pick_tier_above_high_threshold(self):
pricing = self._gemini_pricing()
tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=250_000)
assert tier["threshold_input_tokens"] == 200_000
assert tier["input_per_mtok"] == 4.0
def test_pick_tier_exactly_at_threshold(self):
pricing = self._gemini_pricing()
tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=200_000)
assert tier["threshold_input_tokens"] == 200_000
def test_pick_tier_single_tier(self):
pricing = self._gpt_pricing()
tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=500_000)
assert tier["input_per_mtok"] == 2.50
def test_pick_tier_none_pricing_returns_none(self):
assert self.ModelPricing.pick_tier(None, 1000) is None
def test_pick_tier_empty_tiers_returns_none(self):
assert self.ModelPricing.pick_tier({"tiers": []}, 1000) is None
# ── compute_cost ──────────────────────────────────────────────────────────
def test_compute_cost_none_pricing_returns_zeros(self):
cost = self.ModelPricing.compute_cost(None, 100, 50)
assert cost == {"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}
def test_compute_cost_basic_gemini_low_tier(self):
# 100k input, 10k output, no cached — stays in low tier
cost = self.ModelPricing.compute_cost(
self._gemini_pricing(),
prompt_tokens=100_000,
completion_tokens=10_000,
cached_tokens=0,
)
expected_input = 100_000 * 2.0 / 1_000_000 # $0.20
expected_output = 10_000 * 12.0 / 1_000_000 # $0.12
assert abs(cost["input"] - expected_input) < 1e-9
assert cost["cached"] == 0.0
assert abs(cost["output"] - expected_output) < 1e-9
assert abs(cost["total"] - (expected_input + expected_output)) < 1e-9
def test_compute_cost_basic_gemini_high_tier(self):
# 300k input → high tier: $4/$18
cost = self.ModelPricing.compute_cost(
self._gemini_pricing(),
prompt_tokens=300_000,
completion_tokens=5_000,
)
expected_input = 300_000 * 4.0 / 1_000_000
expected_output = 5_000 * 18.0 / 1_000_000
assert abs(cost["input"] - expected_input) < 1e-9
assert abs(cost["output"] - expected_output) < 1e-9
def test_compute_cost_with_cached_tokens_gpt(self):
# 10k input, 2k cached (cheaper), 5k output
cost = self.ModelPricing.compute_cost(
self._gpt_pricing(),
prompt_tokens=10_000,
completion_tokens=5_000,
cached_tokens=2_000,
)
# Billable input = 10k - 2k = 8k at $2.50/M
expected_input = 8_000 * 2.50 / 1_000_000
# Cached 2k at $0.25/M
expected_cached = 2_000 * 0.25 / 1_000_000
# Output 5k at $15/M
expected_output = 5_000 * 15.0 / 1_000_000
expected_total = expected_input + expected_cached + expected_output
assert abs(cost["input"] - expected_input) < 1e-9
assert abs(cost["cached"] - expected_cached) < 1e-9
assert abs(cost["output"] - expected_output) < 1e-9
assert abs(cost["total"] - expected_total) < 1e-9
def test_compute_cost_total_equals_sum_of_components(self):
cost = self.ModelPricing.compute_cost(
self._gpt_pricing(),
prompt_tokens=50_000,
completion_tokens=20_000,
cached_tokens=5_000,
)
assert abs(cost["total"] - (cost["input"] + cost["cached"] + cost["output"])) < 1e-9
# ══════════════════════════════════════════════════════════════════════════════
# 3. LLMService._extract_usage_metadata and _resolve_model
# ══════════════════════════════════════════════════════════════════════════════
class TestLLMServicePureStaticMethods:
def setup_method(self):
from app.services.llm_service import LLMService, MODEL_ALIASES, DEFAULT_MODEL
self.LLMService = LLMService
self.MODEL_ALIASES = MODEL_ALIASES
self.DEFAULT_MODEL = DEFAULT_MODEL
# ── _resolve_model ────────────────────────────────────────────────────────
def test_resolve_model_none_returns_default(self):
assert self.LLMService._resolve_model(None) == self.DEFAULT_MODEL
def test_resolve_model_default_unchanged(self):
assert self.LLMService._resolve_model("gemini-3.1-pro-preview") == "gemini-3.1-pro-preview"
def test_resolve_model_gpt5_alias(self):
assert self.LLMService._resolve_model("gpt-5") == "gpt-5.4-2026-03-05"
def test_resolve_model_gpt52_alias(self):
assert self.LLMService._resolve_model("gpt-5.2") == "gpt-5.4-2026-03-05"
def test_resolve_model_gemini3_alias(self):
assert self.LLMService._resolve_model("gemini-3-pro-preview") == "gemini-3.1-pro-preview"
def test_resolve_model_gpt41_retired_falls_to_gemini(self):
assert self.LLMService._resolve_model("gpt-4.1") == "gemini-3.1-pro-preview"
def test_resolve_model_known_openai_unchanged(self):
assert self.LLMService._resolve_model("gpt-5.4-2026-03-05") == "gpt-5.4-2026-03-05"
def test_resolve_model_unknown_passthrough(self):
# Unknown model names should pass through untouched
assert self.LLMService._resolve_model("some-future-model") == "some-future-model"
# ── _extract_usage_metadata ───────────────────────────────────────────────
def test_extract_gemini_full(self):
response = MagicMock()
um = MagicMock()
um.prompt_token_count = 1000
um.candidates_token_count = 200
um.cached_content_token_count = 50
response.usage_metadata = um
result = self.LLMService._extract_usage_metadata(response, "gemini")
assert result == {"prompt": 1000, "completion": 200, "cached": 50, "reasoning": 0}
def test_extract_gemini_missing_usage_metadata(self):
response = MagicMock(spec=[]) # no attributes
result = self.LLMService._extract_usage_metadata(response, "gemini")
assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0}
def test_extract_gemini_none_usage_metadata(self):
response = MagicMock()
response.usage_metadata = None
result = self.LLMService._extract_usage_metadata(response, "gemini")
assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0}
def test_extract_openai_responses_api(self):
"""OpenAI Responses API format (gpt-5.4-2026-03-05)."""
response = MagicMock()
usage = MagicMock()
usage.input_tokens = 5000
usage.output_tokens = 1000
input_details = MagicMock()
input_details.cached_tokens = 200
output_details = MagicMock()
output_details.reasoning_tokens = 300
usage.input_tokens_details = input_details
usage.output_tokens_details = output_details
response.usage = usage
result = self.LLMService._extract_usage_metadata(response, "openai")
assert result == {"prompt": 5000, "completion": 1000, "cached": 200, "reasoning": 300}
def test_extract_openai_chat_completions(self):
"""OpenAI Chat Completions format (legacy)."""
response = MagicMock()
usage = MagicMock(spec=['prompt_tokens', 'completion_tokens', 'prompt_tokens_details'])
usage.prompt_tokens = 3000
usage.completion_tokens = 800
details = MagicMock()
details.cached_tokens = 100
usage.prompt_tokens_details = details
response.usage = usage
result = self.LLMService._extract_usage_metadata(response, "openai")
assert result == {"prompt": 3000, "completion": 800, "cached": 100, "reasoning": 0}
def test_extract_openai_missing_usage(self):
response = MagicMock()
response.usage = None
result = self.LLMService._extract_usage_metadata(response, "openai")
assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0}
def test_extract_unknown_provider_returns_zeros(self):
result = self.LLMService._extract_usage_metadata(MagicMock(), "anthropic")
assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0}
def test_extract_gemini_token_counts_default_to_zero_when_none(self):
"""Fields that return None should be coerced to 0, not None."""
response = MagicMock()
um = MagicMock()
um.prompt_token_count = None
um.candidates_token_count = None
um.cached_content_token_count = None
response.usage_metadata = um
result = self.LLMService._extract_usage_metadata(response, "gemini")
assert result["prompt"] == 0
assert result["completion"] == 0
assert result["cached"] == 0
# ══════════════════════════════════════════════════════════════════════════════
# 4. UsageEvent.record — non-fatal DB error handling, field coercion
# ══════════════════════════════════════════════════════════════════════════════
class TestUsageEventRecord:
def setup_method(self):
# Reset module-level imports for clean state
pass
@pytest.mark.asyncio
async def test_record_normalizes_unknown_feature(self):
"""Unknown feature names should be stored as 'other'."""
captured = {}
async def fake_insert(doc):
captured['doc'] = doc
mock_db = MagicMock()
mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert)
with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)):
from app.models.usage_event import UsageEvent
await UsageEvent.record(
provider="gemini",
model="gemini-3.1-pro-preview",
prompt_tokens=100,
completion_tokens=50,
cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0},
feature="totally_invalid_feature_xyz",
)
assert captured['doc']['feature'] == "other"
@pytest.mark.asyncio
async def test_record_valid_feature_preserved(self):
captured = {}
async def fake_insert(doc):
captured['doc'] = doc
mock_db = MagicMock()
mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert)
with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)):
from app.models.usage_event import UsageEvent
await UsageEvent.record(
provider="gemini",
model="gemini-3.1-pro-preview",
prompt_tokens=100,
completion_tokens=50,
cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0},
feature="key_themes",
)
assert captured['doc']['feature'] == "key_themes"
@pytest.mark.asyncio
async def test_record_total_tokens_computed(self):
captured = {}
async def fake_insert(doc):
captured['doc'] = doc
mock_db = MagicMock()
mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert)
with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)):
from app.models.usage_event import UsageEvent
await UsageEvent.record(
provider="gemini",
model="gemini-3.1-pro-preview",
prompt_tokens=300,
completion_tokens=150,
cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0},
)
assert captured['doc']['total_tokens'] == 450
@pytest.mark.asyncio
async def test_record_error_truncated_to_500_chars(self):
captured = {}
async def fake_insert(doc):
captured['doc'] = doc
mock_db = MagicMock()
mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert)
long_error = "E" * 1000
with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)):
from app.models.usage_event import UsageEvent
await UsageEvent.record(
provider="gemini",
model="gemini-3.1-pro-preview",
prompt_tokens=10,
completion_tokens=5,
cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0},
status="error",
error=long_error,
)
assert len(captured['doc']['error']) == 500
@pytest.mark.asyncio
async def test_record_db_error_is_swallowed(self):
"""A DB write failure must not propagate — telemetry cannot kill LLM calls."""
mock_db = MagicMock()
mock_db.usage_events.insert_one = AsyncMock(side_effect=RuntimeError("DB down"))
with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)):
from app.models.usage_event import UsageEvent
# Should not raise
await UsageEvent.record(
provider="gemini",
model="gemini-3.1-pro-preview",
prompt_tokens=10,
completion_tokens=5,
cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0},
)
# ══════════════════════════════════════════════════════════════════════════════
# 5. check_quota — admin bypass, warning threshold, exceeded error
# ══════════════════════════════════════════════════════════════════════════════
class TestCheckQuota:
"""
check_quota does lazy imports inside the function body:
from app.models.user import User
from app.models.usage_event import UsageEvent
from app.models.focus_group import FocusGroup
We inject mocks via patch.dict(sys.modules) so that when the function
executes those import statements at runtime it gets our mock objects.
"""
def _make_user_mod(self, user_doc):
m = MagicMock()
m.User.find_by_id = AsyncMock(return_value=user_doc)
return m
def _make_usage_mod(self, total):
m = MagicMock()
m.UsageEvent.sum_cost = AsyncMock(return_value=total)
return m
def _make_fg_mod(self, fg_doc):
m = MagicMock()
m.FocusGroup.find_by_id = AsyncMock(return_value=fg_doc)
return m
@pytest.mark.asyncio
async def test_no_user_no_fg_returns_none(self):
from app.models.quota import check_quota
result = await check_quota(user_id=None, focus_group_id=None)
assert result is None
@pytest.mark.asyncio
async def test_admin_bypasses_user_quota(self):
user_mod = self._make_user_mod(
{"role": "admin", "quota": {"monthly_usd": 0.01}, "override_quota": False}
)
usage_mod = self._make_usage_mod(999.0)
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
from app.models.quota import check_quota
result = await check_quota(user_id="admin_id", focus_group_id=None)
assert result is None
@pytest.mark.asyncio
async def test_override_quota_bypasses_user_quota(self):
user_mod = self._make_user_mod(
{"role": "user", "quota": {"monthly_usd": 1.0}, "override_quota": True}
)
usage_mod = self._make_usage_mod(999.0)
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
from app.models.quota import check_quota
result = await check_quota(user_id="u1", focus_group_id=None)
assert result is None
@pytest.mark.asyncio
async def test_no_quota_configured_returns_none(self):
user_mod = self._make_user_mod({"role": "user", "quota": {}, "override_quota": False})
usage_mod = self._make_usage_mod(10.0)
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
from app.models.quota import check_quota
result = await check_quota(user_id="u1", focus_group_id=None)
assert result is None
@pytest.mark.asyncio
async def test_under_quota_returns_none(self):
user_mod = self._make_user_mod(
{"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False}
)
usage_mod = self._make_usage_mod(20.0)
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
from app.models.quota import check_quota
result = await check_quota(user_id="u1", focus_group_id=None)
assert result is None
@pytest.mark.asyncio
async def test_quota_at_80_percent_returns_warning(self):
user_mod = self._make_user_mod(
{"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False}
)
usage_mod = self._make_usage_mod(42.0) # 84%
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
from app.models.quota import check_quota, QuotaWarning
result = await check_quota(user_id="u1", focus_group_id=None)
assert isinstance(result, QuotaWarning)
assert result.scope == "user"
assert result.limit_usd == 50.0
assert result.used_usd == 42.0
assert result.pct == pytest.approx(0.84, abs=0.001)
@pytest.mark.asyncio
async def test_quota_exceeded_raises_error(self):
user_mod = self._make_user_mod(
{"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False}
)
usage_mod = self._make_usage_mod(50.01)
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
from app.models.quota import check_quota, QuotaExceededError
with pytest.raises(QuotaExceededError) as exc_info:
await check_quota(user_id="u1", focus_group_id=None)
err = exc_info.value
assert err.scope == "user"
assert err.limit_usd == 50.0
assert err.used_usd == pytest.approx(50.01)
@pytest.mark.asyncio
async def test_fg_quota_exceeded_raises_error(self):
fg_mod = self._make_fg_mod({"quota": {"total_usd": 10.0}})
usage_mod = self._make_usage_mod(10.01)
with patch.dict(sys.modules, {'app.models.focus_group': fg_mod, 'app.models.usage_event': usage_mod}):
from app.models.quota import check_quota, QuotaExceededError
with pytest.raises(QuotaExceededError) as exc_info:
await check_quota(user_id=None, focus_group_id="fg1")
assert exc_info.value.scope == "focus_group"
@pytest.mark.asyncio
async def test_fg_quota_at_80_percent_returns_warning(self):
fg_mod = self._make_fg_mod({"quota": {"total_usd": 100.0}})
usage_mod = self._make_usage_mod(85.0)
with patch.dict(sys.modules, {'app.models.focus_group': fg_mod, 'app.models.usage_event': usage_mod}):
from app.models.quota import check_quota, QuotaWarning
result = await check_quota(user_id=None, focus_group_id="fg1")
assert isinstance(result, QuotaWarning)
assert result.scope == "focus_group"
@pytest.mark.asyncio
async def test_db_error_in_quota_check_is_non_fatal(self):
"""If User.find_by_id raises, quota check swallows and allows the call."""
user_mod = MagicMock()
user_mod.User.find_by_id = AsyncMock(side_effect=RuntimeError("DB timeout"))
with patch.dict(sys.modules, {'app.models.user': user_mod}):
from app.models.quota import check_quota
result = await check_quota(user_id="u1", focus_group_id=None)
assert result is None
@pytest.mark.asyncio
async def test_exactly_at_quota_limit_is_exceeded(self):
"""Spending exactly the limit should trigger QuotaExceededError (>=)."""
user_mod = self._make_user_mod(
{"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False}
)
usage_mod = self._make_usage_mod(50.0) # exactly at limit
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
from app.models.quota import check_quota, QuotaExceededError
with pytest.raises(QuotaExceededError):
await check_quota(user_id="u1", focus_group_id=None)
# ══════════════════════════════════════════════════════════════════════════════
# 6. QuotaExceededError — message and attributes
# ══════════════════════════════════════════════════════════════════════════════
class TestQuotaExceededError:
def test_attributes(self):
from app.models.quota import QuotaExceededError
from datetime import datetime, timezone
period = datetime(2026, 4, 1, tzinfo=timezone.utc)
err = QuotaExceededError("user", 50.0, 52.34, period)
assert err.scope == "user"
assert err.limit_usd == 50.0
assert err.used_usd == 52.34
assert err.period_start == period
assert "user" in str(err)
assert "52.3400" in str(err)
assert "50.00" in str(err)
def test_is_exception(self):
from app.models.quota import QuotaExceededError
assert issubclass(QuotaExceededError, Exception)

3
pytest.ini Normal file
View file

@ -0,0 +1,3 @@
[pytest]
asyncio_mode = auto
testpaths = backend/tests