cohorta/backend/app/models/model_pricing.py
Vadym Samoilenko 3e9ccafad2 Add LLM usage tracking infrastructure (Phases A-C)
- Model renames: gpt-5.2 → gpt-5.4-2026-03-05, gemini-3-pro-preview → gemini-3.1-pro-preview; retire gpt-4.1 via alias fallback
- New: llm_usage_context.py (ContextVar-based attribution), model_pricing.py (tiered pricing + 60s cache), usage_event.py (append-only telemetry), quota.py (user/FG quota enforcement with 80% warning)
- Wire _record_usage into all 3 LLM methods; set_llm_context at every service entry point
- Fix admin_required decorator (was sync, never awaited User.find_by_id); add active_required and with_user_context decorators
- Inject user_id into ContextVar from JWT on every authenticated request
- Add DB indexes for usage_events, model_pricing, users collections
- Seed script for model pricing (gpt-5.4 single-tier, gemini-3.1 two-tier 200k threshold)
- Fix parse_json_response NameError (logger undefined at module level)
- 70 passing tests: conftest.py with sys.modules stubs, test_usage_infrastructure.py (52 tests), rewrite stale test_llm_service.py (18 tests)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-24 18:08:27 +01:00

104 lines
3.7 KiB
Python

from app.db import get_db
from datetime import datetime, timezone
import logging
import time
logger = logging.getLogger(__name__)
# In-process cache: (model_name -> (pricing_dict, cached_at_monotonic))
_pricing_cache: dict = {}
_CACHE_TTL_SECONDS = 60
def _cache_get(model: str):
entry = _pricing_cache.get(model)
if entry and (time.monotonic() - entry[1]) < _CACHE_TTL_SECONDS:
return entry[0]
return None
def _cache_set(model: str, pricing: dict):
_pricing_cache[model] = (pricing, time.monotonic())
class ModelPricing:
@staticmethod
async def current_for(model_name: str) -> dict | None:
"""Return the active pricing row for a model, with 60 s in-process cache.
Resolves MODEL_ALIASES before lookup so callers can pass raw model names.
Returns None if no pricing is configured (cost will be recorded as 0).
"""
from app.services.llm_service import MODEL_ALIASES
resolved = MODEL_ALIASES.get(model_name, model_name)
cached = _cache_get(resolved)
if cached is not None:
return cached
try:
db = await get_db()
now = datetime.now(timezone.utc)
doc = await db.model_pricing.find_one(
{
"model": resolved,
"effective_from": {"$lte": now},
"$or": [
{"effective_until": None},
{"effective_until": {"$gt": now}},
],
},
sort=[("effective_from", -1)],
)
_cache_set(resolved, doc)
return doc
except Exception:
logger.warning(f"Failed to fetch pricing for model {resolved}", exc_info=True)
return None
@staticmethod
def pick_tier(pricing: dict, prompt_tokens: int) -> dict | None:
"""Return the cost tier that applies for a given prompt token count."""
if not pricing:
return None
tiers = pricing.get("tiers") or []
if not tiers:
return None
# Pick the tier with the largest threshold still <= prompt_tokens
applicable = [t for t in tiers if t.get("threshold_input_tokens", 0) <= prompt_tokens]
if not applicable:
applicable = tiers # fall back to first tier
return max(applicable, key=lambda t: t.get("threshold_input_tokens", 0))
@staticmethod
def compute_cost(pricing: dict | None, prompt_tokens: int, completion_tokens: int,
cached_tokens: int = 0) -> dict:
"""Compute cost breakdown from token counts and pricing doc.
Returns a dict with keys: input, cached, output, total (all USD floats).
All values are 0.0 if pricing is None.
"""
zero = {"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}
if not pricing:
return zero
tier = ModelPricing.pick_tier(pricing, prompt_tokens)
if not tier:
return zero
input_per_mtok = tier.get("input_per_mtok") or 0
cached_per_mtok = tier.get("cached_input_per_mtok") or 0
output_per_mtok = tier.get("output_per_mtok") or 0
billable_input = max(0, prompt_tokens - cached_tokens)
cost_input = billable_input * input_per_mtok / 1_000_000
cost_cached = cached_tokens * cached_per_mtok / 1_000_000
cost_output = completion_tokens * output_per_mtok / 1_000_000
cost_total = cost_input + cost_cached + cost_output
return {
"input": round(cost_input, 8),
"cached": round(cost_cached, 8),
"output": round(cost_output, 8),
"total": round(cost_total, 8),
}