ai-cost-tracker/backend/app/services/pricing_engine.py
Vadym Samoilenko 2822e9cb99 fix: pricing engine unit keys and google→vertex_ai provider alias
- compute_total_cost: read token_input/token_output/char (new keys)
  with fallback to old input_tokens/output_tokens/chars for compat
- _PROVIDER_ALIAS: google/gemini → vertex_ai-language-models
- _infer_provider: gemini → vertex_ai-language-models

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-27 14:41:14 +01:00

169 lines
5.5 KiB
Python

"""Pricing engine: load YAML on startup, LiteLLM sync, admin override.
Priority: override(3) > yaml(2) > litellm(1).
compute_cost() picks the highest-priority active record at the given timestamp.
"""
import re
from datetime import date, datetime, timezone
from pathlib import Path
from typing import Optional
import yaml
from motor.motor_asyncio import AsyncIOMotorDatabase
from ..core.logging import get_logger
from ..models.pricing import ModelPrice, SOURCE_PRIORITY
logger = get_logger(__name__)
PRICING_YAML = Path(__file__).parent.parent / "pricing" / "models.yaml"
async def load_yaml_prices(db: AsyncIOMotorDatabase) -> int:
"""Upsert YAML prices into model_prices. Never touches override records."""
if not PRICING_YAML.exists():
logger.warning("pricing/models.yaml not found — skipping YAML load")
return 0
raw = yaml.safe_load(PRICING_YAML.read_text())
count = 0
for entry in raw or []:
provider = entry["provider"]
model = entry["model"]
billing_unit = entry["billing_unit"]
effective_from = date.fromisoformat(entry["effective_from"])
# Compute price_per_unit from whichever key is present
if "price_per_unit_usd" in entry:
ppu = float(entry["price_per_unit_usd"])
elif "price_per_1k_usd" in entry:
ppu = float(entry["price_per_1k_usd"]) / 1000
elif "price_per_1m_usd" in entry:
ppu = float(entry["price_per_1m_usd"]) / 1_000_000
else:
logger.warning(f"No price key in YAML entry for {provider}/{model}")
continue
existing = await db.model_prices.find_one({
"provider": provider,
"model": model,
"source": "yaml",
"effective_from": effective_from.isoformat(),
})
if not existing:
await db.model_prices.insert_one({
"provider": provider,
"model": model,
"billing_unit": billing_unit,
"price_per_unit_usd": ppu,
"currency": "USD",
"effective_from": effective_from.isoformat(),
"effective_to": None,
"source": "yaml",
"created_at": datetime.now(timezone.utc),
})
count += 1
logger.info(f"YAML pricing: {count} new records upserted")
return count
async def compute_cost(
db: AsyncIOMotorDatabase,
provider: str,
model: str,
units: dict, # e.g. {"input_tokens": 1000, "output_tokens": 200}
ts: Optional[datetime] = None,
) -> tuple[Optional[float], Optional[str]]:
"""Return (cost_usd, price_id) or (None, None) if no price found."""
if ts is None:
ts = datetime.now(timezone.utc)
ts_date = ts.date().isoformat()
# Find all active prices for this model at ts
cursor = db.model_prices.find({
"provider": provider,
"model": model,
"effective_from": {"$lte": ts_date},
"$or": [{"effective_to": None}, {"effective_to": {"$gt": ts_date}}],
})
records = await cursor.to_list(length=100)
if not records:
return None, None
# Pick highest-priority source
records.sort(key=lambda r: SOURCE_PRIORITY.get(r.get("source", "litellm"), 0), reverse=True)
rec = records[0]
ppu = rec["price_per_unit_usd"]
billing_unit = rec["billing_unit"]
price_id = str(rec["_id"])
if billing_unit == "token_input":
cost = ppu * units.get("input_tokens", 0)
elif billing_unit == "token_output":
cost = ppu * units.get("output_tokens", 0)
elif billing_unit == "char":
cost = ppu * units.get("chars", 0)
elif billing_unit == "second":
cost = ppu * units.get("seconds", 0)
elif billing_unit == "request":
cost = ppu
else:
return None, None
return round(cost, 8), price_id
_PROVIDER_ALIAS = {
"google": "vertex_ai-language-models",
"gemini": "vertex_ai-language-models",
}
async def compute_total_cost(
db: AsyncIOMotorDatabase,
provider: str,
model: str,
units: dict,
ts: Optional[datetime] = None,
) -> tuple[Optional[float], Optional[str]]:
"""Sum input+output token costs for token-based models."""
if ts is None:
ts = datetime.now(timezone.utc)
ts_date = ts.date().isoformat()
provider = _PROVIDER_ALIAS.get(provider, provider)
cursor = db.model_prices.find({
"provider": provider,
"model": model,
"effective_from": {"$lte": ts_date},
"$or": [{"effective_to": None}, {"effective_to": {"$gt": ts_date}}],
})
records = await cursor.to_list(length=100)
if not records:
return None, None
by_unit: dict[str, dict] = {}
for r in records:
u = r["billing_unit"]
existing = by_unit.get(u)
if not existing or SOURCE_PRIORITY.get(r["source"], 0) > SOURCE_PRIORITY.get(existing["source"], 0):
by_unit[u] = r
total_cost = 0.0
price_id = None
for u, rec in by_unit.items():
ppu = rec["price_per_unit_usd"]
price_id = str(rec["_id"])
if u == "token_input":
total_cost += ppu * units.get("token_input", units.get("input_tokens", 0))
elif u == "token_output":
total_cost += ppu * units.get("token_output", units.get("output_tokens", 0))
elif u == "char":
total_cost += ppu * units.get("char", units.get("chars", 0))
elif u == "second":
total_cost += ppu * units.get("second", units.get("seconds", 0))
elif u == "request":
total_cost += ppu
return round(total_cost, 8), price_id