"""Token usage tracking - records LLM usage to the database.""" import logging from typing import Any from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession from app.models.audit import TokenUsageLog logger = logging.getLogger(__name__) async def record_token_usage( db: AsyncSession, instance_id: UUID, agent_name: str, usage: dict[str, Any], ) -> TokenUsageLog: """Record token usage from an LLM call to the database. Args: db: Async database session. instance_id: The locale instance ID this usage is for. agent_name: Name of the agent that made the call. usage: Usage dict from LLMClient with keys: input_tokens, output_tokens, total_tokens, estimated_cost_usd, model. Returns: The created TokenUsageLog record. """ log_entry = TokenUsageLog( instance_id=instance_id, agent_name=agent_name, model=usage.get("model", "unknown"), input_tokens=usage.get("input_tokens", 0), output_tokens=usage.get("output_tokens", 0), total_tokens=usage.get("total_tokens", 0), estimated_cost_usd=usage.get("estimated_cost_usd", 0.0), ) db.add(log_entry) await db.flush() logger.info( f"Token usage recorded: agent={agent_name}, " f"tokens={usage.get('total_tokens', 0)}, " f"cost=${usage.get('estimated_cost_usd', 0.0):.6f}" ) return log_entry async def get_total_usage_for_instance( db: AsyncSession, instance_id: UUID, ) -> dict[str, Any]: """Get aggregated token usage for a locale instance. Args: db: Async database session. instance_id: The locale instance ID. Returns: Dict with total_tokens, total_cost, by_agent breakdown. """ from sqlalchemy import func, select result = await db.execute( select( func.sum(TokenUsageLog.input_tokens).label("total_input"), func.sum(TokenUsageLog.output_tokens).label("total_output"), func.sum(TokenUsageLog.total_tokens).label("total_tokens"), func.sum(TokenUsageLog.estimated_cost_usd).label("total_cost"), ).where(TokenUsageLog.instance_id == instance_id) ) row = result.one() # By-agent breakdown agent_result = await db.execute( select( TokenUsageLog.agent_name, func.sum(TokenUsageLog.total_tokens).label("tokens"), func.sum(TokenUsageLog.estimated_cost_usd).label("cost"), ) .where(TokenUsageLog.instance_id == instance_id) .group_by(TokenUsageLog.agent_name) ) by_agent = { agent_name: {"tokens": tokens, "cost": float(cost)} for agent_name, tokens, cost in agent_result.all() } return { "total_input_tokens": row.total_input or 0, "total_output_tokens": row.total_output or 0, "total_tokens": row.total_tokens or 0, "total_cost_usd": float(row.total_cost or 0.0), "by_agent": by_agent, }