diff --git a/backend/.env.example b/backend/.env.example new file mode 100644 index 0000000..25f528c --- /dev/null +++ b/backend/.env.example @@ -0,0 +1,4 @@ +# AI Cost Tracker +COST_TRACKER_BASE_URL=https://optical-dev.oliver.solutions/cost-tracker/v1 +COST_TRACKER_API_KEY= +COST_TRACKER_SOURCE_APP=Sandbox-NotebookLM diff --git a/backend/src/api/routes/notebooks.py b/backend/src/api/routes/notebooks.py index af47ef5..632dc8e 100644 --- a/backend/src/api/routes/notebooks.py +++ b/backend/src/api/routes/notebooks.py @@ -7,10 +7,11 @@ import sys import logging from pathlib import Path -logger = logging.getLogger(__name__) - # Add paths to import backend modules sys.path.insert(0, str(Path(__file__).parent.parent.parent / "notebookllama")) +from cost_tracker import set_user_ctx # noqa: E402 + +logger = logging.getLogger(__name__) # Import JWT authentication sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -411,12 +412,14 @@ async def generate_podcast( final_voice1_id = request.voice1_id or DEFAULT_VOICE1_ID final_voice2_id = request.voice2_id or DEFAULT_VOICE2_ID + set_user_ctx(current_user.email) params = { 'target_length': request.target_length, 'custom_theme': request.custom_theme, 'custom_prompt': None, 'voice1_id': final_voice1_id, - 'voice2_id': final_voice2_id + 'voice2_id': final_voice2_id, + 'user_email': current_user.email, } # Log podcast request with voice selection details @@ -796,6 +799,7 @@ async def gen_flashcards(notebook_id: int, opts: StudioGenerateRequest = None, c docs = _get_doc_data(notebook_id) if not docs: raise HTTPException(status_code=400, detail="No documents with summaries found") + set_user_ctx(current_user.email) result = await generate_flashcards(docs, nb.model_type, opts.dict() if opts else None) data = result.model_dump() _save_studio_key(notebook_id, "flashcards", data) @@ -811,6 +815,7 @@ async def gen_quiz(notebook_id: int, opts: StudioGenerateRequest = None, current docs = _get_doc_data(notebook_id) if not docs: raise HTTPException(status_code=400, detail="No documents with summaries found") + set_user_ctx(current_user.email) result = await generate_quiz(docs, nb.model_type, opts.dict() if opts else None) data = result.model_dump() _save_studio_key(notebook_id, "quiz", data) @@ -826,6 +831,7 @@ async def gen_mindmap(notebook_id: int, opts: StudioGenerateRequest = None, curr docs = _get_doc_data(notebook_id) if not docs: raise HTTPException(status_code=400, detail="No documents with summaries found") + set_user_ctx(current_user.email) result = await generate_mindmap(docs, nb.model_type, opts.dict() if opts else None) data = result.model_dump() _save_studio_key(notebook_id, "mindmap", data) @@ -841,6 +847,7 @@ async def gen_slides(notebook_id: int, opts: StudioGenerateRequest = None, curre docs = _get_doc_data(notebook_id) if not docs: raise HTTPException(status_code=400, detail="No documents with summaries found") + set_user_ctx(current_user.email) result = await generate_slides(docs, nb.model_type, opts.dict() if opts else None) data = result.model_dump() _save_studio_key(notebook_id, "slides", data) @@ -856,6 +863,7 @@ async def gen_report(notebook_id: int, opts: StudioGenerateRequest = None, curre docs = _get_doc_data(notebook_id) if not docs: raise HTTPException(status_code=400, detail="No documents with summaries found") + set_user_ctx(current_user.email) result = await generate_report(docs, nb.model_type, opts.dict() if opts else None) data = result.model_dump() _save_studio_key(notebook_id, "report", data) @@ -871,6 +879,7 @@ async def gen_infographic(notebook_id: int, opts: StudioGenerateRequest = None, docs = _get_doc_data(notebook_id) if not docs: raise HTTPException(status_code=400, detail="No documents with summaries found") + set_user_ctx(current_user.email) result = await generate_infographic(docs, nb.model_type, opts.dict() if opts else None) data = result.model_dump() _save_studio_key(notebook_id, "infographic", data) @@ -886,6 +895,7 @@ async def gen_datatable(notebook_id: int, opts: StudioGenerateRequest = None, cu docs = _get_doc_data(notebook_id) if not docs: raise HTTPException(status_code=400, detail="No documents with summaries found") + set_user_ctx(current_user.email) result = await generate_datatable(docs, nb.model_type, opts.dict() if opts else None) data = result.model_dump() _save_studio_key(notebook_id, "datatable", data) diff --git a/backend/src/notebookllama/audio.py b/backend/src/notebookllama/audio.py index a0752e0..76dc04e 100644 --- a/backend/src/notebookllama/audio.py +++ b/backend/src/notebookllama/audio.py @@ -86,6 +86,9 @@ class PodcastGenerator(BaseModel): ) ] ) + from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens + inp, out = extract_llama_tokens(response) + await record(model=model_id_for("podcast"), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out) return MultiTurnConversation.model_validate_json(response.message.content) @log_api_call("ELEVENLABS", "generate_audio") @@ -101,6 +104,10 @@ class PodcastGenerator(BaseModel): logger.info(f"🔊 [TTS] Starting audio generation for {len(conversation.conversation)} segments") logger.info(f"🔊 [TTS] Voice assignment: speaker1={voice1_name} (id={voice1_id}), speaker2={voice2_name} (id={voice2_id})") + from cost_tracker import record, get_user_ctx + total_chars = sum(len(t.content) for t in conversation.conversation) + await record(model="eleven_turbo_v2_5", user_external_id=get_user_ctx(), chars=total_chars) + files: List[str] = [] for i, turn in enumerate(conversation.conversation): if turn.speaker == "speaker1": diff --git a/backend/src/notebookllama/background_tasks.py b/backend/src/notebookllama/background_tasks.py index f44b6eb..69dc5be 100644 --- a/backend/src/notebookllama/background_tasks.py +++ b/backend/src/notebookllama/background_tasks.py @@ -700,6 +700,8 @@ async def execute_podcast_task(task_id: int): return params = json.loads(task.parameters) if task.parameters else {} + from cost_tracker import set_user_ctx + set_user_ctx(params.get('user_email', str(task.user_id))) notebook_id = task.notebook_id target_length = params.get('target_length', 10) custom_theme = params.get('custom_theme') diff --git a/backend/src/notebookllama/cost_tracker.py b/backend/src/notebookllama/cost_tracker.py new file mode 100644 index 0000000..9297a17 --- /dev/null +++ b/backend/src/notebookllama/cost_tracker.py @@ -0,0 +1,154 @@ +"""Lightweight AI Cost Tracker integration — fail-open, fire-and-forget.""" +import os +import logging +from contextvars import ContextVar +from typing import Optional, Tuple + +import httpx + +logger = logging.getLogger(__name__) + +_BASE = os.environ.get("COST_TRACKER_BASE_URL", "").rstrip("/") +_KEY = os.environ.get("COST_TRACKER_API_KEY", "") +_APP = os.environ.get("COST_TRACKER_SOURCE_APP", "") +_HEADERS = {"X-API-Key": _KEY} if _KEY else {} + +# Per-async-task user context — set in routes/background tasks, read at AI call sites +_user_ctx: ContextVar[str] = ContextVar("ct_user_email", default="") + + +def _enabled() -> bool: + return bool(_BASE and _KEY and _APP) + + +def set_user_ctx(email: str) -> None: + _user_ctx.set(email) + + +def get_user_ctx() -> str: + return _user_ctx.get() + + +def model_id_for(model_type: str) -> str: + """Map llm_factory model_type alias to the actual model ID string.""" + try: + from llm_factory import ( + OPENAI_CHAT_MODEL, ANTHROPIC_CHAT_MODEL, + GEMINI_CHAT_MODEL, GEMINI_FLASH_MODEL, OPENAI_LEGACY_MODEL, + ) + return { + "gpt54-exp": OPENAI_CHAT_MODEL, + "claude46-exp": ANTHROPIC_CHAT_MODEL, + "gemini31-exp": GEMINI_CHAT_MODEL, + "gemini31-flash": GEMINI_FLASH_MODEL, + "gpt4o": "gpt-4o", + "gpt4": "gpt-4", + "openai": "gpt-4", + "podcast": OPENAI_LEGACY_MODEL, + }.get(model_type, model_type) + except Exception: + return model_type + + +def extract_llama_tokens(response) -> Tuple[int, int]: + """Extract (input_tokens, output_tokens) from a LlamaIndex ChatResponse.""" + try: + raw = getattr(response, "raw", None) + if raw is None: + return 0, 0 + if isinstance(raw, dict): + usage = raw.get("usage") or {} + return ( + int(usage.get("prompt_tokens") or usage.get("input_tokens") or 0), + int(usage.get("completion_tokens") or usage.get("output_tokens") or 0), + ) + u = getattr(raw, "usage", None) + if u: + inp = int(getattr(u, "prompt_tokens", 0) or getattr(u, "input_tokens", 0) or 0) + out = int(getattr(u, "completion_tokens", 0) or getattr(u, "output_tokens", 0) or 0) + return inp, out + except Exception: + pass + return 0, 0 + + +async def preflight( + model: str, + user_external_id: str, + project_id: Optional[str] = None, +) -> bool: + """Check budget before an AI call. Always returns True on error (fail-open).""" + if not _enabled() or not user_external_id: + return True + try: + async with httpx.AsyncClient() as client: + r = await client.post( + f"{_BASE}/preflight", + headers=_HEADERS, + json={ + "source_app": _APP, + "model": model, + "user_external_id": user_external_id, + **({"project_id": project_id} if project_id else {}), + }, + timeout=3.0, + ) + return r.json().get("allow", True) + except Exception as exc: + logger.warning("cost_tracker preflight error (allowing): %s", exc) + return True + + +async def record( + model: str, + user_external_id: str, + input_tokens: int = 0, + output_tokens: int = 0, + chars: int = 0, + project_id: Optional[str] = None, + metadata: Optional[dict] = None, +) -> None: + """Record AI usage after a call. Fire-and-forget — errors never propagate.""" + if not _enabled() or not user_external_id: + return + units: dict = {} + if input_tokens: units["token_input"] = input_tokens + if output_tokens: units["token_output"] = output_tokens + if chars: units["char"] = chars + if not units: + return + try: + async with httpx.AsyncClient() as client: + await client.post( + f"{_BASE}/usage/record", + headers=_HEADERS, + json={ + "model": model, + "user_external_id": user_external_id, + "units": units, + **({"project_external_id": project_id} if project_id else {}), + **({"metadata": metadata} if metadata else {}), + }, + timeout=5.0, + ) + except Exception as exc: + logger.warning("cost_tracker record error (ignoring): %s", exc) + + +def upsert_user( + user_external_id: str, + email: Optional[str] = None, + full_name: Optional[str] = None, + role: Optional[str] = None, +) -> None: + """Enrich user record with profile data. Call once on login. Sync, non-fatal.""" + if not _enabled(): + return + try: + payload: dict = {"source_app": _APP, "user_external_id": user_external_id} + if email: payload["email"] = email + if full_name: payload["full_name"] = full_name + if role: payload["role"] = role + httpx.post(f"{_BASE}/users/upsert", headers=_HEADERS, json=payload, timeout=3.0) + except Exception as exc: + logger.warning("cost_tracker upsert_user error (ignoring): %s", exc) diff --git a/backend/src/notebookllama/notebook_synthesis.py b/backend/src/notebookllama/notebook_synthesis.py index 6495a4a..c6278b8 100644 --- a/backend/src/notebookllama/notebook_synthesis.py +++ b/backend/src/notebookllama/notebook_synthesis.py @@ -101,8 +101,11 @@ Focus on what's revealed when these documents are considered as a collection, no ] try: + from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens llm_synthesis = get_structured_llm(model_type, NotebookSynthesis) response = await llm_synthesis.achat(messages=messages) + inp, out = extract_llama_tokens(response) + await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out) # Handle different response types from different models content = response.message.content @@ -208,6 +211,9 @@ Return ONLY the JSON, no additional text.""" )) response = await llm.achat(messages=json_messages) + from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens + inp, out = extract_llama_tokens(response) + await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out) content = response.message.content print(f"Raw response from {model_type}: {content[:500]}") @@ -232,6 +238,9 @@ Return ONLY the JSON, no additional text.""" # OpenAI works with as_structured_llm llm_podcast = get_structured_llm(model_type, PodcastOutline) response = await llm_podcast.achat(messages=messages) + from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens + inp, out = extract_llama_tokens(response) + await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out) outline = PodcastOutline.model_validate_json(response.message.content) outline.target_length_minutes = target_length @@ -309,6 +318,9 @@ Format as a natural conversation with approximately {target_turns} speaking turn try: llm = get_llm_by_type(model_type) response = await llm.achat(messages=messages) + from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens + inp, out = extract_llama_tokens(response) + await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out) script = response.message.content print(f" ✓ Script generated: {len(script)} chars (~{target_turns} turns)") return script diff --git a/backend/src/notebookllama/studio_generators.py b/backend/src/notebookllama/studio_generators.py index 963d4cc..334e017 100644 --- a/backend/src/notebookllama/studio_generators.py +++ b/backend/src/notebookllama/studio_generators.py @@ -157,6 +157,9 @@ async def _generate(doc_summaries, model_type, system_msg, user_msg, output_clas else: llm = get_structured_llm(model_type, output_class) response = await llm.achat(messages=messages) + from cost_tracker import record, get_user_ctx, model_id_for, extract_llama_tokens + inp, out = extract_llama_tokens(response) + await record(model=model_id_for(model_type), user_external_id=get_user_ctx(), input_tokens=inp, output_tokens=out) return _parse_response(response, output_class)