from app.db import get_db from datetime import datetime, timezone import logging import time logger = logging.getLogger(__name__) # In-process cache: (model_name -> (pricing_dict, cached_at_monotonic)) _pricing_cache: dict = {} _CACHE_TTL_SECONDS = 60 def _cache_get(model: str): entry = _pricing_cache.get(model) if entry and (time.monotonic() - entry[1]) < _CACHE_TTL_SECONDS: return entry[0] return None def _cache_set(model: str, pricing: dict): _pricing_cache[model] = (pricing, time.monotonic()) class ModelPricing: @staticmethod async def current_for(model_name: str) -> dict | None: """Return the active pricing row for a model, with 60 s in-process cache. Resolves MODEL_ALIASES before lookup so callers can pass raw model names. Returns None if no pricing is configured (cost will be recorded as 0). """ from app.services.llm_service import MODEL_ALIASES resolved = MODEL_ALIASES.get(model_name, model_name) cached = _cache_get(resolved) if cached is not None: return cached try: db = await get_db() now = datetime.now(timezone.utc) doc = await db.model_pricing.find_one( { "model": resolved, "effective_from": {"$lte": now}, "$or": [ {"effective_until": None}, {"effective_until": {"$gt": now}}, ], }, sort=[("effective_from", -1)], ) _cache_set(resolved, doc) return doc except Exception: logger.warning(f"Failed to fetch pricing for model {resolved}", exc_info=True) return None @staticmethod def pick_tier(pricing: dict, prompt_tokens: int) -> dict | None: """Return the cost tier that applies for a given prompt token count.""" if not pricing: return None tiers = pricing.get("tiers") or [] if not tiers: return None # Pick the tier with the largest threshold still <= prompt_tokens applicable = [t for t in tiers if t.get("threshold_input_tokens", 0) <= prompt_tokens] if not applicable: applicable = tiers # fall back to first tier return max(applicable, key=lambda t: t.get("threshold_input_tokens", 0)) @staticmethod def compute_cost(pricing: dict | None, prompt_tokens: int, completion_tokens: int, cached_tokens: int = 0) -> dict: """Compute cost breakdown from token counts and pricing doc. Returns a dict with keys: input, cached, output, total (all USD floats). All values are 0.0 if pricing is None. """ zero = {"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0} if not pricing: return zero tier = ModelPricing.pick_tier(pricing, prompt_tokens) if not tier: return zero input_per_mtok = tier.get("input_per_mtok") or 0 cached_per_mtok = tier.get("cached_input_per_mtok") or 0 output_per_mtok = tier.get("output_per_mtok") or 0 billable_input = max(0, prompt_tokens - cached_tokens) cost_input = billable_input * input_per_mtok / 1_000_000 cost_cached = cached_tokens * cached_per_mtok / 1_000_000 cost_output = completion_tokens * output_per_mtok / 1_000_000 cost_total = cost_input + cost_cached + cost_output return { "input": round(cost_input, 8), "cached": round(cost_cached, 8), "output": round(cost_output, 8), "total": round(cost_total, 8), }