Fix backfill: use accumulated conversation context for prompt estimation
Old logic used output text length as a proxy for prompt tokens — completely wrong. Real Gemini calls send the full conversation history as context, so prompt grows with every turn. New logic: - completion_tokens = len(response_text) / 3.8 (what was generated) - prompt_tokens = base_template + sum(all_prior_messages_in_fg) / 3.8 - persona_response base: 1500 tok (template + persona details + topic) - moderator base: 1200 tok (moderator template + fg context) - persona_generate base: 2500 tok (persona-detailed-generation.md template) Also: - Sorts messages chronologically per focus group before processing - Accumulates context correctly so turn N includes turns 0..N-1 as context - Idempotency via pre-fetched set instead of per-doc find_one queries - cost_usd breakdown now has correct input/output split (not 40/60 guess) - Dry-run prints per-focus-group cost estimates for sanity checking Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
57508e8e55
commit
d0ad8e67be
1 changed files with 205 additions and 107 deletions
|
|
@ -5,11 +5,14 @@ Backfill usage_events from existing focus-group messages and personas.
|
|||
Creates estimated usage_event docs (is_estimated=True) so the admin dashboard
|
||||
can show historical cost data for sessions that pre-date the usage tracking system.
|
||||
|
||||
Idempotent: skips documents that already have an estimated event in the collection.
|
||||
Token estimation approach:
|
||||
- completion = actual output text length / 3.8 chars-per-token
|
||||
- prompt = base template size + ALL prior messages in conversation (accumulated context)
|
||||
This mirrors the real LLM call: each turn sends the full conversation history.
|
||||
|
||||
Usage:
|
||||
cd backend
|
||||
python scripts/backfill_usage.py [--dry-run]
|
||||
python scripts/backfill_usage.py [--dry-run] [--delete-existing-estimates]
|
||||
|
||||
Environment:
|
||||
MONGO_URI — connection string (falls back to localhost:27017 without auth)
|
||||
|
|
@ -19,36 +22,47 @@ Environment:
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from pymongo import MongoClient
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Token estimation helpers
|
||||
# Prompt template size constants (measured from actual files in backend/prompts/)
|
||||
# These are the BASE tokens before any dynamic content is added.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _estimate_tokens(text: str, model: str) -> dict:
|
||||
"""Estimate prompt/completion tokens for a piece of text."""
|
||||
if not text:
|
||||
return {"prompt": 0, "completion": 0}
|
||||
# focus-group-response.md (~941 tok) + persona details (~350 tok) + topic/instructions (~200 tok)
|
||||
BASE_PROMPT_PERSONA_RESPONSE = 1_500
|
||||
# ai-moderator-system.md (~738 tok) + focus group context (~500 tok)
|
||||
BASE_PROMPT_MODERATOR = 1_200
|
||||
# persona-detailed-generation.md (~2307 tok) + focus group brief (~200 tok)
|
||||
BASE_PROMPT_PERSONA_GENERATE = 2_500
|
||||
|
||||
if model and ("gpt" in model.lower() or "openai" in model.lower()):
|
||||
try:
|
||||
import tiktoken
|
||||
enc = tiktoken.encoding_for_model("gpt-4")
|
||||
n = len(enc.encode(text))
|
||||
return {"prompt": n, "completion": 0}
|
||||
except Exception:
|
||||
pass
|
||||
CHARS_PER_TOKEN = 3.8 # Gemini approximation
|
||||
|
||||
# Gemini / unknown: ~3.8 chars per token
|
||||
n = max(1, int(len(text) / 3.8))
|
||||
return {"prompt": n, "completion": 0}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Token helpers
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _chars_to_tokens(chars: int) -> int:
|
||||
return max(1, int(chars / CHARS_PER_TOKEN))
|
||||
|
||||
|
||||
def _to_str(v) -> str:
|
||||
if isinstance(v, list):
|
||||
return " ".join(str(i) for i in v if i)
|
||||
return str(v) if v else ""
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Pricing
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
_pricing_cache: dict = {}
|
||||
|
||||
|
||||
def _load_pricing(db) -> None:
|
||||
"""Load current pricing from model_pricing collection into cache."""
|
||||
for row in db.model_pricing.find({"effective_until": None}):
|
||||
model = row.get("model", "")
|
||||
tiers = row.get("tiers") or []
|
||||
|
|
@ -60,26 +74,31 @@ def _load_pricing(db) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _estimate_cost(prompt_tokens: int, completion_tokens: int, model: str) -> float:
|
||||
"""Cost estimate in USD using model_pricing collection rates."""
|
||||
# Try exact match, then prefix match
|
||||
def _estimate_cost(prompt_tokens: int, completion_tokens: int, model: str) -> dict:
|
||||
rates = _pricing_cache.get(model)
|
||||
if not rates:
|
||||
for key, val in _pricing_cache.items():
|
||||
if model and key and (key in model or model in key):
|
||||
rates = val
|
||||
break
|
||||
# Final fallback matching seed_model_pricing.py values
|
||||
if not rates:
|
||||
m = (model or "").lower()
|
||||
if "gpt-5" in m or "gpt-4" in m:
|
||||
rates = (2.50, 15.00) # gpt-5.4 pricing from seed
|
||||
rates = (2.50, 15.00)
|
||||
else:
|
||||
rates = (2.00, 12.00) # gemini-3.1-pro-preview pricing from seed
|
||||
rates = (2.00, 12.00)
|
||||
|
||||
input_rate, output_rate = rates
|
||||
cost = (prompt_tokens / 1_000_000) * input_rate + (completion_tokens / 1_000_000) * output_rate
|
||||
return round(cost, 8)
|
||||
cost_input = (prompt_tokens / 1_000_000) * input_rate
|
||||
cost_output = (completion_tokens / 1_000_000) * output_rate
|
||||
total = round(cost_input + cost_output, 8)
|
||||
return {
|
||||
"input": round(cost_input, 8),
|
||||
"output": round(cost_output, 8),
|
||||
"cached": 0.0,
|
||||
"reasoning": 0.0,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
|
@ -101,14 +120,20 @@ def connect():
|
|||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Backfill focus-group messages
|
||||
# Messages are in the separate `focus_group_messages` collection (NOT embedded).
|
||||
# Fields: focus_group_id (str), text, type, senderId, created_at
|
||||
#
|
||||
# Real prompt structure per call:
|
||||
# system prompt template (~1200-1500 tok) + all prior messages (accumulated)
|
||||
# Real completion:
|
||||
# the response text
|
||||
#
|
||||
# We sort messages per focus group by timestamp and accumulate context,
|
||||
# so that message N has all N-1 prior messages as context — matching reality.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def backfill_messages(db, dry_run: bool) -> int:
|
||||
created = 0
|
||||
|
||||
# Build a lookup: focus_group_id -> {llm_model, user_id}
|
||||
# Build focus-group metadata lookup
|
||||
fg_meta = {}
|
||||
for fg in db.focus_groups.find({}, {"llm_model": 1, "created_by": 1}):
|
||||
fg_meta[str(fg["_id"])] = {
|
||||
|
|
@ -116,72 +141,107 @@ def backfill_messages(db, dry_run: bool) -> int:
|
|||
"user_id": str(fg.get("created_by") or ""),
|
||||
}
|
||||
|
||||
total_messages = db.focus_group_messages.count_documents({})
|
||||
print(f"\n[messages] Found {total_messages} messages across all focus groups")
|
||||
# Collect all messages, group by focus_group_id, sort chronologically
|
||||
all_messages = list(db.focus_group_messages.find({}))
|
||||
print(f"\n[messages] Found {len(all_messages)} messages across all focus groups")
|
||||
|
||||
for msg in db.focus_group_messages.find({}):
|
||||
msg_id = str(msg["_id"])
|
||||
# Bucket by focus group
|
||||
by_fg: dict = defaultdict(list)
|
||||
for msg in all_messages:
|
||||
fg_id = str(msg.get("focus_group_id") or "")
|
||||
|
||||
# Skip non-AI messages (only persona responses and moderator questions cost money)
|
||||
msg_type = msg.get("type", "")
|
||||
# Only AI-generated messages cost money
|
||||
if msg_type not in ("response", "question", "moderator", "ai", ""):
|
||||
continue
|
||||
by_fg[fg_id].append(msg)
|
||||
|
||||
# Idempotent check
|
||||
if db.usage_events.find_one({"source_message_id": msg_id, "is_estimated": True}):
|
||||
continue
|
||||
# Sort each group chronologically
|
||||
def _ts(m):
|
||||
t = m.get("created_at") or m.get("timestamp")
|
||||
if isinstance(t, str):
|
||||
try:
|
||||
return datetime.fromisoformat(t)
|
||||
except Exception:
|
||||
pass
|
||||
if isinstance(t, datetime):
|
||||
return t
|
||||
return datetime.min
|
||||
|
||||
for fg_id, msgs in by_fg.items():
|
||||
msgs.sort(key=_ts)
|
||||
|
||||
# Already-estimated message IDs (for idempotency)
|
||||
existing_ids = set(
|
||||
str(e["source_message_id"])
|
||||
for e in db.usage_events.find(
|
||||
{"is_estimated": True, "source_message_id": {"$exists": True}},
|
||||
{"source_message_id": 1}
|
||||
)
|
||||
)
|
||||
|
||||
for fg_id, msgs in by_fg.items():
|
||||
meta = fg_meta.get(fg_id, {"model": "gemini-3.1-pro-preview", "user_id": ""})
|
||||
fg_model = meta["model"]
|
||||
user_id = meta["user_id"]
|
||||
provider = "gemini" if "gemini" in fg_model.lower() else "openai"
|
||||
|
||||
text = msg.get("text") or msg.get("content") or ""
|
||||
tokens = _estimate_tokens(text, fg_model)
|
||||
tokens["completion"] = max(1, int(len(text) / 5.0))
|
||||
cost = _estimate_cost(tokens["prompt"], tokens["completion"], fg_model)
|
||||
accumulated_context_chars = 0 # sum of all prior message text lengths
|
||||
|
||||
ts = msg.get("created_at") or msg.get("timestamp")
|
||||
if isinstance(ts, str):
|
||||
try:
|
||||
ts = datetime.fromisoformat(ts)
|
||||
except Exception:
|
||||
ts = None
|
||||
ts = ts or datetime.now(timezone.utc)
|
||||
for msg in msgs:
|
||||
msg_id = str(msg["_id"])
|
||||
if msg_id in existing_ids:
|
||||
# Still accumulate context so subsequent messages are correct
|
||||
text = msg.get("text") or msg.get("content") or ""
|
||||
accumulated_context_chars += len(text)
|
||||
continue
|
||||
|
||||
feature = "moderator" if msg_type in ("question", "moderator") else "persona_response"
|
||||
text = msg.get("text") or msg.get("content") or ""
|
||||
msg_type = msg.get("type", "")
|
||||
|
||||
event = {
|
||||
"ts": ts,
|
||||
"provider": "gemini" if "gemini" in fg_model.lower() else "openai",
|
||||
"model": fg_model,
|
||||
"feature": feature,
|
||||
"user_id": user_id,
|
||||
"focus_group_id": fg_id,
|
||||
"persona_id": str(msg.get("senderId") or msg.get("persona_id") or ""),
|
||||
"prompt_tokens": tokens["prompt"],
|
||||
"completion_tokens": tokens["completion"],
|
||||
"cached_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
"total_tokens": tokens["prompt"] + tokens["completion"],
|
||||
"cost_usd": {
|
||||
"input": round(cost * 0.4, 8),
|
||||
"output": round(cost * 0.6, 8),
|
||||
"cached": 0,
|
||||
"reasoning": 0,
|
||||
"total": cost,
|
||||
},
|
||||
"duration_ms": 0,
|
||||
"retry_count": 0,
|
||||
"status": "success",
|
||||
"is_estimated": True,
|
||||
"estimate_method": "char_div_3_8",
|
||||
"source_message_id": msg_id,
|
||||
}
|
||||
# completion = what the model actually generated
|
||||
completion_tokens = _chars_to_tokens(len(text))
|
||||
|
||||
if not dry_run:
|
||||
db.usage_events.insert_one(event)
|
||||
created += 1
|
||||
# prompt = base template + full conversation history up to this point
|
||||
context_tokens = _chars_to_tokens(accumulated_context_chars)
|
||||
if msg_type in ("question", "moderator"):
|
||||
prompt_tokens = BASE_PROMPT_MODERATOR + context_tokens
|
||||
else:
|
||||
prompt_tokens = BASE_PROMPT_PERSONA_RESPONSE + context_tokens
|
||||
|
||||
cost = _estimate_cost(prompt_tokens, completion_tokens, fg_model)
|
||||
|
||||
ts = _ts(msg) or datetime.now(timezone.utc)
|
||||
|
||||
feature = "moderator" if msg_type in ("question", "moderator") else "persona_response"
|
||||
|
||||
event = {
|
||||
"ts": ts,
|
||||
"provider": provider,
|
||||
"model": fg_model,
|
||||
"feature": feature,
|
||||
"user_id": user_id,
|
||||
"focus_group_id": fg_id,
|
||||
"persona_id": str(msg.get("senderId") or msg.get("persona_id") or ""),
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"cached_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
"cost_usd": cost,
|
||||
"duration_ms": 0,
|
||||
"retry_count": 0,
|
||||
"status": "success",
|
||||
"is_estimated": True,
|
||||
"estimate_method": "accumulated_context",
|
||||
"source_message_id": msg_id,
|
||||
}
|
||||
|
||||
if not dry_run:
|
||||
db.usage_events.insert_one(event)
|
||||
created += 1
|
||||
|
||||
# Add this message to the accumulated context for subsequent messages
|
||||
accumulated_context_chars += len(text)
|
||||
|
||||
print(f"[messages] {'Would create' if dry_run else 'Created'} {created} estimated usage events")
|
||||
return created
|
||||
|
|
@ -189,7 +249,9 @@ def backfill_messages(db, dry_run: bool) -> int:
|
|||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Backfill persona generation
|
||||
# Personas: fields background, description, name; created_by = user_id
|
||||
#
|
||||
# Real prompt: persona-detailed-generation.md template (~2307 tok) + fg brief (~200 tok)
|
||||
# Real completion: the generated persona profile text
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def backfill_personas(db, dry_run: bool) -> int:
|
||||
|
|
@ -197,32 +259,37 @@ def backfill_personas(db, dry_run: bool) -> int:
|
|||
personas = list(db.personas.find({}))
|
||||
print(f"\n[personas] Found {len(personas)} personas to process")
|
||||
|
||||
existing_persona_ids = set(
|
||||
str(e["source_persona_id"])
|
||||
for e in db.usage_events.find(
|
||||
{"is_estimated": True, "source_persona_id": {"$exists": True}, "feature": "persona_generate"},
|
||||
{"source_persona_id": 1}
|
||||
)
|
||||
)
|
||||
|
||||
for persona in personas:
|
||||
persona_id = str(persona["_id"])
|
||||
if persona_id in existing_persona_ids:
|
||||
continue
|
||||
|
||||
def _to_str(v):
|
||||
if isinstance(v, list):
|
||||
return " ".join(str(i) for i in v if i)
|
||||
return str(v) if v else ""
|
||||
|
||||
# Use background + description as the generation text
|
||||
# The generated output is the persona profile text
|
||||
text = " ".join(filter(None, [
|
||||
_to_str(persona.get("background")),
|
||||
_to_str(persona.get("description")),
|
||||
_to_str(persona.get("goals")),
|
||||
_to_str(persona.get("name")),
|
||||
])).strip()
|
||||
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# Idempotent check
|
||||
if db.usage_events.find_one({"source_persona_id": persona_id, "feature": "persona_generate", "is_estimated": True}):
|
||||
continue
|
||||
|
||||
model = "gemini-3.1-pro-preview"
|
||||
tokens = _estimate_tokens(text, model)
|
||||
tokens["completion"] = max(1, int(len(text) / 4.0))
|
||||
cost = _estimate_cost(tokens["prompt"], tokens["completion"], model)
|
||||
# completion = the generated persona text
|
||||
completion_tokens = _chars_to_tokens(len(text))
|
||||
# prompt = template + focus group brief (fixed base)
|
||||
prompt_tokens = BASE_PROMPT_PERSONA_GENERATE
|
||||
|
||||
cost = _estimate_cost(prompt_tokens, completion_tokens, model)
|
||||
|
||||
ts = persona.get("created_at") or persona.get("updatedAt") or datetime.now(timezone.utc)
|
||||
if isinstance(ts, str):
|
||||
|
|
@ -237,25 +304,19 @@ def backfill_personas(db, dry_run: bool) -> int:
|
|||
"model": model,
|
||||
"feature": "persona_generate",
|
||||
"user_id": str(persona.get("created_by") or persona.get("user_id") or ""),
|
||||
"focus_group_id": "",
|
||||
"focus_group_id": str(persona.get("focus_group_id") or ""),
|
||||
"persona_id": persona_id,
|
||||
"prompt_tokens": tokens["prompt"],
|
||||
"completion_tokens": tokens["completion"],
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"cached_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
"total_tokens": tokens["prompt"] + tokens["completion"],
|
||||
"cost_usd": {
|
||||
"input": round(cost * 0.4, 8),
|
||||
"output": round(cost * 0.6, 8),
|
||||
"cached": 0,
|
||||
"reasoning": 0,
|
||||
"total": cost,
|
||||
},
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
"cost_usd": cost,
|
||||
"duration_ms": 0,
|
||||
"retry_count": 0,
|
||||
"status": "success",
|
||||
"is_estimated": True,
|
||||
"estimate_method": "char_div_3_8",
|
||||
"estimate_method": "accumulated_context",
|
||||
"source_persona_id": persona_id,
|
||||
}
|
||||
|
||||
|
|
@ -290,6 +351,10 @@ def main():
|
|||
_load_pricing(db)
|
||||
print(f"Loaded {len(_pricing_cache)} pricing rows: {list(_pricing_cache.keys())}")
|
||||
|
||||
# Dry-run: show a sample of what the cost distribution looks like
|
||||
if args.dry_run:
|
||||
_dry_run_sample(db)
|
||||
|
||||
total = 0
|
||||
total += backfill_messages(db, args.dry_run)
|
||||
total += backfill_personas(db, args.dry_run)
|
||||
|
|
@ -297,5 +362,38 @@ def main():
|
|||
print(f"\n{'[DRY RUN] ' if args.dry_run else ''}Backfill complete — {total} events total")
|
||||
|
||||
|
||||
def _dry_run_sample(db):
|
||||
"""Print a sample of estimated costs to sanity-check before real run."""
|
||||
from collections import defaultdict
|
||||
by_fg: dict = defaultdict(list)
|
||||
for msg in db.focus_group_messages.find({}):
|
||||
fg_id = str(msg.get("focus_group_id") or "")
|
||||
if msg.get("type", "") in ("response", "question", "moderator", "ai", ""):
|
||||
by_fg[fg_id].append(msg)
|
||||
|
||||
print("\n[dry-run sample] Estimated cost per focus group (top 5 by message count):")
|
||||
fg_meta = {str(fg["_id"]): fg.get("llm_model") or "gemini-3.1-pro-preview"
|
||||
for fg in db.focus_groups.find({}, {"llm_model": 1})}
|
||||
|
||||
rows = []
|
||||
for fg_id, msgs in by_fg.items():
|
||||
model = fg_meta.get(fg_id, "gemini-3.1-pro-preview")
|
||||
accumulated = 0
|
||||
total_cost = 0
|
||||
for msg in sorted(msgs, key=lambda m: m.get("created_at") or datetime.min):
|
||||
text = msg.get("text") or msg.get("content") or ""
|
||||
completion = _chars_to_tokens(len(text))
|
||||
prompt = BASE_PROMPT_PERSONA_RESPONSE + _chars_to_tokens(accumulated)
|
||||
cost = _estimate_cost(prompt, completion, model)
|
||||
total_cost += cost["total"]
|
||||
accumulated += len(text)
|
||||
rows.append((fg_id, len(msgs), total_cost))
|
||||
|
||||
for fg_id, count, cost in sorted(rows, key=lambda r: -r[1])[:5]:
|
||||
fg = db.focus_groups.find_one({"_id": __import__("bson").ObjectId(fg_id)}, {"name": 1}) if fg_id else None
|
||||
name = (fg or {}).get("name", fg_id[:8])
|
||||
print(f" {name}: {count} messages → estimated ${cost:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue