feat: integrate AI Cost Tracker for usage analytics
Created cost_tracker.py with async httpx client, ContextVar-based user propagation, and LlamaIndex token extraction helper. Wrapped all AI call sites — studio generators (7 types), notebook synthesis, podcast outline + script, ElevenLabs TTS, and podcast background task. Routes set user context via set_user_ctx(current_user.email) before AI dispatch so every record() call carries user identity without changing generator signatures. Source app: Sandbox-NotebookLM Tracker URL: https://optical-dev.oliver.solutions/cost-tracker/v1 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
1f1f83c8e6
commit
400a342418
7 changed files with 195 additions and 3 deletions
4
backend/.env.example
Normal file
4
backend/.env.example
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# AI Cost Tracker
|
||||
COST_TRACKER_BASE_URL=https://optical-dev.oliver.solutions/cost-tracker/v1
|
||||
COST_TRACKER_API_KEY=<generate at https://optical-dev.oliver.solutions/cost-tracker/ → API Keys → + New Key>
|
||||
COST_TRACKER_SOURCE_APP=Sandbox-NotebookLM
|
||||
|
|
@ -7,10 +7,11 @@ import sys
|
|||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add paths to import backend modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "notebookllama"))
|
||||
from cost_tracker import set_user_ctx # noqa: E402
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import JWT authentication
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
|
@ -411,12 +412,14 @@ async def generate_podcast(
|
|||
final_voice1_id = request.voice1_id or DEFAULT_VOICE1_ID
|
||||
final_voice2_id = request.voice2_id or DEFAULT_VOICE2_ID
|
||||
|
||||
set_user_ctx(current_user.email)
|
||||
params = {
|
||||
'target_length': request.target_length,
|
||||
'custom_theme': request.custom_theme,
|
||||
'custom_prompt': None,
|
||||
'voice1_id': final_voice1_id,
|
||||
'voice2_id': final_voice2_id
|
||||
'voice2_id': final_voice2_id,
|
||||
'user_email': current_user.email,
|
||||
}
|
||||
|
||||
# Log podcast request with voice selection details
|
||||
|
|
@ -796,6 +799,7 @@ async def gen_flashcards(notebook_id: int, opts: StudioGenerateRequest = None, c
|
|||
docs = _get_doc_data(notebook_id)
|
||||
if not docs:
|
||||
raise HTTPException(status_code=400, detail="No documents with summaries found")
|
||||
set_user_ctx(current_user.email)
|
||||
result = await generate_flashcards(docs, nb.model_type, opts.dict() if opts else None)
|
||||
data = result.model_dump()
|
||||
_save_studio_key(notebook_id, "flashcards", data)
|
||||
|
|
@ -811,6 +815,7 @@ async def gen_quiz(notebook_id: int, opts: StudioGenerateRequest = None, current
|
|||
docs = _get_doc_data(notebook_id)
|
||||
if not docs:
|
||||
raise HTTPException(status_code=400, detail="No documents with summaries found")
|
||||
set_user_ctx(current_user.email)
|
||||
result = await generate_quiz(docs, nb.model_type, opts.dict() if opts else None)
|
||||
data = result.model_dump()
|
||||
_save_studio_key(notebook_id, "quiz", data)
|
||||
|
|
@ -826,6 +831,7 @@ async def gen_mindmap(notebook_id: int, opts: StudioGenerateRequest = None, curr
|
|||
docs = _get_doc_data(notebook_id)
|
||||
if not docs:
|
||||
raise HTTPException(status_code=400, detail="No documents with summaries found")
|
||||
set_user_ctx(current_user.email)
|
||||
result = await generate_mindmap(docs, nb.model_type, opts.dict() if opts else None)
|
||||
data = result.model_dump()
|
||||
_save_studio_key(notebook_id, "mindmap", data)
|
||||
|
|
@ -841,6 +847,7 @@ async def gen_slides(notebook_id: int, opts: StudioGenerateRequest = None, curre
|
|||
docs = _get_doc_data(notebook_id)
|
||||
if not docs:
|
||||
raise HTTPException(status_code=400, detail="No documents with summaries found")
|
||||
set_user_ctx(current_user.email)
|
||||
result = await generate_slides(docs, nb.model_type, opts.dict() if opts else None)
|
||||
data = result.model_dump()
|
||||
_save_studio_key(notebook_id, "slides", data)
|
||||
|
|
@ -856,6 +863,7 @@ async def gen_report(notebook_id: int, opts: StudioGenerateRequest = None, curre
|
|||
docs = _get_doc_data(notebook_id)
|
||||
if not docs:
|
||||
raise HTTPException(status_code=400, detail="No documents with summaries found")
|
||||
set_user_ctx(current_user.email)
|
||||
result = await generate_report(docs, nb.model_type, opts.dict() if opts else None)
|
||||
data = result.model_dump()
|
||||
_save_studio_key(notebook_id, "report", data)
|
||||
|
|
@ -871,6 +879,7 @@ async def gen_infographic(notebook_id: int, opts: StudioGenerateRequest = None,
|
|||
docs = _get_doc_data(notebook_id)
|
||||
if not docs:
|
||||
raise HTTPException(status_code=400, detail="No documents with summaries found")
|
||||
set_user_ctx(current_user.email)
|
||||
result = await generate_infographic(docs, nb.model_type, opts.dict() if opts else None)
|
||||
data = result.model_dump()
|
||||
_save_studio_key(notebook_id, "infographic", data)
|
||||
|
|
@ -886,6 +895,7 @@ async def gen_datatable(notebook_id: int, opts: StudioGenerateRequest = None, cu
|
|||
docs = _get_doc_data(notebook_id)
|
||||
if not docs:
|
||||
raise HTTPException(status_code=400, detail="No documents with summaries found")
|
||||
set_user_ctx(current_user.email)
|
||||
result = await generate_datatable(docs, nb.model_type, opts.dict() if opts else None)
|
||||
data = result.model_dump()
|
||||
_save_studio_key(notebook_id, "datatable", data)
|
||||
|
|
|
|||
|
|
@ -86,6 +86,9 @@ class PodcastGenerator(BaseModel):
|
|||
)
|
||||
]
|
||||
)
|
||||
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
|
||||
inp, out = extract_llama_tokens(response)
|
||||
await record(model=model_id_for("podcast"), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
|
||||
return MultiTurnConversation.model_validate_json(response.message.content)
|
||||
|
||||
@log_api_call("ELEVENLABS", "generate_audio")
|
||||
|
|
@ -101,6 +104,10 @@ class PodcastGenerator(BaseModel):
|
|||
logger.info(f"🔊 [TTS] Starting audio generation for {len(conversation.conversation)} segments")
|
||||
logger.info(f"🔊 [TTS] Voice assignment: speaker1={voice1_name} (id={voice1_id}), speaker2={voice2_name} (id={voice2_id})")
|
||||
|
||||
from cost_tracker import record, get_user_ctx
|
||||
total_chars = sum(len(t.content) for t in conversation.conversation)
|
||||
await record(model="eleven_turbo_v2_5", user_external_id=get_user_ctx(), chars=total_chars)
|
||||
|
||||
files: List[str] = []
|
||||
for i, turn in enumerate(conversation.conversation):
|
||||
if turn.speaker == "speaker1":
|
||||
|
|
|
|||
|
|
@ -700,6 +700,8 @@ async def execute_podcast_task(task_id: int):
|
|||
return
|
||||
|
||||
params = json.loads(task.parameters) if task.parameters else {}
|
||||
from cost_tracker import set_user_ctx
|
||||
set_user_ctx(params.get('user_email', str(task.user_id)))
|
||||
notebook_id = task.notebook_id
|
||||
target_length = params.get('target_length', 10)
|
||||
custom_theme = params.get('custom_theme')
|
||||
|
|
|
|||
154
backend/src/notebookllama/cost_tracker.py
Normal file
154
backend/src/notebookllama/cost_tracker.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
"""Lightweight AI Cost Tracker integration — fail-open, fire-and-forget."""
|
||||
import os
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BASE = os.environ.get("COST_TRACKER_BASE_URL", "").rstrip("/")
|
||||
_KEY = os.environ.get("COST_TRACKER_API_KEY", "")
|
||||
_APP = os.environ.get("COST_TRACKER_SOURCE_APP", "")
|
||||
_HEADERS = {"X-API-Key": _KEY} if _KEY else {}
|
||||
|
||||
# Per-async-task user context — set in routes/background tasks, read at AI call sites
|
||||
_user_ctx: ContextVar[str] = ContextVar("ct_user_email", default="")
|
||||
|
||||
|
||||
def _enabled() -> bool:
|
||||
return bool(_BASE and _KEY and _APP)
|
||||
|
||||
|
||||
def set_user_ctx(email: str) -> None:
|
||||
_user_ctx.set(email)
|
||||
|
||||
|
||||
def get_user_ctx() -> str:
|
||||
return _user_ctx.get()
|
||||
|
||||
|
||||
def model_id_for(model_type: str) -> str:
|
||||
"""Map llm_factory model_type alias to the actual model ID string."""
|
||||
try:
|
||||
from llm_factory import (
|
||||
OPENAI_CHAT_MODEL, ANTHROPIC_CHAT_MODEL,
|
||||
GEMINI_CHAT_MODEL, GEMINI_FLASH_MODEL, OPENAI_LEGACY_MODEL,
|
||||
)
|
||||
return {
|
||||
"gpt54-exp": OPENAI_CHAT_MODEL,
|
||||
"claude46-exp": ANTHROPIC_CHAT_MODEL,
|
||||
"gemini31-exp": GEMINI_CHAT_MODEL,
|
||||
"gemini31-flash": GEMINI_FLASH_MODEL,
|
||||
"gpt4o": "gpt-4o",
|
||||
"gpt4": "gpt-4",
|
||||
"openai": "gpt-4",
|
||||
"podcast": OPENAI_LEGACY_MODEL,
|
||||
}.get(model_type, model_type)
|
||||
except Exception:
|
||||
return model_type
|
||||
|
||||
|
||||
def extract_llama_tokens(response) -> Tuple[int, int]:
|
||||
"""Extract (input_tokens, output_tokens) from a LlamaIndex ChatResponse."""
|
||||
try:
|
||||
raw = getattr(response, "raw", None)
|
||||
if raw is None:
|
||||
return 0, 0
|
||||
if isinstance(raw, dict):
|
||||
usage = raw.get("usage") or {}
|
||||
return (
|
||||
int(usage.get("prompt_tokens") or usage.get("input_tokens") or 0),
|
||||
int(usage.get("completion_tokens") or usage.get("output_tokens") or 0),
|
||||
)
|
||||
u = getattr(raw, "usage", None)
|
||||
if u:
|
||||
inp = int(getattr(u, "prompt_tokens", 0) or getattr(u, "input_tokens", 0) or 0)
|
||||
out = int(getattr(u, "completion_tokens", 0) or getattr(u, "output_tokens", 0) or 0)
|
||||
return inp, out
|
||||
except Exception:
|
||||
pass
|
||||
return 0, 0
|
||||
|
||||
|
||||
async def preflight(
|
||||
model: str,
|
||||
user_external_id: str,
|
||||
project_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Check budget before an AI call. Always returns True on error (fail-open)."""
|
||||
if not _enabled() or not user_external_id:
|
||||
return True
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(
|
||||
f"{_BASE}/preflight",
|
||||
headers=_HEADERS,
|
||||
json={
|
||||
"source_app": _APP,
|
||||
"model": model,
|
||||
"user_external_id": user_external_id,
|
||||
**({"project_id": project_id} if project_id else {}),
|
||||
},
|
||||
timeout=3.0,
|
||||
)
|
||||
return r.json().get("allow", True)
|
||||
except Exception as exc:
|
||||
logger.warning("cost_tracker preflight error (allowing): %s", exc)
|
||||
return True
|
||||
|
||||
|
||||
async def record(
|
||||
model: str,
|
||||
user_external_id: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
chars: int = 0,
|
||||
project_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Record AI usage after a call. Fire-and-forget — errors never propagate."""
|
||||
if not _enabled() or not user_external_id:
|
||||
return
|
||||
units: dict = {}
|
||||
if input_tokens: units["token_input"] = input_tokens
|
||||
if output_tokens: units["token_output"] = output_tokens
|
||||
if chars: units["char"] = chars
|
||||
if not units:
|
||||
return
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
await client.post(
|
||||
f"{_BASE}/usage/record",
|
||||
headers=_HEADERS,
|
||||
json={
|
||||
"model": model,
|
||||
"user_external_id": user_external_id,
|
||||
"units": units,
|
||||
**({"project_external_id": project_id} if project_id else {}),
|
||||
**({"metadata": metadata} if metadata else {}),
|
||||
},
|
||||
timeout=5.0,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("cost_tracker record error (ignoring): %s", exc)
|
||||
|
||||
|
||||
def upsert_user(
|
||||
user_external_id: str,
|
||||
email: Optional[str] = None,
|
||||
full_name: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Enrich user record with profile data. Call once on login. Sync, non-fatal."""
|
||||
if not _enabled():
|
||||
return
|
||||
try:
|
||||
payload: dict = {"source_app": _APP, "user_external_id": user_external_id}
|
||||
if email: payload["email"] = email
|
||||
if full_name: payload["full_name"] = full_name
|
||||
if role: payload["role"] = role
|
||||
httpx.post(f"{_BASE}/users/upsert", headers=_HEADERS, json=payload, timeout=3.0)
|
||||
except Exception as exc:
|
||||
logger.warning("cost_tracker upsert_user error (ignoring): %s", exc)
|
||||
|
|
@ -101,8 +101,11 @@ Focus on what's revealed when these documents are considered as a collection, no
|
|||
]
|
||||
|
||||
try:
|
||||
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
|
||||
llm_synthesis = get_structured_llm(model_type, NotebookSynthesis)
|
||||
response = await llm_synthesis.achat(messages=messages)
|
||||
inp, out = extract_llama_tokens(response)
|
||||
await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
|
||||
|
||||
# Handle different response types from different models
|
||||
content = response.message.content
|
||||
|
|
@ -208,6 +211,9 @@ Return ONLY the JSON, no additional text."""
|
|||
))
|
||||
|
||||
response = await llm.achat(messages=json_messages)
|
||||
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
|
||||
inp, out = extract_llama_tokens(response)
|
||||
await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
|
||||
content = response.message.content
|
||||
|
||||
print(f"Raw response from {model_type}: {content[:500]}")
|
||||
|
|
@ -232,6 +238,9 @@ Return ONLY the JSON, no additional text."""
|
|||
# OpenAI works with as_structured_llm
|
||||
llm_podcast = get_structured_llm(model_type, PodcastOutline)
|
||||
response = await llm_podcast.achat(messages=messages)
|
||||
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
|
||||
inp, out = extract_llama_tokens(response)
|
||||
await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
|
||||
outline = PodcastOutline.model_validate_json(response.message.content)
|
||||
|
||||
outline.target_length_minutes = target_length
|
||||
|
|
@ -309,6 +318,9 @@ Format as a natural conversation with approximately {target_turns} speaking turn
|
|||
try:
|
||||
llm = get_llm_by_type(model_type)
|
||||
response = await llm.achat(messages=messages)
|
||||
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
|
||||
inp, out = extract_llama_tokens(response)
|
||||
await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
|
||||
script = response.message.content
|
||||
print(f" ✓ Script generated: {len(script)} chars (~{target_turns} turns)")
|
||||
return script
|
||||
|
|
|
|||
|
|
@ -157,6 +157,9 @@ async def _generate(doc_summaries, model_type, system_msg, user_msg, output_clas
|
|||
else:
|
||||
llm = get_structured_llm(model_type, output_class)
|
||||
response = await llm.achat(messages=messages)
|
||||
from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens
|
||||
inp, out = extract_llama_tokens(response)
|
||||
await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out)
|
||||
return _parse_response(response, output_class)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue