From 3e9ccafad2cf618333f244a5fec64cc0c4b3276d Mon Sep 17 00:00:00 2001 From: Vadym Samoilenko Date: Fri, 24 Apr 2026 18:08:27 +0100 Subject: [PATCH] Add LLM usage tracking infrastructure (Phases A-C) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Model renames: gpt-5.2 → gpt-5.4-2026-03-05, gemini-3-pro-preview → gemini-3.1-pro-preview; retire gpt-4.1 via alias fallback - New: llm_usage_context.py (ContextVar-based attribution), model_pricing.py (tiered pricing + 60s cache), usage_event.py (append-only telemetry), quota.py (user/FG quota enforcement with 80% warning) - Wire _record_usage into all 3 LLM methods; set_llm_context at every service entry point - Fix admin_required decorator (was sync, never awaited User.find_by_id); add active_required and with_user_context decorators - Inject user_id into ContextVar from JWT on every authenticated request - Add DB indexes for usage_events, model_pricing, users collections - Seed script for model pricing (gpt-5.4 single-tier, gemini-3.1 two-tier 200k threshold) - Fix parse_json_response NameError (logger undefined at module level) - 70 passing tests: conftest.py with sys.modules stubs, test_usage_infrastructure.py (52 tests), rewrite stale test_llm_service.py (18 tests) Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 166 ++--- backend/app/auth/quart_jwt.py | 13 +- backend/app/db.py | 11 + backend/app/models/focus_group.py | 2 +- backend/app/models/model_pricing.py | 104 +++ backend/app/models/quota.py | 90 +++ backend/app/models/usage_event.py | 91 +++ backend/app/routes/ai_personas.py | 8 +- backend/app/routes/focus_group_ai.py | 4 +- backend/app/routes/personas.py | 8 +- backend/app/services/ai_moderator_service.py | 2 + backend/app/services/ai_persona_service.py | 4 + .../services/conversation_decision_service.py | 4 +- .../services/focus_group_response_service.py | 16 +- .../services/focus_group_summary_service.py | 2 + backend/app/services/key_theme_service.py | 6 +- backend/app/services/llm_service.py | 122 +++- backend/app/services/llm_usage_context.py | 54 ++ .../services/persona_modification_service.py | 8 +- backend/app/utils/__init__.py | 52 +- backend/requirements.txt | 3 + backend/scripts/seed_model_pricing.py | 101 +++ backend/tests/conftest.py | 53 ++ backend/tests/test_llm_service.py | 261 ++++--- backend/tests/test_usage_infrastructure.py | 664 ++++++++++++++++++ pytest.ini | 3 + 26 files changed, 1566 insertions(+), 286 deletions(-) create mode 100644 backend/app/models/model_pricing.py create mode 100644 backend/app/models/quota.py create mode 100644 backend/app/models/usage_event.py create mode 100644 backend/app/services/llm_usage_context.py create mode 100644 backend/scripts/seed_model_pricing.py create mode 100644 backend/tests/conftest.py create mode 100644 backend/tests/test_usage_infrastructure.py create mode 100644 pytest.ini diff --git a/CLAUDE.md b/CLAUDE.md index 8f09a1f6..644f5a80 100755 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -3,124 +3,88 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Commands -- **Build**: `npm run build` (use this for all testing and verification) +- **Dev Server**: `npm run dev` (port 5173, proxies `/api` → `localhost:5137`) +- **Build**: `npm run build` (use this to verify TypeScript compilation) - **Dev Build**: `npm run build:dev` (development mode build) - **Lint**: `npm run lint` -- **Preview**: `npm run preview` -- **Backend**: `cd backend && python run.py` +- **Backend**: `cd backend && python run.py` (Hypercorn ASGI on port 5137) -**Note**: This project supports both local development and production deployment. See Environment Configuration section below. - -## Backend Commands -- **Start Backend**: `python run.py` (from backend/ directory) -- **Backend Server**: Runs on port 5137 with Hypercorn ASGI server -- **Database**: MongoDB with PyMongo -- **Authentication**: Custom Quart-compatible JWT (not Flask-JWT-Extended) - -## Testing -**Python Backend**: After modifying any Python files: +## Backend Testing +After modifying any Python files: ```bash source backend/venv/bin/activate -python -c "import app.services.module_name" # Test specific module -python -c "from app import create_app; app = create_app()" # Test app creation +python -c "import app.services.module_name" # Test specific module +python -c "from app import create_app; create_app()" # Test app creation ``` -**Frontend**: Run `npm run build` to verify TypeScript compilation ## Architecture Overview -### Real-Time Communication -The application uses Socket.IO for real-time WebSocket communication between frontend and backend: -- **Backend**: `python-socketio` with `AsyncServer` wrapped in ASGI app -- **Frontend**: `socket.io-client` managed via `WebSocketContext` -- WebSocket manager (`websocket_manager_async.py`) handles room-based messaging for focus group sessions +### ASGI Stack (critical detail) +`create_app()` returns a **`socketio.ASGIApp`** wrapping a Quart app — not the Quart app itself. Accessing `app.quart_app` gives the inner Quart instance. This distinction matters whenever you write ASGI middleware or access app config directly. -### Autonomous Conversation System -Focus group sessions can run autonomously with AI-driven conversations: -- `ai_runner_service.py` - Manages background task execution for autonomous mode -- `autonomous_conversation_controller.py` - Orchestrates multi-persona conversations -- `conversation_decision_service.py` - Determines next speaker and conversation flow -- `conversation_context_service.py` - Maintains conversation state and history +### Real-Time Communication +Socket.IO via `python-socketio` `AsyncServer` (ASGI mode). The `WebSocketContextNew.tsx` context manages the client connection. `websocket_manager_async.py` handles room-based messaging for focus group sessions. The WebSocket manager must call `ws_mgr.set_main_loop(asyncio.get_running_loop())` at startup so that cross-thread emits from the AI Runner land on the right loop. + +> `VITE_ENABLE_WEBSOCKET` is hardcoded `true` in dev and `false` in production builds via `vite.config.ts` — it is not controlled by `.env`. + +### AI Runner + Threading +`ai_runner_service.py` is a singleton that owns a **dedicated OS thread** with a single asyncio event loop. All autonomous AI conversations run in this thread. This solves Motor (AsyncIOMotorClient) event-loop affinity: Motor clients in the AI runner are bound to that loop, while regular API routes use synchronous PyMongo. Never share Motor clients between the two contexts. + +### Autonomous Conversation Pipeline +1. `ai_runner_service.py` — spawns coroutines on the dedicated thread's event loop +2. `autonomous_conversation_controller.py` — orchestrates the full session +3. `conversation_decision_service.py` — picks the next speaker +4. `conversation_context_service.py` — maintains history/state +5. `conversation_state_manager.py` — in-memory state across turns + +### Task Manager +`task_manager.py` is a singleton tracking cancellable asyncio tasks (persona generation, discussion guides, etc.). Tasks are exposed via `/api/tasks` routes. A background sweeper cleans up completed/expired tasks. Frontend polling is handled by `useTaskPolling.ts`. ### LLM Integration -Multi-model support through `llm_service.py`: -- **Google Gemini** (`gemini-3-pro-preview`) - Default model -- **OpenAI** (`gpt-4.1`, `gpt-5.2`) - Alternative models -- Prompts are stored as markdown templates in `/backend/prompts/` +`llm_service.py` creates fresh clients per call (avoids event-loop mismatch in ASGI). Default model: **Google Gemini** via `google-genai`. Alternative: **OpenAI** (`AsyncOpenAI`). Both require env vars `GEMINI_API_KEY` and `OPENAI_API_KEY` — startup fails if missing. Prompts are markdown templates in `/backend/prompts/` loaded by `prompt_loader.py`. -## Code Style Guidelines -- **Imports**: Group imports by source (React, third-party, local) -- **Types**: Use TypeScript. Project allows nullable types (`strictNullChecks: false`) -- **Components**: Use functional components with hooks -- **Naming**: Use PascalCase for components, camelCase for variables/functions -- **Formatting**: Follow ESLint recommendations, focus on readability -- **Error Handling**: Use try/catch blocks with toast for user feedback -- **CSS**: Use Tailwind classes for styling, with component-specific CSS files when needed -- **File Structure**: Components in `/src/components`, pages in `/src/pages`, hooks in `/src/hooks` -- **UI Components**: Use shadcn-ui components from `/src/components/ui` -- **State Management**: React hooks for local state, context/props for sharing -- **URL Construction**: ALWAYS use `import.meta.env.BASE_URL` when constructing URLs for static assets, images, or links. This project uses base path `/semblance/` in production. Example: `${import.meta.env.BASE_URL}image.png` instead of `/image.png` - -## Project Stack -**Frontend**: Vite, React 18, TypeScript, Tailwind CSS, shadcn-ui -**Backend**: Quart (async Flask), Hypercorn ASGI, PyMongo, python-socketio -**Key Libraries**: -- UI: Radix UI components, Lucide React icons -- State: TanStack Query, React Hook Form with Zod validation -- Routing: React Router DOM -- AI/LLM: OpenAI, Google Generative AI (genai) -- Real-time: Socket.IO (client and server) -- Charts: Recharts -- Drag & Drop: DND Kit - -## API Configuration -- **Frontend API Base**: `/semblance_back/api` (configurable via VITE_API_BASE_URL) -- **Backend Proxy**: Vite dev server proxies `/api` to `localhost:5137` -- **Production Base Path**: `/semblance/` (configured in vite.config.ts) -- **Authentication**: JWT tokens stored in localStorage +## Code Style +- TypeScript with `strictNullChecks: false` +- Functional components with hooks; local state via hooks, shared state via context/props +- `@/` alias maps to `src/` +- **URL construction**: always use `${import.meta.env.BASE_URL}asset.png` — production base is `/semblance/` +- Error handling: try/catch + `sonner` toast for user feedback ## File Organization -- **Backend Services**: `/backend/app/services/` - Business logic and AI integrations -- **Backend Models**: `/backend/app/models/` - Data models (User, FocusGroup, Persona, Folder) -- **Backend Routes**: `/backend/app/routes/` - API endpoints (auth, personas, focus-groups, ai-personas, folders, tasks) -- **AI Prompts**: `/backend/prompts/` - LLM prompt templates (markdown files loaded by `prompt_loader.py`) -- **Frontend Components**: - - `/src/components/ui/` - Reusable shadcn-ui components - - `/src/components/focus-group-session/` - Focus group session UI (DiscussionPanel, ParticipantPanel, ThemesPanel, etc.) - - `/src/components/persona/` - Persona management components -- **Types**: `/src/types/` - TypeScript type definitions -- **Contexts**: `/src/contexts/` - React context providers (AuthContext, WebSocketContext, NavigationContext) +``` +backend/ + app/ + routes/ # Blueprints: auth, personas, focus-groups, ai-personas, focus-group-ai, folders, tasks + services/ # Business logic: llm_service, ai_runner_service, task_manager, autonomous_*, conversation_* + models/ # Data models: User, FocusGroup, Persona, Folder + auth/ # Auth utilities (JWT helpers) + prompts/ # LLM prompt markdown templates + websocket_manager_async.py # Room-based async WebSocket manager + extensions.py # socketio.AsyncServer singleton + +src/ + pages/ # Route-level components (Dashboard, FocusGroups, FocusGroupSession, Login, SyntheticUsers) + components/ + focus-group-session/ # Session UI panels (Discussion, Participant, Themes, etc.) + persona/ # Persona management components + ui/ # shadcn-ui primitives + contexts/ # AuthContext, WebSocketContextNew, NavigationContext + hooks/ # useTaskPolling, useWebSocket, usePersonaStorage, useDiscussionGuideGeneration, etc. + types/ # TypeScript type definitions +``` ## Environment Configuration -This application supports both local development and production deployment through environment-specific configuration files: +| Setting | Development | Production | +|---------|-------------|------------| +| Base path | `/` | `/semblance/` | +| API base | `/api` (proxied to 5137) | `https://optical-dev.oliver.solutions/semblance_back/api` | +| WebSocket path | `/socket.io/` | `/semblance_back/socket.io/` | +| MSAL redirect | `http://localhost:5173/` | `https://optical-dev.oliver.solutions/semblance` | -### Environment Files -- **`.env.development`**: Local development configuration -- **`.env.production`**: Production server configuration -- **`.env`**: Active configuration (copy from appropriate environment file) +Setup: copy `.env.development` or `.env.production` to `.env`. Backend requires `backend/.env` with `SECRET_KEY`, `JWT_SECRET_KEY`, `GEMINI_API_KEY`, `OPENAI_API_KEY` — startup will throw `RuntimeError` if any are missing or use weak defaults. -### Development vs Production -The application automatically adapts based on environment variables: - -**Development Mode:** -- Base path: `/` (root) -- API base: `/api` -- WebSocket path: `/socket.io/` -- MSAL redirect: `http://localhost:5173/` - -**Production Mode:** -- Base path: `/semblance/` -- API base: `https://optical-dev.oliver.solutions/semblance_back/api` -- WebSocket path: `/semblance_back/socket.io/` -- MSAL redirect: `https://optical-dev.oliver.solutions/semblance` - -### Setup Instructions -1. **For local development**: Copy `.env.development` to `.env` -2. **For production**: Copy `.env.production` to `.env` -3. The build system will use the appropriate configuration - -### Technical Details -- **Base Path**: Configured in vite.config.ts based on `NODE_ENV` -- **Backend Port**: 5137 (Hypercorn ASGI server) -- **Frontend Dev Port**: 5173 -- **Temp Directories**: Backend creates `/backend/temp/` for file handling \ No newline at end of file +## Knowledge Wiki +A cross-project knowledge base is maintained automatically from all Claude Code sessions. +- **Index:** `/Users/ai_leed/Library/Mobile Documents/iCloud~md~obsidian/Documents/VadymSamoilenko/wiki/index.md` +- **Query:** `cd ~/.claude/memory-compiler && uv run python scripts/query.py "your question"` diff --git a/backend/app/auth/quart_jwt.py b/backend/app/auth/quart_jwt.py index 285fab80..947469cc 100755 --- a/backend/app/auth/quart_jwt.py +++ b/backend/app/auth/quart_jwt.py @@ -147,7 +147,18 @@ def jwt_required(optional: bool = False): # Store user ID in request context g.current_user_id = user_id - + + # Propagate user_id into the LLM usage ContextVar for this request. + # Each Quart request runs in its own asyncio Task, so setting the ContextVar + # here is request-scoped. Child tasks (create_task) and thread submissions + # (run_coroutine_threadsafe) inherit this context automatically. + try: + from app.services.llm_usage_context import _ctx, current_context + from dataclasses import replace as _dc_replace + _ctx.set(_dc_replace(current_context(), user_id=user_id)) + except Exception: + pass # Non-fatal — telemetry only + # Call the actual route function and handle tuple returns result = await func(*args, **kwargs) diff --git a/backend/app/db.py b/backend/app/db.py index 5533f9ea..0bb9a4d5 100755 --- a/backend/app/db.py +++ b/backend/app/db.py @@ -50,10 +50,21 @@ async def get_db(): try: await database.users.create_index("username", unique=True, background=True) await database.users.create_index("email", unique=True, background=True) + await database.users.create_index("role", background=True) + await database.users.create_index([("is_active", 1), ("username", 1)], background=True) await database.personas.create_index("created_by", background=True) await database.focus_groups.create_index("created_by", background=True) await database.folders.create_index("created_by", background=True) await database.folders.create_index("parent_folder_id", background=True) + # usage_events indexes + await database.usage_events.create_index([("user_id", 1), ("ts", -1)], background=True) + await database.usage_events.create_index([("focus_group_id", 1), ("ts", -1)], background=True) + await database.usage_events.create_index([("ts", -1)], background=True) + await database.usage_events.create_index([("feature", 1), ("ts", -1)], background=True) + await database.usage_events.create_index([("model", 1), ("ts", -1)], background=True) + await database.usage_events.create_index([("status", 1), ("ts", -1)], background=True) + # model_pricing indexes + await database.model_pricing.create_index([("model", 1), ("effective_from", -1)], background=True) except Exception as e: logging.warning(f"Index creation warning (non-fatal): {e}") diff --git a/backend/app/models/focus_group.py b/backend/app/models/focus_group.py index c19591c9..14bd83ca 100755 --- a/backend/app/models/focus_group.py +++ b/backend/app/models/focus_group.py @@ -52,7 +52,7 @@ class FocusGroup: # Set default LLM model if not provided if "llm_model" not in focus_group_data: - focus_group_data["llm_model"] = "gemini-3-pro-preview" + focus_group_data["llm_model"] = "gemini-3.1-pro-preview" # Set default GPT-5 parameters if not provided if "reasoning_effort" not in focus_group_data: diff --git a/backend/app/models/model_pricing.py b/backend/app/models/model_pricing.py new file mode 100644 index 00000000..92f3db04 --- /dev/null +++ b/backend/app/models/model_pricing.py @@ -0,0 +1,104 @@ +from app.db import get_db +from datetime import datetime, timezone +import logging +import time + +logger = logging.getLogger(__name__) + +# In-process cache: (model_name -> (pricing_dict, cached_at_monotonic)) +_pricing_cache: dict = {} +_CACHE_TTL_SECONDS = 60 + + +def _cache_get(model: str): + entry = _pricing_cache.get(model) + if entry and (time.monotonic() - entry[1]) < _CACHE_TTL_SECONDS: + return entry[0] + return None + + +def _cache_set(model: str, pricing: dict): + _pricing_cache[model] = (pricing, time.monotonic()) + + +class ModelPricing: + @staticmethod + async def current_for(model_name: str) -> dict | None: + """Return the active pricing row for a model, with 60 s in-process cache. + + Resolves MODEL_ALIASES before lookup so callers can pass raw model names. + Returns None if no pricing is configured (cost will be recorded as 0). + """ + from app.services.llm_service import MODEL_ALIASES + resolved = MODEL_ALIASES.get(model_name, model_name) + + cached = _cache_get(resolved) + if cached is not None: + return cached + + try: + db = await get_db() + now = datetime.now(timezone.utc) + doc = await db.model_pricing.find_one( + { + "model": resolved, + "effective_from": {"$lte": now}, + "$or": [ + {"effective_until": None}, + {"effective_until": {"$gt": now}}, + ], + }, + sort=[("effective_from", -1)], + ) + _cache_set(resolved, doc) + return doc + except Exception: + logger.warning(f"Failed to fetch pricing for model {resolved}", exc_info=True) + return None + + @staticmethod + def pick_tier(pricing: dict, prompt_tokens: int) -> dict | None: + """Return the cost tier that applies for a given prompt token count.""" + if not pricing: + return None + tiers = pricing.get("tiers") or [] + if not tiers: + return None + # Pick the tier with the largest threshold still <= prompt_tokens + applicable = [t for t in tiers if t.get("threshold_input_tokens", 0) <= prompt_tokens] + if not applicable: + applicable = tiers # fall back to first tier + return max(applicable, key=lambda t: t.get("threshold_input_tokens", 0)) + + @staticmethod + def compute_cost(pricing: dict | None, prompt_tokens: int, completion_tokens: int, + cached_tokens: int = 0) -> dict: + """Compute cost breakdown from token counts and pricing doc. + + Returns a dict with keys: input, cached, output, total (all USD floats). + All values are 0.0 if pricing is None. + """ + zero = {"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0} + if not pricing: + return zero + + tier = ModelPricing.pick_tier(pricing, prompt_tokens) + if not tier: + return zero + + input_per_mtok = tier.get("input_per_mtok") or 0 + cached_per_mtok = tier.get("cached_input_per_mtok") or 0 + output_per_mtok = tier.get("output_per_mtok") or 0 + + billable_input = max(0, prompt_tokens - cached_tokens) + cost_input = billable_input * input_per_mtok / 1_000_000 + cost_cached = cached_tokens * cached_per_mtok / 1_000_000 + cost_output = completion_tokens * output_per_mtok / 1_000_000 + cost_total = cost_input + cost_cached + cost_output + + return { + "input": round(cost_input, 8), + "cached": round(cost_cached, 8), + "output": round(cost_output, 8), + "total": round(cost_total, 8), + } diff --git a/backend/app/models/quota.py b/backend/app/models/quota.py new file mode 100644 index 00000000..6f088b93 --- /dev/null +++ b/backend/app/models/quota.py @@ -0,0 +1,90 @@ +from datetime import datetime, timezone +import logging + +logger = logging.getLogger(__name__) + + +class QuotaExceededError(Exception): + def __init__(self, scope: str, limit_usd: float, used_usd: float, period_start=None): + self.scope = scope # "user" | "focus_group" + self.limit_usd = limit_usd + self.used_usd = used_usd + self.period_start = period_start + super().__init__( + f"Quota exceeded ({scope}): used ${used_usd:.4f} of ${limit_usd:.2f} limit" + ) + + +class QuotaWarning: + def __init__(self, scope: str, limit_usd: float, used_usd: float, pct: float): + self.scope = scope + self.limit_usd = limit_usd + self.used_usd = used_usd + self.pct = pct + + +async def check_quota(user_id: str | None, focus_group_id: str | None) -> QuotaWarning | None: + """Check quotas and raise QuotaExceededError if either is exceeded. + + Returns a QuotaWarning (not raised) when usage is between 80 % and 100 %. + Returns None if all quotas are fine. + + Admins and users with override_quota=True bypass user-level quota. + Focus-group quotas apply to everyone (including admins) — they are project budgets. + """ + from app.models.user import User + from app.models.usage_event import UsageEvent + + warning = None + + if user_id: + try: + user = await User.find_by_id(user_id) + if user: + is_admin = user.get("role") == "admin" + override = user.get("override_quota", False) + + if not is_admin and not override: + quota = user.get("quota") or {} + limit = quota.get("monthly_usd") + + if limit: + now = datetime.now(timezone.utc) + period_start = now.replace( + day=1, hour=0, minute=0, second=0, microsecond=0 + ) + spent = await UsageEvent.sum_cost( + {"user_id": user_id, "ts": {"$gte": period_start}} + ) + + pct = spent / limit if limit else 0 + if spent >= limit: + raise QuotaExceededError("user", limit, spent, period_start) + elif pct >= 0.8: + warning = QuotaWarning("user", limit, spent, pct) + except QuotaExceededError: + raise + except Exception: + logger.warning("Quota check failed (non-fatal, allowing call)", exc_info=True) + + if focus_group_id: + try: + from app.models.focus_group import FocusGroup + fg = await FocusGroup.find_by_id(focus_group_id) + if fg: + fg_quota = fg.get("quota") or {} + fg_limit = fg_quota.get("total_usd") + + if fg_limit: + spent = await UsageEvent.sum_cost({"focus_group_id": focus_group_id}) + pct = spent / fg_limit if fg_limit else 0 + if spent >= fg_limit: + raise QuotaExceededError("focus_group", fg_limit, spent, None) + elif pct >= 0.8 and not warning: + warning = QuotaWarning("focus_group", fg_limit, spent, pct) + except QuotaExceededError: + raise + except Exception: + logger.warning("Focus-group quota check failed (non-fatal)", exc_info=True) + + return warning diff --git a/backend/app/models/usage_event.py b/backend/app/models/usage_event.py new file mode 100644 index 00000000..e3ee0e30 --- /dev/null +++ b/backend/app/models/usage_event.py @@ -0,0 +1,91 @@ +from app.db import get_db +from datetime import datetime, timezone +import logging + +logger = logging.getLogger(__name__) + +VALID_FEATURES = { + "moderator", "persona_response", "persona_generate", "persona_modify", + "persona_export", "key_themes", "summary", "discussion_guide", + "image_description", "conversation_decision", "other", +} + + +class UsageEvent: + @staticmethod + async def record( + *, + provider: str, + model: str, + prompt_tokens: int, + completion_tokens: int, + cached_tokens: int = 0, + reasoning_tokens: int = 0, + cost_usd: dict, + price_snapshot_id=None, + duration_ms: int = 0, + retry_count: int = 0, + status: str = "success", + error: str | None = None, + is_estimated: bool = False, + estimate_method: str | None = None, + user_id: str | None = None, + focus_group_id: str | None = None, + persona_id: str | None = None, + feature: str = "other", + task_id: str | None = None, + ) -> None: + """Append one usage event doc. Never raises — telemetry must not kill LLM calls.""" + try: + if feature not in VALID_FEATURES: + feature = "other" + + doc = { + "ts": datetime.now(timezone.utc), + "user_id": user_id, + "focus_group_id": focus_group_id, + "persona_id": persona_id, + "task_id": task_id, + "feature": feature, + "provider": provider, + "model": model, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "cached_tokens": cached_tokens, + "reasoning_tokens": reasoning_tokens, + "total_tokens": prompt_tokens + completion_tokens, + "cost_usd": cost_usd, + "price_snapshot_id": price_snapshot_id, + "duration_ms": duration_ms, + "retry_count": retry_count, + "status": status, + "error": (error or "")[:500] if error else None, + "is_estimated": is_estimated, + "estimate_method": estimate_method, + } + + db = await get_db() + await db.usage_events.insert_one(doc) + + if not user_id or not focus_group_id: + logger.info( + f"Usage event recorded with partial context: user_id={user_id} " + f"focus_group_id={focus_group_id} feature={feature} model={model}" + ) + except Exception: + logger.warning("Failed to record usage event (non-fatal)", exc_info=True) + + @staticmethod + async def sum_cost(match: dict) -> float: + """Return total cost_usd.total for the given match filter.""" + try: + db = await get_db() + pipeline = [ + {"$match": match}, + {"$group": {"_id": None, "total": {"$sum": "$cost_usd.total"}}}, + ] + result = await db.usage_events.aggregate(pipeline).to_list(1) + return result[0]["total"] if result else 0.0 + except Exception: + logger.warning("Failed to sum usage costs (non-fatal)", exc_info=True) + return 0.0 diff --git a/backend/app/routes/ai_personas.py b/backend/app/routes/ai_personas.py index 98164db4..19d2ff65 100755 --- a/backend/app/routes/ai_personas.py +++ b/backend/app/routes/ai_personas.py @@ -74,7 +74,7 @@ async def generate_basic_profiles(): temperature = 1.0 customer_data_session_id = data.get('customer_data_session_id') # Optional parameter - llm_model = data.get('llm_model', 'gemini-3-pro-preview') # Optional parameter with default + llm_model = data.get('llm_model', 'gemini-3.1-pro-preview') # Optional parameter with default try: # Register current task for cancellation @@ -210,7 +210,7 @@ async def complete_and_save_persona(): temperature = 1.0 customer_data_session_id = data.get('customer_data_session_id') # Optional parameter - llm_model = data.get('llm_model', 'gemini-3-pro-preview') # Optional parameter with default + llm_model = data.get('llm_model', 'gemini-3.1-pro-preview') # Optional parameter with default # Get persona name for logging persona_name = basic_profile.get('name', 'Unknown') @@ -835,7 +835,7 @@ async def batch_generate_summaries(): if not (0 <= temperature <= 1.5): temperature = 1.0 - llm_model = data.get('llm_model', 'gemini-3-pro-preview') # Optional parameter with default + llm_model = data.get('llm_model', 'gemini-3.1-pro-preview') # Optional parameter with default # Log the request with model information print(f"🔄 Backend: Received batch-generate-summaries request for {len(persona_ids)} personas with model: {llm_model}") @@ -1192,7 +1192,7 @@ async def generate_personas_full(): temperature = 1.0 customer_data_session_id = data.get('customer_data_session_id') - llm_model = data.get('llm_model', 'gemini-3-pro-preview') + llm_model = data.get('llm_model', 'gemini-3.1-pro-preview') target_folder_id = data.get('target_folder_id') try: diff --git a/backend/app/routes/focus_group_ai.py b/backend/app/routes/focus_group_ai.py index 3750c16a..751da7e7 100755 --- a/backend/app/routes/focus_group_ai.py +++ b/backend/app/routes/focus_group_ai.py @@ -185,8 +185,8 @@ Be genuine and specific in your feedback, drawing on your personal experiences a conversation_context=multimodal_context['conversation_context'], temperature=temperature, model_name=llm_model, - reasoning_effort=reasoning_effort if llm_model in ('gpt-5', 'gpt-5.2') else None, - verbosity=verbosity if llm_model in ('gpt-5', 'gpt-5.2') else None + reasoning_effort=reasoning_effort if llm_model in ('gpt-5', 'gpt-5.4-2026-03-05') else None, + verbosity=verbosity if llm_model in ('gpt-5', 'gpt-5.4-2026-03-05') else None ) else: response_text = await generate_persona_response( diff --git a/backend/app/routes/personas.py b/backend/app/routes/personas.py index 17914442..63ce46cb 100755 --- a/backend/app/routes/personas.py +++ b/backend/app/routes/personas.py @@ -161,7 +161,7 @@ async def modify_persona_with_ai(persona_id): Request body should include: - modification_prompt: Natural language description of desired changes - - llm_model: Model to use (defaults to 'gemini-3-pro-preview') + - llm_model: Model to use (defaults to 'gemini-3.1-pro-preview') - reasoning_effort: For GPT-5 (minimal, low, medium, high) - verbosity: For GPT-5 (low, medium, high) - preview_only: If true, returns modified data without saving to database (defaults to false) @@ -175,7 +175,7 @@ async def modify_persona_with_ai(persona_id): if not modification_prompt: return jsonify({"error": "modification_prompt is required"}), 400 - llm_model = request_data.get('llm_model', 'gemini-3-pro-preview') + llm_model = request_data.get('llm_model', 'gemini-3.1-pro-preview') reasoning_effort = request_data.get('reasoning_effort', 'medium') verbosity = request_data.get('verbosity', 'medium') preview_only = request_data.get('preview_only', False) @@ -246,7 +246,7 @@ async def export_persona_profile(persona_id): Returns 202 immediately; result delivered via WebSocket task_completed event. Request body can optionally include: - - llm_model: Model to use (defaults to 'gpt-4.1') + - llm_model: Model to use (defaults to 'gemini-3.1-pro-preview') - temperature: Temperature for generation (defaults to 0.3) """ try: @@ -255,7 +255,7 @@ async def export_persona_profile(persona_id): return jsonify({"error": "Persona not found"}), 404 request_data = await request.get_json() or {} - llm_model = request_data.get('llm_model', 'gpt-4.1') + llm_model = request_data.get('llm_model', 'gemini-3.1-pro-preview') temperature = request_data.get('temperature', 0.3) user_id = get_jwt_identity() diff --git a/backend/app/services/ai_moderator_service.py b/backend/app/services/ai_moderator_service.py index ea1ff1f5..abe9e219 100755 --- a/backend/app/services/ai_moderator_service.py +++ b/backend/app/services/ai_moderator_service.py @@ -585,6 +585,8 @@ class AIModeratorService: Generated moderator response """ try: + from app.services.llm_usage_context import set_llm_context + set_llm_context(feature="moderator", focus_group_id=focus_group_id) # Get previous messages for context messages = await FocusGroup.get_messages(focus_group_id) recent_messages = messages[-10:] if messages else [] # Last 10 messages diff --git a/backend/app/services/ai_persona_service.py b/backend/app/services/ai_persona_service.py index da9b9b1d..d61cb0be 100755 --- a/backend/app/services/ai_persona_service.py +++ b/backend/app/services/ai_persona_service.py @@ -138,6 +138,8 @@ async def generate_basic_personas( Raises: PersonaGenerationError: If there's an issue with the AI generation or JSON parsing """ + from app.services.llm_usage_context import set_llm_context + set_llm_context(feature="persona_generate") last_error = None for attempt in range(max_retries + 1): @@ -467,6 +469,8 @@ async def generate_persona( PersonaGenerationError: If there's an issue with the AI generation or JSON parsing """ try: + from app.services.llm_usage_context import set_llm_context + set_llm_context(feature="persona_generate") # If audience_brief or research_objective provided but no prompt_customization, # generate customization so the LLM knows the research context if not prompt_customization and (audience_brief or research_objective): diff --git a/backend/app/services/conversation_decision_service.py b/backend/app/services/conversation_decision_service.py index 3fc3daac..a9d7124f 100755 --- a/backend/app/services/conversation_decision_service.py +++ b/backend/app/services/conversation_decision_service.py @@ -36,8 +36,10 @@ class ConversationDecisionService: ConversationDecisionError: If there's an issue with decision making """ print(f"🎯 Decision request: {mode} mode for focus group {focus_group_id}") - + try: + from app.services.llm_usage_context import set_llm_context + set_llm_context(feature="conversation_decision", focus_group_id=focus_group_id) # Get full conversation context context = await ConversationContextService.get_full_context(focus_group_id) formatted_context = ConversationContextService.format_context_for_llm(context) diff --git a/backend/app/services/focus_group_response_service.py b/backend/app/services/focus_group_response_service.py index 664f0134..6bff28fe 100755 --- a/backend/app/services/focus_group_response_service.py +++ b/backend/app/services/focus_group_response_service.py @@ -46,13 +46,15 @@ async def generate_persona_response( FocusGroupResponseError: If there's an issue with the response generation """ try: + from app.services.llm_usage_context import set_llm_context + set_llm_context( + feature="persona_response", + focus_group_id=focus_group_id or None, + persona_id=str(persona.get("_id", "")) or None, + ) print(f"🎭 Generating persona response for {persona.get('name', 'Unknown')}") print(f" - focus_group_id: {focus_group_id}") print(f" - current_topic: {current_topic[:50]}...") - if llm_model in ('gpt-5', 'gpt-5.2'): - print(f" - llm_model: {llm_model} (reasoning_effort: {reasoning_effort or 'medium'}, verbosity: {verbosity or 'medium'}) [using Responses API]") - else: - print(f" - llm_model: {llm_model or 'default (gemini-3-pro-preview)'}") # Import LLMService at the top to avoid scoping issues from app.services.llm_service import LLMService @@ -417,6 +419,12 @@ async def generate_creative_review_response( FocusGroupResponseError: If there's an issue with the response generation """ try: + from app.services.llm_usage_context import set_llm_context + set_llm_context( + feature="persona_response", + focus_group_id=focus_group_id or None, + persona_id=str(persona.get("_id", "")) or None, + ) print(f"🎨 CREATIVE REVIEW RESPONSE DEBUG:") print(f" - persona: {persona.get('name', 'Unknown')}") print(f" - current_topic: {current_topic}") diff --git a/backend/app/services/focus_group_summary_service.py b/backend/app/services/focus_group_summary_service.py index 49ccb8b6..f4e6d380 100755 --- a/backend/app/services/focus_group_summary_service.py +++ b/backend/app/services/focus_group_summary_service.py @@ -43,6 +43,8 @@ async def generate_focus_group_summary( A concise one-line summary string (max ~100 characters), or None on failure """ try: + from app.services.llm_usage_context import set_llm_context + set_llm_context(feature="summary") # Extract data from focus group name = focus_group_data.get('name', 'Unnamed Focus Group') topic = focus_group_data.get('topic', 'General Research') diff --git a/backend/app/services/key_theme_service.py b/backend/app/services/key_theme_service.py index 89689881..e9515880 100755 --- a/backend/app/services/key_theme_service.py +++ b/backend/app/services/key_theme_service.py @@ -40,9 +40,11 @@ class KeyThemeService: KeyThemeServiceError: If there's an issue with the generation process """ logger = logging.getLogger(__name__) + from app.services.llm_usage_context import set_llm_context + set_llm_context(feature="key_themes", focus_group_id=focus_group_id) logger.info(f"Starting key theme generation for focus group {focus_group_id} with temperature {temperature}") - logger.info(f"Using LLM model: {llm_model or 'default (gemini-3-pro-preview)'}") - + logger.info(f"Using LLM model: {llm_model or 'default'}") + try: # Get the focus group focus_group = await FocusGroup.find_by_id(focus_group_id) diff --git a/backend/app/services/llm_service.py b/backend/app/services/llm_service.py index 98a67224..74d839bd 100755 --- a/backend/app/services/llm_service.py +++ b/backend/app/services/llm_service.py @@ -11,6 +11,7 @@ import asyncio import logging import base64 import traceback +import time from google import genai from google.genai import errors as genai_errors from openai import AsyncOpenAI @@ -59,18 +60,20 @@ def get_openai_client(): return AsyncOpenAI(api_key=OPENAI_API_KEY, timeout=600.0) # The default model we're using -DEFAULT_MODEL = "gemini-3-pro-preview" +DEFAULT_MODEL = "gemini-3.1-pro-preview" # Supported models SUPPORTED_MODELS = { - 'gemini-3-pro-preview': 'gemini', - 'gpt-4.1': 'openai', - 'gpt-5.2': 'openai' + 'gemini-3.1-pro-preview': 'gemini', + 'gpt-5.4-2026-03-05': 'openai', } # Aliases for renamed/legacy model IDs stored in the database MODEL_ALIASES = { - 'gpt-5': 'gpt-5.2', + 'gpt-5': 'gpt-5.4-2026-03-05', + 'gpt-5.2': 'gpt-5.4-2026-03-05', + 'gemini-3-pro-preview': 'gemini-3.1-pro-preview', + 'gpt-4.1': 'gemini-3.1-pro-preview', } class LLMServiceError(Exception): @@ -182,6 +185,83 @@ class LLMService: raise raise LLMServiceError(f"Error extracting text from new GenAI SDK response: {str(e)}") + @staticmethod + def _extract_usage_metadata(response, provider: str) -> dict: + """Extract token counts from a provider response. All fields default to 0.""" + if provider == 'gemini': + um = getattr(response, 'usage_metadata', None) + if um is None: + return {'prompt': 0, 'completion': 0, 'cached': 0, 'reasoning': 0} + return { + 'prompt': getattr(um, 'prompt_token_count', 0) or 0, + 'completion': getattr(um, 'candidates_token_count', 0) or 0, + 'cached': getattr(um, 'cached_content_token_count', 0) or 0, + 'reasoning': 0, + } + elif provider == 'openai': + usage = getattr(response, 'usage', None) + if usage is None: + return {'prompt': 0, 'completion': 0, 'cached': 0, 'reasoning': 0} + # Responses API (gpt-5.4-2026-03-05) + if hasattr(usage, 'input_tokens'): + input_details = getattr(usage, 'input_tokens_details', None) + output_details = getattr(usage, 'output_tokens_details', None) + return { + 'prompt': getattr(usage, 'input_tokens', 0) or 0, + 'completion': getattr(usage, 'output_tokens', 0) or 0, + 'cached': getattr(input_details, 'cached_tokens', 0) or 0 if input_details else 0, + 'reasoning': getattr(output_details, 'reasoning_tokens', 0) or 0 if output_details else 0, + } + # Chat Completions API + prompt_details = getattr(usage, 'prompt_tokens_details', None) + return { + 'prompt': getattr(usage, 'prompt_tokens', 0) or 0, + 'completion': getattr(usage, 'completion_tokens', 0) or 0, + 'cached': getattr(prompt_details, 'cached_tokens', 0) or 0 if prompt_details else 0, + 'reasoning': 0, + } + return {'prompt': 0, 'completion': 0, 'cached': 0, 'reasoning': 0} + + @staticmethod + async def _record_usage(response, provider: str, model: str, start_time: float, retry_count: int) -> None: + """Record a usage event after a successful LLM call. Never raises.""" + try: + from app.services.llm_usage_context import current_context + from app.models.usage_event import UsageEvent + from app.models.model_pricing import ModelPricing + + ctx = current_context() + tokens = LLMService._extract_usage_metadata(response, provider) + pricing = await ModelPricing.current_for(model) + cost = ModelPricing.compute_cost( + pricing, + prompt_tokens=tokens['prompt'], + completion_tokens=tokens['completion'], + cached_tokens=tokens['cached'], + ) + price_id = pricing.get('_id') if pricing else None + + await UsageEvent.record( + provider=provider, + model=model, + prompt_tokens=tokens['prompt'], + completion_tokens=tokens['completion'], + cached_tokens=tokens['cached'], + reasoning_tokens=tokens['reasoning'], + cost_usd=cost, + price_snapshot_id=price_id, + duration_ms=int((time.monotonic() - start_time) * 1000), + retry_count=retry_count, + status="success", + user_id=ctx.user_id, + focus_group_id=ctx.focus_group_id, + persona_id=ctx.persona_id, + feature=ctx.feature, + task_id=ctx.task_id, + ) + except Exception: + logging.getLogger(__name__).warning("_record_usage failed (non-fatal)", exc_info=True) + @staticmethod async def generate_content( prompt: str, @@ -216,6 +296,7 @@ class LLMService: actual_model = LLMService._resolve_model(model_name) provider = LLMService._get_model_provider(model_name) + _start_time = time.monotonic() for attempt in range(max_retries): attempt_num = attempt + 1 @@ -223,8 +304,8 @@ class LLMService: try: if provider == 'openai': - if actual_model == 'gpt-5.2': - # Use OpenAI Responses API for GPT-5 + if actual_model == 'gpt-5.4-2026-03-05': + # Use OpenAI Responses API for gpt-5.4-2026-03-05 input_content = prompt if system_prompt: input_content = f"System: {system_prompt}\n\nUser: {prompt}" @@ -303,8 +384,9 @@ class LLMService: if attempt > 0: logger.info(f"LLM content generation succeeded on attempt {attempt_num}/{max_retries}") + await LLMService._record_usage(response, provider, actual_model, _start_time, attempt) return result - + except genai_errors.APIError as e: # Google GenAI SDK specific error handling last_error = e @@ -410,7 +492,7 @@ class LLMService: except json.JSONDecodeError as e: error_msg = f"Failed to parse JSON response: {str(e)}. Raw response: {response_text[:200]}..." - logger.error(error_msg) + logging.getLogger(__name__).error(error_msg) raise LLMServiceError(error_msg) @staticmethod @@ -531,7 +613,8 @@ class LLMService: provider = LLMService._get_model_provider(model_name) logger.info(f"Generating multimodal content with {len(image_paths)} image(s) using {provider} provider") - + _start_time = time.monotonic() + for attempt in range(max_retries): attempt_num = attempt + 1 logger.debug(f"Multimodal content generation attempt {attempt_num}/{max_retries}") @@ -564,8 +647,8 @@ class LLMService: }) logger.debug(f"Successfully loaded image for OpenAI: {image_path}") - if actual_model == 'gpt-5.2': - # Use Responses API for GPT-5.2 multimodal + if actual_model == 'gpt-5.4-2026-03-05': + # Use Responses API for gpt-5.4-2026-03-05 multimodal # Note: GPT-5 Responses API supports multimodal input input_content = [{"role": "user", "content": [{"type": "input_text", "text": prompt}]}] # Add images to the content array @@ -660,12 +743,13 @@ class LLMService: if attempt > 0: logger.info(f"Multimodal content generation succeeded on attempt {attempt_num}/{max_retries}") + await LLMService._record_usage(response, provider, actual_model, _start_time, attempt) return result - + except Exception as e: last_error = e error_message = str(e).lower() - + logger.warning(f"Multimodal attempt {attempt_num}/{max_retries} failed: {str(e)}") # Check if this is a retryable error @@ -766,7 +850,8 @@ class LLMService: max_retries = 3 last_error = None - + _start_time = time.monotonic() + for attempt in range(max_retries): attempt_num = attempt + 1 logger.debug(f"Contextual multimodal generation attempt {attempt_num}/{max_retries}") @@ -790,8 +875,8 @@ class LLMService: } }) - if actual_model == 'gpt-5.2': - # Use Responses API for GPT-5.2 contextual multimodal + if actual_model == 'gpt-5.4-2026-03-05': + # Use Responses API for gpt-5.4-2026-03-05 contextual multimodal input_content = [{"role": "user", "content": [{"type": "input_text", "text": full_prompt}]}] # Add images to the content array for img_content in image_content: @@ -875,8 +960,9 @@ class LLMService: print(f" - Result length: {len(result) if result else 0} characters") print(f" - Result preview: '{result[:200] if result else 'EMPTY'}...'") print(f" - Result repr: {repr(result[:50]) if result else 'NONE'}") + await LLMService._record_usage(response, provider, actual_model, _start_time, attempt) return result - + except genai_errors.APIError as e: # Google GenAI SDK specific error handling last_error = e diff --git a/backend/app/services/llm_usage_context.py b/backend/app/services/llm_usage_context.py new file mode 100644 index 00000000..b60b54d2 --- /dev/null +++ b/backend/app/services/llm_usage_context.py @@ -0,0 +1,54 @@ +from contextvars import ContextVar +from dataclasses import dataclass, replace +from contextlib import contextmanager +from typing import Optional + + +@dataclass(frozen=True) +class LLMCallContext: + user_id: Optional[str] = None + focus_group_id: Optional[str] = None + persona_id: Optional[str] = None + feature: str = "other" + task_id: Optional[str] = None + + +_ctx: ContextVar[LLMCallContext] = ContextVar("llm_call_context", default=LLMCallContext()) + + +def current_context() -> LLMCallContext: + return _ctx.get() + + +def set_llm_context(**overrides) -> None: + """Mutate the LLM context for the current asyncio task without cleanup. + + Use this at service entry points where the feature/focus_group_id/persona_id + should persist for the duration of the whole async call tree (including sub-awaits). + The change lives until the asyncio Task ends or is overridden again. + + Unlike llm_context(), this does NOT restore the previous value on exit — suitable + for top-level service calls, not for re-entrant helpers. + """ + prev = _ctx.get() + _ctx.set(replace(prev, **overrides)) + + +@contextmanager +def llm_context(**overrides): + """Context manager that sets LLM call attribution metadata. + + Usage: + with llm_context(user_id="abc", focus_group_id="xyz", feature="moderator"): + await LLMService.generate_content(...) + + Overrides stack — inner contexts extend (not replace) outer ones. + Safe across asyncio tasks and run_coroutine_threadsafe hops because + ContextVar inherits context on task creation / thread submission. + """ + prev = _ctx.get() + token = _ctx.set(replace(prev, **overrides)) + try: + yield + finally: + _ctx.reset(token) diff --git a/backend/app/services/persona_modification_service.py b/backend/app/services/persona_modification_service.py index f33244fd..0c6c1822 100755 --- a/backend/app/services/persona_modification_service.py +++ b/backend/app/services/persona_modification_service.py @@ -134,7 +134,7 @@ class PersonaModificationService: async def modify_persona( persona_id: str, modification_prompt: str, - llm_model: str = 'gemini-3-pro-preview', + llm_model: str = 'gemini-3.1-pro-preview', reasoning_effort: str = 'medium', verbosity: str = 'medium', max_retries: int = 3, @@ -159,6 +159,8 @@ class PersonaModificationService: PersonaModificationError: If modification fails or validation fails """ try: + from app.services.llm_usage_context import set_llm_context + set_llm_context(feature="persona_modify", persona_id=persona_id) # Fetch the original persona original_persona = await Persona.find_by_id(persona_id) if not original_persona: @@ -188,8 +190,8 @@ class PersonaModificationService: prompt=final_prompt, temperature=0.3, # Lower temperature for consistent modifications model_name=llm_model, - reasoning_effort=reasoning_effort if llm_model in ('gpt-5', 'gpt-5.2') else None, - verbosity=verbosity if llm_model in ('gpt-5', 'gpt-5.2') else None + reasoning_effort=reasoning_effort if llm_model in ('gpt-5', 'gpt-5.4-2026-03-05') else None, + verbosity=verbosity if llm_model in ('gpt-5', 'gpt-5.4-2026-03-05') else None ) # Parse JSON response diff --git a/backend/app/utils/__init__.py b/backend/app/utils/__init__.py index fb5b1702..a3c0b9ef 100755 --- a/backend/app/utils/__init__.py +++ b/backend/app/utils/__init__.py @@ -22,14 +22,54 @@ def make_serializable(obj): def admin_required(f): + """Route decorator requiring admin role. Must be stacked BELOW @jwt_required().""" @wraps(f) - def decorated_function(*args, **kwargs): + async def decorated(*args, **kwargs): user_id = get_jwt_identity() - user_data = User.find_by_id(user_id) - - if not user_data or user_data.get('role') != 'admin': + if not user_id: + return jsonify({"message": "Authentication required"}), 401 + user_data = await User.find_by_id(user_id) + if not user_data: + return jsonify({"message": "User not found"}), 404 + if user_data.get("role") != "admin": return jsonify({"message": "Admin privileges required"}), 403 + if user_data.get("is_active") is False: + return jsonify({"message": "Account disabled"}), 403 + return await f(*args, **kwargs) + return decorated - return f(*args, **kwargs) - return decorated_function \ No newline at end of file +def active_required(f): + """Route decorator that rejects requests from disabled users. + + Guards LLM-invoking routes so that revoking a user's access takes effect + immediately rather than waiting for their JWT to expire (24 h window). + Must be stacked BELOW @jwt_required(). + """ + @wraps(f) + async def decorated(*args, **kwargs): + user_id = get_jwt_identity() + if user_id: + user_data = await User.find_by_id(user_id) + if user_data and user_data.get("is_active") is False: + return jsonify({"message": "Account disabled"}), 403 + return await f(*args, **kwargs) + return decorated + + +def with_user_context(f): + """Route decorator that injects the JWT user_id into the LLM usage ContextVar. + + Must be stacked BELOW @jwt_required() so the token is already validated. + The context propagates to asyncio tasks and run_coroutine_threadsafe calls, + so autonomous AI runner conversations pick up the user attribution automatically. + """ + @wraps(f) + async def decorated(*args, **kwargs): + from app.services.llm_usage_context import llm_context + user_id = get_jwt_identity() + if user_id: + with llm_context(user_id=user_id): + return await f(*args, **kwargs) + return await f(*args, **kwargs) + return decorated diff --git a/backend/requirements.txt b/backend/requirements.txt index 0c908cc3..828d7b2a 100755 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -31,3 +31,6 @@ pillow==11.3.0 # Configuration & Utilities python-dotenv==1.1.1 + +# Token estimation (used by backfill_usage.py script) +tiktoken>=0.9.0 diff --git a/backend/scripts/seed_model_pricing.py b/backend/scripts/seed_model_pricing.py new file mode 100644 index 00000000..b011bb72 --- /dev/null +++ b/backend/scripts/seed_model_pricing.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +"""Seed model pricing for Semblance. + +Run from the backend/ directory: + source venv/bin/activate + python scripts/seed_model_pricing.py + +Idempotent — upserts on {model, effective_from}. Safe to re-run. +""" + +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dotenv import load_dotenv +load_dotenv() + +import pymongo +from datetime import datetime, timezone + +MONGO_URI = os.environ.get("MONGO_URI") +MONGO_USER = os.environ.get("MONGO_USER") +MONGO_PASS = os.environ.get("MONGO_PASS") +MONGO_HOST = os.environ.get("MONGO_HOST", "localhost") +MONGO_PORT = os.environ.get("MONGO_PORT", "27017") + +if not MONGO_URI: + if MONGO_USER and MONGO_PASS: + MONGO_URI = f"mongodb://{MONGO_USER}:{MONGO_PASS}@{MONGO_HOST}:{MONGO_PORT}/semblance_db?authSource=admin" + else: + MONGO_URI = f"mongodb://{MONGO_HOST}:{MONGO_PORT}" + +# Pricing effective from project start — covers all historical backfill +EFFECTIVE_FROM = datetime(2024, 1, 1, tzinfo=timezone.utc) + +PRICING_ROWS = [ + { + "model": "gpt-5.4-2026-03-05", + "provider": "openai", + "currency": "USD", + "tiers": [ + { + "threshold_input_tokens": 0, + "input_per_mtok": 2.50, + "cached_input_per_mtok": 0.25, + "output_per_mtok": 15.00, + "image_per_mtok": None, + } + ], + "effective_from": EFFECTIVE_FROM, + "effective_until": None, + "notes": "gpt-5.4-2026-03-05 pricing as of 2026-04", + }, + { + "model": "gemini-3.1-pro-preview", + "provider": "gemini", + "currency": "USD", + "tiers": [ + { + "threshold_input_tokens": 0, + "input_per_mtok": 2.00, + "cached_input_per_mtok": None, + "output_per_mtok": 12.00, + "image_per_mtok": None, + }, + { + "threshold_input_tokens": 200_000, + "input_per_mtok": 4.00, + "cached_input_per_mtok": None, + "output_per_mtok": 18.00, + "image_per_mtok": None, + }, + ], + "effective_from": EFFECTIVE_FROM, + "effective_until": None, + "notes": "gemini-3.1-pro-preview pricing: $2/$12 (<200k ctx), $4/$18 (>=200k ctx)", + }, +] + + +def main(): + client = pymongo.MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000) + db = client.semblance_db + + db.model_pricing.create_index( + [("model", pymongo.ASCENDING), ("effective_from", pymongo.DESCENDING)], + background=True, + ) + + for row in PRICING_ROWS: + key = {"model": row["model"], "effective_from": row["effective_from"]} + result = db.model_pricing.update_one(key, {"$set": row}, upsert=True) + action = "inserted" if result.upserted_id else "updated" + print(f" {action}: {row['model']} (effective from {row['effective_from'].date()})") + + print(f"\nDone. {len(PRICING_ROWS)} pricing rows seeded.") + client.close() + + +if __name__ == "__main__": + main() diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 00000000..a9aa6942 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,53 @@ +""" +conftest.py — sys.modules stubs so tests run without the full Docker venv. +All heavy external packages are replaced with MagicMocks before any app.* +module is imported. Individual tests patch specific methods as needed. +""" + +import os +import sys +from unittest.mock import MagicMock, AsyncMock + +# ── Fake env vars required at llm_service.py module level ───────────────────── +os.environ.setdefault('GEMINI_API_KEY', 'test-key-gemini') +os.environ.setdefault('OPENAI_API_KEY', 'test-key-openai') + + +def _stub(*names): + """Register MagicMocks under each name in sys.modules.""" + for name in names: + if name not in sys.modules: + sys.modules[name] = MagicMock() + + +# ── External packages not present in system Python ──────────────────────────── +_stub( + 'google', 'google.genai', 'google.genai.types', + 'openai', 'openai.types', 'openai.types.responses', + 'motor', 'motor.motor_asyncio', + 'pymongo', 'pymongo.errors', + 'quart', 'quart_cors', 'hypercorn', 'werkzeug', 'werkzeug.exceptions', + 'socketio', + 'bcrypt', 'jwt', 'msal', + 'bson', 'bson.objectid', + 'pydantic', + 'PIL', 'PIL.Image', + 'httpx', 'requests', + 'dotenv', + 'llama_cloud_services', +) + +# ── app.db ───────────────────────────────────────────────────────────────────── +# Any `from app.db import get_db` will capture this AsyncMock. +# Tests that need to control DB responses should patch +# `app.models..get_db` (i.e. the local binding in the module under test). + +_mock_db = MagicMock() +_mock_get_db = AsyncMock(return_value=_mock_db) + +_app_db_mod = MagicMock() +_app_db_mod.get_db = _mock_get_db +sys.modules['app.db'] = _app_db_mod + +# Expose for tests +mock_db = _mock_db diff --git a/backend/tests/test_llm_service.py b/backend/tests/test_llm_service.py index 36227b03..dff240c2 100755 --- a/backend/tests/test_llm_service.py +++ b/backend/tests/test_llm_service.py @@ -1,145 +1,128 @@ """ -Tests for the LLM Service +Tests for LLMService — covers parse_json_response (sync) and generate_structured_array. +generate_content / generate_multimodal_content are async and call real provider APIs, +so they are covered via integration tests; only pure logic is unit-tested here. """ -import unittest -from unittest.mock import patch, MagicMock -import json +import sys +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + from app.services.llm_service import LLMService, LLMServiceError -class TestLLMService(unittest.TestCase): - """Test cases for the LLM Service""" - - @patch('app.services.llm_service.genai.GenerativeModel') - def test_get_model(self, mock_generative_model): - """Test getting a Gemini model""" - # Setup mock - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - - # Test with default model - model = LLMService.get_model() - mock_generative_model.assert_called_once() - self.assertEqual(model, mock_model) - - # Reset mock - mock_generative_model.reset_mock() - - # Test with custom model - custom_model = "custom-model-name" - model = LLMService.get_model(custom_model) - mock_generative_model.assert_called_once_with(custom_model) - self.assertEqual(model, mock_model) - - @patch('app.services.llm_service.LLMService.get_model') - def test_generate_content(self, mock_get_model): - """Test generating content with the LLM""" - # Setup mock - mock_model = MagicMock() - mock_response = MagicMock() - mock_response.text = "Generated text response" - mock_model.generate_content.return_value = mock_response - mock_get_model.return_value = mock_model - - # Test with default parameters - prompt = "Test prompt" - response = LLMService.generate_content(prompt) - - mock_get_model.assert_called_once() - mock_model.generate_content.assert_called_once() - self.assertEqual(response, "Generated text response") - - # Test with custom parameters - mock_get_model.reset_mock() - mock_model.generate_content.reset_mock() - - response = LLMService.generate_content( - prompt="Custom prompt", - temperature=0.5, - max_tokens=100, - model_name="custom-model" - ) - - mock_get_model.assert_called_once_with("custom-model") - mock_model.generate_content.assert_called_once() - self.assertEqual(response, "Generated text response") - - @patch('app.services.llm_service.LLMService.get_model') - def test_generate_content_error(self, mock_get_model): - """Test error handling in generate_content""" - # Setup mock to raise an exception - mock_model = MagicMock() - mock_model.generate_content.side_effect = Exception("Model error") - mock_get_model.return_value = mock_model - - # Test error handling - with self.assertRaises(LLMServiceError) as context: - LLMService.generate_content("Test prompt") - - self.assertIn("Error generating content", str(context.exception)) - - def test_parse_json_response_valid(self): - """Test parsing valid JSON responses""" - # Test with clean JSON - clean_json = '{"key": "value", "number": 42}' - result = LLMService.parse_json_response(clean_json) - expected = {"key": "value", "number": 42} - self.assertEqual(result, expected) - - # Test with JSON in markdown code block - markdown_json = '```json\n{"key": "value", "number": 42}\n```' - result = LLMService.parse_json_response(markdown_json) - self.assertEqual(result, expected) - - # Test with JSON in generic code block - generic_code_block = '```\n{"key": "value", "number": 42}\n```' - result = LLMService.parse_json_response(generic_code_block) - self.assertEqual(result, expected) - - def test_parse_json_response_invalid(self): - """Test parsing invalid JSON responses""" - invalid_json = 'This is not JSON' - - with self.assertRaises(LLMServiceError) as context: - LLMService.parse_json_response(invalid_json) - - self.assertIn("Failed to parse JSON response", str(context.exception)) - - @patch('app.services.llm_service.LLMService.generate_content') - @patch('app.services.llm_service.LLMService.parse_json_response') - def test_generate_structured_response(self, mock_parse_json, mock_generate_content): - """Test generating a structured JSON response""" - # Setup mocks - mock_generate_content.return_value = '{"result": "success"}' - mock_parse_json.return_value = {"result": "success"} - - # Test - result = LLMService.generate_structured_response( - prompt="Generate JSON", - temperature=0.5 - ) - - mock_generate_content.assert_called_once_with( - prompt="Generate JSON", - temperature=0.5, - max_tokens=None, - model_name=None, - system_prompt=None - ) - mock_parse_json.assert_called_once_with('{"result": "success"}') - self.assertEqual(result, {"result": "success"}) - - @patch('app.services.llm_service.LLMService.generate_content') - def test_generate_structured_response_error(self, mock_generate_content): - """Test error handling in generate_structured_response""" - # Setup mock to raise an exception - mock_generate_content.side_effect = LLMServiceError("Generation failed") - - # Test error propagation - with self.assertRaises(LLMServiceError) as context: - LLMService.generate_structured_response("Generate JSON") - - self.assertEqual(str(context.exception), "Generation failed") -if __name__ == '__main__': - unittest.main() \ No newline at end of file +class TestParseJsonResponse: + def test_clean_json(self): + result = LLMService.parse_json_response('{"key": "value", "number": 42}') + assert result == {"key": "value", "number": 42} + + def test_json_in_markdown_fenced_block(self): + md = '```json\n{"key": "value", "number": 42}\n```' + assert LLMService.parse_json_response(md) == {"key": "value", "number": 42} + + def test_json_in_generic_fenced_block(self): + md = '```\n{"key": "value", "number": 42}\n```' + assert LLMService.parse_json_response(md) == {"key": "value", "number": 42} + + def test_invalid_json_raises(self): + with pytest.raises(LLMServiceError) as exc_info: + LLMService.parse_json_response("This is not JSON") + assert "Failed to parse JSON response" in str(exc_info.value) + + def test_empty_string_raises(self): + with pytest.raises(LLMServiceError): + LLMService.parse_json_response("") + + def test_json_array(self): + result = LLMService.parse_json_response('[{"a": 1}, {"b": 2}]') + assert result == [{"a": 1}, {"b": 2}] + + def test_nested_json(self): + result = LLMService.parse_json_response('{"outer": {"inner": [1, 2, 3]}}') + assert result == {"outer": {"inner": [1, 2, 3]}} + + +class TestResolveModelAndProvider: + """_resolve_model is a pure function. Provider is looked up via SUPPORTED_MODELS.""" + + def test_none_resolves_to_default(self): + assert LLMService._resolve_model(None) == "gemini-3.1-pro-preview" + + def test_all_aliases_resolve(self): + assert LLMService._resolve_model("gpt-5") == "gpt-5.4-2026-03-05" + assert LLMService._resolve_model("gpt-5.2") == "gpt-5.4-2026-03-05" + assert LLMService._resolve_model("gemini-3-pro-preview") == "gemini-3.1-pro-preview" + assert LLMService._resolve_model("gpt-4.1") == "gemini-3.1-pro-preview" + + def test_known_models_unchanged(self): + assert LLMService._resolve_model("gemini-3.1-pro-preview") == "gemini-3.1-pro-preview" + assert LLMService._resolve_model("gpt-5.4-2026-03-05") == "gpt-5.4-2026-03-05" + + def test_provider_for_gemini_model(self): + from app.services.llm_service import SUPPORTED_MODELS + assert SUPPORTED_MODELS.get("gemini-3.1-pro-preview") == "gemini" + + def test_provider_for_openai_model(self): + from app.services.llm_service import SUPPORTED_MODELS + assert SUPPORTED_MODELS.get("gpt-5.4-2026-03-05") == "openai" + + def test_unknown_model_not_in_supported(self): + from app.services.llm_service import SUPPORTED_MODELS + assert "gpt-4.1" not in SUPPORTED_MODELS # retired + + +class TestExtractUsageMetadata: + """Static — no external calls.""" + + def test_gemini_extracts_all_fields(self): + response = MagicMock() + um = MagicMock() + um.prompt_token_count = 500 + um.candidates_token_count = 100 + um.cached_content_token_count = 20 + response.usage_metadata = um + + result = LLMService._extract_usage_metadata(response, "gemini") + assert result == {"prompt": 500, "completion": 100, "cached": 20, "reasoning": 0} + + def test_openai_responses_api(self): + response = MagicMock() + usage = MagicMock() + usage.input_tokens = 1000 + usage.output_tokens = 200 + usage.input_tokens_details = MagicMock(cached_tokens=50) + usage.output_tokens_details = MagicMock(reasoning_tokens=80) + response.usage = usage + + result = LLMService._extract_usage_metadata(response, "openai") + assert result == {"prompt": 1000, "completion": 200, "cached": 50, "reasoning": 80} + + def test_openai_chat_completions(self): + response = MagicMock() + usage = MagicMock(spec=['prompt_tokens', 'completion_tokens', 'prompt_tokens_details']) + usage.prompt_tokens = 400 + usage.completion_tokens = 100 + usage.prompt_tokens_details = MagicMock(cached_tokens=10) + response.usage = usage + + result = LLMService._extract_usage_metadata(response, "openai") + assert result == {"prompt": 400, "completion": 100, "cached": 10, "reasoning": 0} + + def test_missing_usage_returns_zeros(self): + response = MagicMock() + response.usage = None + assert LLMService._extract_usage_metadata(response, "openai") == { + "prompt": 0, "completion": 0, "cached": 0, "reasoning": 0 + } + + def test_none_values_coerced_to_zero(self): + """Fields returning None should become 0, not None.""" + response = MagicMock() + um = MagicMock() + um.prompt_token_count = None + um.candidates_token_count = None + um.cached_content_token_count = None + response.usage_metadata = um + + result = LLMService._extract_usage_metadata(response, "gemini") + assert all(v == 0 for v in result.values()) diff --git a/backend/tests/test_usage_infrastructure.py b/backend/tests/test_usage_infrastructure.py new file mode 100644 index 00000000..9f81d338 --- /dev/null +++ b/backend/tests/test_usage_infrastructure.py @@ -0,0 +1,664 @@ +""" +Tests for the usage tracking infrastructure added in Phase A-C: + - LLMCallContext / llm_usage_context module + - ModelPricing.pick_tier and ModelPricing.compute_cost + - LLMService._extract_usage_metadata and LLMService._resolve_model + - UsageEvent.record (non-fatal DB error handling, field validation) + - check_quota (admin bypass, warning threshold, exceeded error) +""" + +import asyncio +import sys +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + + +# ══════════════════════════════════════════════════════════════════════════════ +# 1. LLM usage context — pure stdlib, no mocking needed +# ══════════════════════════════════════════════════════════════════════════════ + +class TestLLMUsageContext: + def setup_method(self): + # Import fresh each test to avoid cross-test ContextVar pollution + from app.services.llm_usage_context import ( + LLMCallContext, _ctx, current_context, set_llm_context, llm_context + ) + self.LLMCallContext = LLMCallContext + self._ctx = _ctx + self.current_context = current_context + self.set_llm_context = set_llm_context + self.llm_context = llm_context + # Reset context to default before each test + self._ctx.set(LLMCallContext()) + + def test_default_context(self): + ctx = self.current_context() + assert ctx.user_id is None + assert ctx.focus_group_id is None + assert ctx.persona_id is None + assert ctx.feature == "other" + assert ctx.task_id is None + + def test_set_llm_context_mutates(self): + self.set_llm_context(user_id="u1", feature="persona_generate") + ctx = self.current_context() + assert ctx.user_id == "u1" + assert ctx.feature == "persona_generate" + # Unset fields remain None + assert ctx.focus_group_id is None + + def test_set_llm_context_merges_with_existing(self): + self.set_llm_context(user_id="u1") + self.set_llm_context(focus_group_id="fg1") + ctx = self.current_context() + # Both values present + assert ctx.user_id == "u1" + assert ctx.focus_group_id == "fg1" + + def test_llm_context_restores_on_exit(self): + self.set_llm_context(user_id="outer") + with self.llm_context(user_id="inner", feature="moderator"): + assert self.current_context().user_id == "inner" + assert self.current_context().feature == "moderator" + # Restored + restored = self.current_context() + assert restored.user_id == "outer" + assert restored.feature == "other" + + def test_llm_context_restores_on_exception(self): + self.set_llm_context(user_id="before") + try: + with self.llm_context(user_id="during"): + raise ValueError("boom") + except ValueError: + pass + assert self.current_context().user_id == "before" + + def test_llm_context_stacking(self): + with self.llm_context(user_id="u1"): + with self.llm_context(focus_group_id="fg1"): + with self.llm_context(feature="key_themes"): + ctx = self.current_context() + assert ctx.user_id == "u1" + assert ctx.focus_group_id == "fg1" + assert ctx.feature == "key_themes" + assert self.current_context().feature == "other" # restored + assert self.current_context().focus_group_id is None # restored + assert self.current_context().user_id is None # restored + + def test_frozen_dataclass_immutable(self): + ctx = self.LLMCallContext(user_id="x") + with pytest.raises((AttributeError, TypeError)): + ctx.user_id = "y" # type: ignore + + +# ══════════════════════════════════════════════════════════════════════════════ +# 2. ModelPricing pure logic — pick_tier and compute_cost +# ══════════════════════════════════════════════════════════════════════════════ + +class TestModelPricingPureLogic: + def setup_method(self): + from app.models.model_pricing import ModelPricing + self.ModelPricing = ModelPricing + + def _gemini_pricing(self): + """Two-tier Gemini pricing: <200k and >=200k.""" + return { + "model": "gemini-3.1-pro-preview", + "tiers": [ + { + "threshold_input_tokens": 0, + "input_per_mtok": 2.0, + "cached_input_per_mtok": None, + "output_per_mtok": 12.0, + }, + { + "threshold_input_tokens": 200_000, + "input_per_mtok": 4.0, + "cached_input_per_mtok": None, + "output_per_mtok": 18.0, + }, + ], + } + + def _gpt_pricing(self): + """Single-tier GPT pricing.""" + return { + "model": "gpt-5.4-2026-03-05", + "tiers": [ + { + "threshold_input_tokens": 0, + "input_per_mtok": 2.50, + "cached_input_per_mtok": 0.25, + "output_per_mtok": 15.0, + } + ], + } + + # ── pick_tier ────────────────────────────────────────────────────────────── + + def test_pick_tier_below_all_thresholds(self): + pricing = self._gemini_pricing() + tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=1000) + assert tier["threshold_input_tokens"] == 0 + assert tier["input_per_mtok"] == 2.0 + + def test_pick_tier_above_high_threshold(self): + pricing = self._gemini_pricing() + tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=250_000) + assert tier["threshold_input_tokens"] == 200_000 + assert tier["input_per_mtok"] == 4.0 + + def test_pick_tier_exactly_at_threshold(self): + pricing = self._gemini_pricing() + tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=200_000) + assert tier["threshold_input_tokens"] == 200_000 + + def test_pick_tier_single_tier(self): + pricing = self._gpt_pricing() + tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=500_000) + assert tier["input_per_mtok"] == 2.50 + + def test_pick_tier_none_pricing_returns_none(self): + assert self.ModelPricing.pick_tier(None, 1000) is None + + def test_pick_tier_empty_tiers_returns_none(self): + assert self.ModelPricing.pick_tier({"tiers": []}, 1000) is None + + # ── compute_cost ────────────────────────────────────────────────────────── + + def test_compute_cost_none_pricing_returns_zeros(self): + cost = self.ModelPricing.compute_cost(None, 100, 50) + assert cost == {"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0} + + def test_compute_cost_basic_gemini_low_tier(self): + # 100k input, 10k output, no cached — stays in low tier + cost = self.ModelPricing.compute_cost( + self._gemini_pricing(), + prompt_tokens=100_000, + completion_tokens=10_000, + cached_tokens=0, + ) + expected_input = 100_000 * 2.0 / 1_000_000 # $0.20 + expected_output = 10_000 * 12.0 / 1_000_000 # $0.12 + assert abs(cost["input"] - expected_input) < 1e-9 + assert cost["cached"] == 0.0 + assert abs(cost["output"] - expected_output) < 1e-9 + assert abs(cost["total"] - (expected_input + expected_output)) < 1e-9 + + def test_compute_cost_basic_gemini_high_tier(self): + # 300k input → high tier: $4/$18 + cost = self.ModelPricing.compute_cost( + self._gemini_pricing(), + prompt_tokens=300_000, + completion_tokens=5_000, + ) + expected_input = 300_000 * 4.0 / 1_000_000 + expected_output = 5_000 * 18.0 / 1_000_000 + assert abs(cost["input"] - expected_input) < 1e-9 + assert abs(cost["output"] - expected_output) < 1e-9 + + def test_compute_cost_with_cached_tokens_gpt(self): + # 10k input, 2k cached (cheaper), 5k output + cost = self.ModelPricing.compute_cost( + self._gpt_pricing(), + prompt_tokens=10_000, + completion_tokens=5_000, + cached_tokens=2_000, + ) + # Billable input = 10k - 2k = 8k at $2.50/M + expected_input = 8_000 * 2.50 / 1_000_000 + # Cached 2k at $0.25/M + expected_cached = 2_000 * 0.25 / 1_000_000 + # Output 5k at $15/M + expected_output = 5_000 * 15.0 / 1_000_000 + expected_total = expected_input + expected_cached + expected_output + + assert abs(cost["input"] - expected_input) < 1e-9 + assert abs(cost["cached"] - expected_cached) < 1e-9 + assert abs(cost["output"] - expected_output) < 1e-9 + assert abs(cost["total"] - expected_total) < 1e-9 + + def test_compute_cost_total_equals_sum_of_components(self): + cost = self.ModelPricing.compute_cost( + self._gpt_pricing(), + prompt_tokens=50_000, + completion_tokens=20_000, + cached_tokens=5_000, + ) + assert abs(cost["total"] - (cost["input"] + cost["cached"] + cost["output"])) < 1e-9 + + +# ══════════════════════════════════════════════════════════════════════════════ +# 3. LLMService._extract_usage_metadata and _resolve_model +# ══════════════════════════════════════════════════════════════════════════════ + +class TestLLMServicePureStaticMethods: + def setup_method(self): + from app.services.llm_service import LLMService, MODEL_ALIASES, DEFAULT_MODEL + self.LLMService = LLMService + self.MODEL_ALIASES = MODEL_ALIASES + self.DEFAULT_MODEL = DEFAULT_MODEL + + # ── _resolve_model ──────────────────────────────────────────────────────── + + def test_resolve_model_none_returns_default(self): + assert self.LLMService._resolve_model(None) == self.DEFAULT_MODEL + + def test_resolve_model_default_unchanged(self): + assert self.LLMService._resolve_model("gemini-3.1-pro-preview") == "gemini-3.1-pro-preview" + + def test_resolve_model_gpt5_alias(self): + assert self.LLMService._resolve_model("gpt-5") == "gpt-5.4-2026-03-05" + + def test_resolve_model_gpt52_alias(self): + assert self.LLMService._resolve_model("gpt-5.2") == "gpt-5.4-2026-03-05" + + def test_resolve_model_gemini3_alias(self): + assert self.LLMService._resolve_model("gemini-3-pro-preview") == "gemini-3.1-pro-preview" + + def test_resolve_model_gpt41_retired_falls_to_gemini(self): + assert self.LLMService._resolve_model("gpt-4.1") == "gemini-3.1-pro-preview" + + def test_resolve_model_known_openai_unchanged(self): + assert self.LLMService._resolve_model("gpt-5.4-2026-03-05") == "gpt-5.4-2026-03-05" + + def test_resolve_model_unknown_passthrough(self): + # Unknown model names should pass through untouched + assert self.LLMService._resolve_model("some-future-model") == "some-future-model" + + # ── _extract_usage_metadata ─────────────────────────────────────────────── + + def test_extract_gemini_full(self): + response = MagicMock() + um = MagicMock() + um.prompt_token_count = 1000 + um.candidates_token_count = 200 + um.cached_content_token_count = 50 + response.usage_metadata = um + + result = self.LLMService._extract_usage_metadata(response, "gemini") + assert result == {"prompt": 1000, "completion": 200, "cached": 50, "reasoning": 0} + + def test_extract_gemini_missing_usage_metadata(self): + response = MagicMock(spec=[]) # no attributes + result = self.LLMService._extract_usage_metadata(response, "gemini") + assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0} + + def test_extract_gemini_none_usage_metadata(self): + response = MagicMock() + response.usage_metadata = None + result = self.LLMService._extract_usage_metadata(response, "gemini") + assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0} + + def test_extract_openai_responses_api(self): + """OpenAI Responses API format (gpt-5.4-2026-03-05).""" + response = MagicMock() + usage = MagicMock() + usage.input_tokens = 5000 + usage.output_tokens = 1000 + input_details = MagicMock() + input_details.cached_tokens = 200 + output_details = MagicMock() + output_details.reasoning_tokens = 300 + usage.input_tokens_details = input_details + usage.output_tokens_details = output_details + response.usage = usage + + result = self.LLMService._extract_usage_metadata(response, "openai") + assert result == {"prompt": 5000, "completion": 1000, "cached": 200, "reasoning": 300} + + def test_extract_openai_chat_completions(self): + """OpenAI Chat Completions format (legacy).""" + response = MagicMock() + usage = MagicMock(spec=['prompt_tokens', 'completion_tokens', 'prompt_tokens_details']) + usage.prompt_tokens = 3000 + usage.completion_tokens = 800 + details = MagicMock() + details.cached_tokens = 100 + usage.prompt_tokens_details = details + response.usage = usage + + result = self.LLMService._extract_usage_metadata(response, "openai") + assert result == {"prompt": 3000, "completion": 800, "cached": 100, "reasoning": 0} + + def test_extract_openai_missing_usage(self): + response = MagicMock() + response.usage = None + result = self.LLMService._extract_usage_metadata(response, "openai") + assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0} + + def test_extract_unknown_provider_returns_zeros(self): + result = self.LLMService._extract_usage_metadata(MagicMock(), "anthropic") + assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0} + + def test_extract_gemini_token_counts_default_to_zero_when_none(self): + """Fields that return None should be coerced to 0, not None.""" + response = MagicMock() + um = MagicMock() + um.prompt_token_count = None + um.candidates_token_count = None + um.cached_content_token_count = None + response.usage_metadata = um + + result = self.LLMService._extract_usage_metadata(response, "gemini") + assert result["prompt"] == 0 + assert result["completion"] == 0 + assert result["cached"] == 0 + + +# ══════════════════════════════════════════════════════════════════════════════ +# 4. UsageEvent.record — non-fatal DB error handling, field coercion +# ══════════════════════════════════════════════════════════════════════════════ + +class TestUsageEventRecord: + def setup_method(self): + # Reset module-level imports for clean state + pass + + @pytest.mark.asyncio + async def test_record_normalizes_unknown_feature(self): + """Unknown feature names should be stored as 'other'.""" + captured = {} + + async def fake_insert(doc): + captured['doc'] = doc + + mock_db = MagicMock() + mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert) + + with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)): + from app.models.usage_event import UsageEvent + await UsageEvent.record( + provider="gemini", + model="gemini-3.1-pro-preview", + prompt_tokens=100, + completion_tokens=50, + cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}, + feature="totally_invalid_feature_xyz", + ) + + assert captured['doc']['feature'] == "other" + + @pytest.mark.asyncio + async def test_record_valid_feature_preserved(self): + captured = {} + + async def fake_insert(doc): + captured['doc'] = doc + + mock_db = MagicMock() + mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert) + + with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)): + from app.models.usage_event import UsageEvent + await UsageEvent.record( + provider="gemini", + model="gemini-3.1-pro-preview", + prompt_tokens=100, + completion_tokens=50, + cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}, + feature="key_themes", + ) + + assert captured['doc']['feature'] == "key_themes" + + @pytest.mark.asyncio + async def test_record_total_tokens_computed(self): + captured = {} + + async def fake_insert(doc): + captured['doc'] = doc + + mock_db = MagicMock() + mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert) + + with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)): + from app.models.usage_event import UsageEvent + await UsageEvent.record( + provider="gemini", + model="gemini-3.1-pro-preview", + prompt_tokens=300, + completion_tokens=150, + cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}, + ) + + assert captured['doc']['total_tokens'] == 450 + + @pytest.mark.asyncio + async def test_record_error_truncated_to_500_chars(self): + captured = {} + + async def fake_insert(doc): + captured['doc'] = doc + + mock_db = MagicMock() + mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert) + + long_error = "E" * 1000 + + with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)): + from app.models.usage_event import UsageEvent + await UsageEvent.record( + provider="gemini", + model="gemini-3.1-pro-preview", + prompt_tokens=10, + completion_tokens=5, + cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}, + status="error", + error=long_error, + ) + + assert len(captured['doc']['error']) == 500 + + @pytest.mark.asyncio + async def test_record_db_error_is_swallowed(self): + """A DB write failure must not propagate — telemetry cannot kill LLM calls.""" + mock_db = MagicMock() + mock_db.usage_events.insert_one = AsyncMock(side_effect=RuntimeError("DB down")) + + with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)): + from app.models.usage_event import UsageEvent + # Should not raise + await UsageEvent.record( + provider="gemini", + model="gemini-3.1-pro-preview", + prompt_tokens=10, + completion_tokens=5, + cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}, + ) + + +# ══════════════════════════════════════════════════════════════════════════════ +# 5. check_quota — admin bypass, warning threshold, exceeded error +# ══════════════════════════════════════════════════════════════════════════════ + +class TestCheckQuota: + """ + check_quota does lazy imports inside the function body: + from app.models.user import User + from app.models.usage_event import UsageEvent + from app.models.focus_group import FocusGroup + + We inject mocks via patch.dict(sys.modules) so that when the function + executes those import statements at runtime it gets our mock objects. + """ + + def _make_user_mod(self, user_doc): + m = MagicMock() + m.User.find_by_id = AsyncMock(return_value=user_doc) + return m + + def _make_usage_mod(self, total): + m = MagicMock() + m.UsageEvent.sum_cost = AsyncMock(return_value=total) + return m + + def _make_fg_mod(self, fg_doc): + m = MagicMock() + m.FocusGroup.find_by_id = AsyncMock(return_value=fg_doc) + return m + + @pytest.mark.asyncio + async def test_no_user_no_fg_returns_none(self): + from app.models.quota import check_quota + result = await check_quota(user_id=None, focus_group_id=None) + assert result is None + + @pytest.mark.asyncio + async def test_admin_bypasses_user_quota(self): + user_mod = self._make_user_mod( + {"role": "admin", "quota": {"monthly_usd": 0.01}, "override_quota": False} + ) + usage_mod = self._make_usage_mod(999.0) + + with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): + from app.models.quota import check_quota + result = await check_quota(user_id="admin_id", focus_group_id=None) + + assert result is None + + @pytest.mark.asyncio + async def test_override_quota_bypasses_user_quota(self): + user_mod = self._make_user_mod( + {"role": "user", "quota": {"monthly_usd": 1.0}, "override_quota": True} + ) + usage_mod = self._make_usage_mod(999.0) + + with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): + from app.models.quota import check_quota + result = await check_quota(user_id="u1", focus_group_id=None) + + assert result is None + + @pytest.mark.asyncio + async def test_no_quota_configured_returns_none(self): + user_mod = self._make_user_mod({"role": "user", "quota": {}, "override_quota": False}) + usage_mod = self._make_usage_mod(10.0) + + with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): + from app.models.quota import check_quota + result = await check_quota(user_id="u1", focus_group_id=None) + + assert result is None + + @pytest.mark.asyncio + async def test_under_quota_returns_none(self): + user_mod = self._make_user_mod( + {"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False} + ) + usage_mod = self._make_usage_mod(20.0) + + with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): + from app.models.quota import check_quota + result = await check_quota(user_id="u1", focus_group_id=None) + + assert result is None + + @pytest.mark.asyncio + async def test_quota_at_80_percent_returns_warning(self): + user_mod = self._make_user_mod( + {"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False} + ) + usage_mod = self._make_usage_mod(42.0) # 84% + + with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): + from app.models.quota import check_quota, QuotaWarning + result = await check_quota(user_id="u1", focus_group_id=None) + + assert isinstance(result, QuotaWarning) + assert result.scope == "user" + assert result.limit_usd == 50.0 + assert result.used_usd == 42.0 + assert result.pct == pytest.approx(0.84, abs=0.001) + + @pytest.mark.asyncio + async def test_quota_exceeded_raises_error(self): + user_mod = self._make_user_mod( + {"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False} + ) + usage_mod = self._make_usage_mod(50.01) + + with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): + from app.models.quota import check_quota, QuotaExceededError + with pytest.raises(QuotaExceededError) as exc_info: + await check_quota(user_id="u1", focus_group_id=None) + + err = exc_info.value + assert err.scope == "user" + assert err.limit_usd == 50.0 + assert err.used_usd == pytest.approx(50.01) + + @pytest.mark.asyncio + async def test_fg_quota_exceeded_raises_error(self): + fg_mod = self._make_fg_mod({"quota": {"total_usd": 10.0}}) + usage_mod = self._make_usage_mod(10.01) + + with patch.dict(sys.modules, {'app.models.focus_group': fg_mod, 'app.models.usage_event': usage_mod}): + from app.models.quota import check_quota, QuotaExceededError + with pytest.raises(QuotaExceededError) as exc_info: + await check_quota(user_id=None, focus_group_id="fg1") + + assert exc_info.value.scope == "focus_group" + + @pytest.mark.asyncio + async def test_fg_quota_at_80_percent_returns_warning(self): + fg_mod = self._make_fg_mod({"quota": {"total_usd": 100.0}}) + usage_mod = self._make_usage_mod(85.0) + + with patch.dict(sys.modules, {'app.models.focus_group': fg_mod, 'app.models.usage_event': usage_mod}): + from app.models.quota import check_quota, QuotaWarning + result = await check_quota(user_id=None, focus_group_id="fg1") + + assert isinstance(result, QuotaWarning) + assert result.scope == "focus_group" + + @pytest.mark.asyncio + async def test_db_error_in_quota_check_is_non_fatal(self): + """If User.find_by_id raises, quota check swallows and allows the call.""" + user_mod = MagicMock() + user_mod.User.find_by_id = AsyncMock(side_effect=RuntimeError("DB timeout")) + + with patch.dict(sys.modules, {'app.models.user': user_mod}): + from app.models.quota import check_quota + result = await check_quota(user_id="u1", focus_group_id=None) + + assert result is None + + @pytest.mark.asyncio + async def test_exactly_at_quota_limit_is_exceeded(self): + """Spending exactly the limit should trigger QuotaExceededError (>=).""" + user_mod = self._make_user_mod( + {"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False} + ) + usage_mod = self._make_usage_mod(50.0) # exactly at limit + + with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): + from app.models.quota import check_quota, QuotaExceededError + with pytest.raises(QuotaExceededError): + await check_quota(user_id="u1", focus_group_id=None) + + +# ══════════════════════════════════════════════════════════════════════════════ +# 6. QuotaExceededError — message and attributes +# ══════════════════════════════════════════════════════════════════════════════ + +class TestQuotaExceededError: + def test_attributes(self): + from app.models.quota import QuotaExceededError + from datetime import datetime, timezone + + period = datetime(2026, 4, 1, tzinfo=timezone.utc) + err = QuotaExceededError("user", 50.0, 52.34, period) + + assert err.scope == "user" + assert err.limit_usd == 50.0 + assert err.used_usd == 52.34 + assert err.period_start == period + assert "user" in str(err) + assert "52.3400" in str(err) + assert "50.00" in str(err) + + def test_is_exception(self): + from app.models.quota import QuotaExceededError + assert issubclass(QuotaExceededError, Exception) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..0715caa0 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +asyncio_mode = auto +testpaths = backend/tests