feat: integrate AI Cost Tracker for usage analytics

Created cost_tracker.py with async httpx client, ContextVar-based user
propagation, and LlamaIndex token extraction helper. Wrapped all AI call
sites — studio generators (7 types), notebook synthesis, podcast outline
+ script, ElevenLabs TTS, and podcast background task. Routes set user
context via set_user_ctx(current_user.email) before AI dispatch so every
record() call carries user identity without changing generator signatures.

Source app: Sandbox-NotebookLM
Tracker URL: https://optical-dev.oliver.solutions/cost-tracker/v1

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Vadym Samoilenko 2026-04-27 16:12:01 +01:00
parent 1f1f83c8e6
commit 400a342418
7 changed files with 195 additions and 3 deletions

4
backend/.env.example Normal file
View file

@ -0,0 +1,4 @@
# AI Cost Tracker
COST_TRACKER_BASE_URL=https://optical-dev.oliver.solutions/cost-tracker/v1
COST_TRACKER_API_KEY=<generate at https://optical-dev.oliver.solutions/cost-tracker/ → API Keys → + New Key>
COST_TRACKER_SOURCE_APP=Sandbox-NotebookLM

View file

@ -7,10 +7,11 @@ import sys
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
# Add paths to import backend modules
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "notebookllama"))
from cost_tracker import set_user_ctx # noqa: E402
logger = logging.getLogger(__name__)
# Import JWT authentication
sys.path.insert(0, str(Path(__file__).parent.parent))
@ -411,12 +412,14 @@ async def generate_podcast(
final_voice1_id = request.voice1_id or DEFAULT_VOICE1_ID
final_voice2_id = request.voice2_id or DEFAULT_VOICE2_ID
set_user_ctx(current_user.email)
params = {
'target_length': request.target_length,
'custom_theme': request.custom_theme,
'custom_prompt': None,
'voice1_id': final_voice1_id,
'voice2_id': final_voice2_id
'voice2_id': final_voice2_id,
'user_email': current_user.email,
}
# Log podcast request with voice selection details
@ -796,6 +799,7 @@ async def gen_flashcards(notebook_id: int, opts: StudioGenerateRequest = None, c
docs = _get_doc_data(notebook_id)
if not docs:
raise HTTPException(status_code=400, detail="No documents with summaries found")
set_user_ctx(current_user.email)
result = await generate_flashcards(docs, nb.model_type, opts.dict() if opts else None)
data = result.model_dump()
_save_studio_key(notebook_id, "flashcards", data)
@ -811,6 +815,7 @@ async def gen_quiz(notebook_id: int, opts: StudioGenerateRequest = None, current
docs = _get_doc_data(notebook_id)
if not docs:
raise HTTPException(status_code=400, detail="No documents with summaries found")
set_user_ctx(current_user.email)
result = await generate_quiz(docs, nb.model_type, opts.dict() if opts else None)
data = result.model_dump()
_save_studio_key(notebook_id, "quiz", data)
@ -826,6 +831,7 @@ async def gen_mindmap(notebook_id: int, opts: StudioGenerateRequest = None, curr
docs = _get_doc_data(notebook_id)
if not docs:
raise HTTPException(status_code=400, detail="No documents with summaries found")
set_user_ctx(current_user.email)
result = await generate_mindmap(docs, nb.model_type, opts.dict() if opts else None)
data = result.model_dump()
_save_studio_key(notebook_id, "mindmap", data)
@ -841,6 +847,7 @@ async def gen_slides(notebook_id: int, opts: StudioGenerateRequest = None, curre
docs = _get_doc_data(notebook_id)
if not docs:
raise HTTPException(status_code=400, detail="No documents with summaries found")
set_user_ctx(current_user.email)
result = await generate_slides(docs, nb.model_type, opts.dict() if opts else None)
data = result.model_dump()
_save_studio_key(notebook_id, "slides", data)
@ -856,6 +863,7 @@ async def gen_report(notebook_id: int, opts: StudioGenerateRequest = None, curre
docs = _get_doc_data(notebook_id)
if not docs:
raise HTTPException(status_code=400, detail="No documents with summaries found")
set_user_ctx(current_user.email)
result = await generate_report(docs, nb.model_type, opts.dict() if opts else None)
data = result.model_dump()
_save_studio_key(notebook_id, "report", data)
@ -871,6 +879,7 @@ async def gen_infographic(notebook_id: int, opts: StudioGenerateRequest = None,
docs = _get_doc_data(notebook_id)
if not docs:
raise HTTPException(status_code=400, detail="No documents with summaries found")
set_user_ctx(current_user.email)
result = await generate_infographic(docs, nb.model_type, opts.dict() if opts else None)
data = result.model_dump()
_save_studio_key(notebook_id, "infographic", data)
@ -886,6 +895,7 @@ async def gen_datatable(notebook_id: int, opts: StudioGenerateRequest = None, cu
docs = _get_doc_data(notebook_id)
if not docs:
raise HTTPException(status_code=400, detail="No documents with summaries found")
set_user_ctx(current_user.email)
result = await generate_datatable(docs, nb.model_type, opts.dict() if opts else None)
data = result.model_dump()
_save_studio_key(notebook_id, "datatable", data)

View file

@ -86,6 +86,9 @@ class PodcastGenerator(BaseModel):
)
]
)
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
inp, out = extract_llama_tokens(response)
await record(model=model_id_for("podcast"), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
return MultiTurnConversation.model_validate_json(response.message.content)
@log_api_call("ELEVENLABS", "generate_audio")
@ -101,6 +104,10 @@ class PodcastGenerator(BaseModel):
logger.info(f"🔊 [TTS] Starting audio generation for {len(conversation.conversation)} segments")
logger.info(f"🔊 [TTS] Voice assignment: speaker1={voice1_name} (id={voice1_id}), speaker2={voice2_name} (id={voice2_id})")
from cost_tracker import record, get_user_ctx
total_chars = sum(len(t.content) for t in conversation.conversation)
await record(model="eleven_turbo_v2_5", user_external_id=get_user_ctx(), chars=total_chars)
files: List[str] = []
for i, turn in enumerate(conversation.conversation):
if turn.speaker == "speaker1":

View file

@ -700,6 +700,8 @@ async def execute_podcast_task(task_id: int):
return
params = json.loads(task.parameters) if task.parameters else {}
from cost_tracker import set_user_ctx
set_user_ctx(params.get('user_email', str(task.user_id)))
notebook_id = task.notebook_id
target_length = params.get('target_length', 10)
custom_theme = params.get('custom_theme')

View file

@ -0,0 +1,154 @@
"""Lightweight AI Cost Tracker integration — fail-open, fire-and-forget."""
import os
import logging
from contextvars import ContextVar
from typing import Optional, Tuple
import httpx
logger = logging.getLogger(__name__)
_BASE = os.environ.get("COST_TRACKER_BASE_URL", "").rstrip("/")
_KEY = os.environ.get("COST_TRACKER_API_KEY", "")
_APP = os.environ.get("COST_TRACKER_SOURCE_APP", "")
_HEADERS = {"X-API-Key": _KEY} if _KEY else {}
# Per-async-task user context — set in routes/background tasks, read at AI call sites
_user_ctx: ContextVar[str] = ContextVar("ct_user_email", default="")
def _enabled() -> bool:
return bool(_BASE and _KEY and _APP)
def set_user_ctx(email: str) -> None:
_user_ctx.set(email)
def get_user_ctx() -> str:
return _user_ctx.get()
def model_id_for(model_type: str) -> str:
"""Map llm_factory model_type alias to the actual model ID string."""
try:
from llm_factory import (
OPENAI_CHAT_MODEL, ANTHROPIC_CHAT_MODEL,
GEMINI_CHAT_MODEL, GEMINI_FLASH_MODEL, OPENAI_LEGACY_MODEL,
)
return {
"gpt54-exp": OPENAI_CHAT_MODEL,
"claude46-exp": ANTHROPIC_CHAT_MODEL,
"gemini31-exp": GEMINI_CHAT_MODEL,
"gemini31-flash": GEMINI_FLASH_MODEL,
"gpt4o": "gpt-4o",
"gpt4": "gpt-4",
"openai": "gpt-4",
"podcast": OPENAI_LEGACY_MODEL,
}.get(model_type, model_type)
except Exception:
return model_type
def extract_llama_tokens(response) -> Tuple[int, int]:
"""Extract (input_tokens, output_tokens) from a LlamaIndex ChatResponse."""
try:
raw = getattr(response, "raw", None)
if raw is None:
return 0, 0
if isinstance(raw, dict):
usage = raw.get("usage") or {}
return (
int(usage.get("prompt_tokens") or usage.get("input_tokens") or 0),
int(usage.get("completion_tokens") or usage.get("output_tokens") or 0),
)
u = getattr(raw, "usage", None)
if u:
inp = int(getattr(u, "prompt_tokens", 0) or getattr(u, "input_tokens", 0) or 0)
out = int(getattr(u, "completion_tokens", 0) or getattr(u, "output_tokens", 0) or 0)
return inp, out
except Exception:
pass
return 0, 0
async def preflight(
model: str,
user_external_id: str,
project_id: Optional[str] = None,
) -> bool:
"""Check budget before an AI call. Always returns True on error (fail-open)."""
if not _enabled() or not user_external_id:
return True
try:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{_BASE}/preflight",
headers=_HEADERS,
json={
"source_app": _APP,
"model": model,
"user_external_id": user_external_id,
**({"project_id": project_id} if project_id else {}),
},
timeout=3.0,
)
return r.json().get("allow", True)
except Exception as exc:
logger.warning("cost_tracker preflight error (allowing): %s", exc)
return True
async def record(
model: str,
user_external_id: str,
input_tokens: int = 0,
output_tokens: int = 0,
chars: int = 0,
project_id: Optional[str] = None,
metadata: Optional[dict] = None,
) -> None:
"""Record AI usage after a call. Fire-and-forget — errors never propagate."""
if not _enabled() or not user_external_id:
return
units: dict = {}
if input_tokens: units["token_input"] = input_tokens
if output_tokens: units["token_output"] = output_tokens
if chars: units["char"] = chars
if not units:
return
try:
async with httpx.AsyncClient() as client:
await client.post(
f"{_BASE}/usage/record",
headers=_HEADERS,
json={
"model": model,
"user_external_id": user_external_id,
"units": units,
**({"project_external_id": project_id} if project_id else {}),
**({"metadata": metadata} if metadata else {}),
},
timeout=5.0,
)
except Exception as exc:
logger.warning("cost_tracker record error (ignoring): %s", exc)
def upsert_user(
user_external_id: str,
email: Optional[str] = None,
full_name: Optional[str] = None,
role: Optional[str] = None,
) -> None:
"""Enrich user record with profile data. Call once on login. Sync, non-fatal."""
if not _enabled():
return
try:
payload: dict = {"source_app": _APP, "user_external_id": user_external_id}
if email: payload["email"] = email
if full_name: payload["full_name"] = full_name
if role: payload["role"] = role
httpx.post(f"{_BASE}/users/upsert", headers=_HEADERS, json=payload, timeout=3.0)
except Exception as exc:
logger.warning("cost_tracker upsert_user error (ignoring): %s", exc)

View file

@ -101,8 +101,11 @@ Focus on what's revealed when these documents are considered as a collection, no
]
try:
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
llm_synthesis = get_structured_llm(model_type, NotebookSynthesis)
response = await llm_synthesis.achat(messages=messages)
inp, out = extract_llama_tokens(response)
await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
# Handle different response types from different models
content = response.message.content
@ -208,6 +211,9 @@ Return ONLY the JSON, no additional text."""
))
response = await llm.achat(messages=json_messages)
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
inp, out = extract_llama_tokens(response)
await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
content = response.message.content
print(f"Raw response from {model_type}: {content[:500]}")
@ -232,6 +238,9 @@ Return ONLY the JSON, no additional text."""
# OpenAI works with as_structured_llm
llm_podcast = get_structured_llm(model_type, PodcastOutline)
response = await llm_podcast.achat(messages=messages)
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
inp, out = extract_llama_tokens(response)
await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
outline = PodcastOutline.model_validate_json(response.message.content)
outline.target_length_minutes = target_length
@ -309,6 +318,9 @@ Format as a natural conversation with approximately {target_turns} speaking turn
try:
llm = get_llm_by_type(model_type)
response = await llm.achat(messages=messages)
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
inp, out = extract_llama_tokens(response)
await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
script = response.message.content
print(f" ✓ Script generated: {len(script)} chars (~{target_turns} turns)")
return script

View file

@ -157,6 +157,9 @@ async def _generate(doc_summaries, model_type, system_msg, user_msg, output_clas
else:
llm = get_structured_llm(model_type, output_class)
response = await llm.achat(messages=messages)
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
inp, out = extract_llama_tokens(response)
await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
return _parse_response(response, output_class)