- compute_total_cost: read token_input/token_output/char (new keys) with fallback to old input_tokens/output_tokens/chars for compat - _PROVIDER_ALIAS: google/gemini → vertex_ai-language-models - _infer_provider: gemini → vertex_ai-language-models Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
169 lines
5.5 KiB
Python
169 lines
5.5 KiB
Python
"""Pricing engine: load YAML on startup, LiteLLM sync, admin override.
|
|
|
|
Priority: override(3) > yaml(2) > litellm(1).
|
|
compute_cost() picks the highest-priority active record at the given timestamp.
|
|
"""
|
|
import re
|
|
from datetime import date, datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import yaml
|
|
from motor.motor_asyncio import AsyncIOMotorDatabase
|
|
|
|
from ..core.logging import get_logger
|
|
from ..models.pricing import ModelPrice, SOURCE_PRIORITY
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
PRICING_YAML = Path(__file__).parent.parent / "pricing" / "models.yaml"
|
|
|
|
|
|
async def load_yaml_prices(db: AsyncIOMotorDatabase) -> int:
|
|
"""Upsert YAML prices into model_prices. Never touches override records."""
|
|
if not PRICING_YAML.exists():
|
|
logger.warning("pricing/models.yaml not found — skipping YAML load")
|
|
return 0
|
|
|
|
raw = yaml.safe_load(PRICING_YAML.read_text())
|
|
count = 0
|
|
for entry in raw or []:
|
|
provider = entry["provider"]
|
|
model = entry["model"]
|
|
billing_unit = entry["billing_unit"]
|
|
effective_from = date.fromisoformat(entry["effective_from"])
|
|
|
|
# Compute price_per_unit from whichever key is present
|
|
if "price_per_unit_usd" in entry:
|
|
ppu = float(entry["price_per_unit_usd"])
|
|
elif "price_per_1k_usd" in entry:
|
|
ppu = float(entry["price_per_1k_usd"]) / 1000
|
|
elif "price_per_1m_usd" in entry:
|
|
ppu = float(entry["price_per_1m_usd"]) / 1_000_000
|
|
else:
|
|
logger.warning(f"No price key in YAML entry for {provider}/{model}")
|
|
continue
|
|
|
|
existing = await db.model_prices.find_one({
|
|
"provider": provider,
|
|
"model": model,
|
|
"source": "yaml",
|
|
"effective_from": effective_from.isoformat(),
|
|
})
|
|
if not existing:
|
|
await db.model_prices.insert_one({
|
|
"provider": provider,
|
|
"model": model,
|
|
"billing_unit": billing_unit,
|
|
"price_per_unit_usd": ppu,
|
|
"currency": "USD",
|
|
"effective_from": effective_from.isoformat(),
|
|
"effective_to": None,
|
|
"source": "yaml",
|
|
"created_at": datetime.now(timezone.utc),
|
|
})
|
|
count += 1
|
|
|
|
logger.info(f"YAML pricing: {count} new records upserted")
|
|
return count
|
|
|
|
|
|
async def compute_cost(
|
|
db: AsyncIOMotorDatabase,
|
|
provider: str,
|
|
model: str,
|
|
units: dict, # e.g. {"input_tokens": 1000, "output_tokens": 200}
|
|
ts: Optional[datetime] = None,
|
|
) -> tuple[Optional[float], Optional[str]]:
|
|
"""Return (cost_usd, price_id) or (None, None) if no price found."""
|
|
if ts is None:
|
|
ts = datetime.now(timezone.utc)
|
|
ts_date = ts.date().isoformat()
|
|
|
|
# Find all active prices for this model at ts
|
|
cursor = db.model_prices.find({
|
|
"provider": provider,
|
|
"model": model,
|
|
"effective_from": {"$lte": ts_date},
|
|
"$or": [{"effective_to": None}, {"effective_to": {"$gt": ts_date}}],
|
|
})
|
|
records = await cursor.to_list(length=100)
|
|
if not records:
|
|
return None, None
|
|
|
|
# Pick highest-priority source
|
|
records.sort(key=lambda r: SOURCE_PRIORITY.get(r.get("source", "litellm"), 0), reverse=True)
|
|
rec = records[0]
|
|
ppu = rec["price_per_unit_usd"]
|
|
billing_unit = rec["billing_unit"]
|
|
price_id = str(rec["_id"])
|
|
|
|
if billing_unit == "token_input":
|
|
cost = ppu * units.get("input_tokens", 0)
|
|
elif billing_unit == "token_output":
|
|
cost = ppu * units.get("output_tokens", 0)
|
|
elif billing_unit == "char":
|
|
cost = ppu * units.get("chars", 0)
|
|
elif billing_unit == "second":
|
|
cost = ppu * units.get("seconds", 0)
|
|
elif billing_unit == "request":
|
|
cost = ppu
|
|
else:
|
|
return None, None
|
|
|
|
return round(cost, 8), price_id
|
|
|
|
|
|
_PROVIDER_ALIAS = {
|
|
"google": "vertex_ai-language-models",
|
|
"gemini": "vertex_ai-language-models",
|
|
}
|
|
|
|
|
|
async def compute_total_cost(
|
|
db: AsyncIOMotorDatabase,
|
|
provider: str,
|
|
model: str,
|
|
units: dict,
|
|
ts: Optional[datetime] = None,
|
|
) -> tuple[Optional[float], Optional[str]]:
|
|
"""Sum input+output token costs for token-based models."""
|
|
if ts is None:
|
|
ts = datetime.now(timezone.utc)
|
|
ts_date = ts.date().isoformat()
|
|
provider = _PROVIDER_ALIAS.get(provider, provider)
|
|
|
|
cursor = db.model_prices.find({
|
|
"provider": provider,
|
|
"model": model,
|
|
"effective_from": {"$lte": ts_date},
|
|
"$or": [{"effective_to": None}, {"effective_to": {"$gt": ts_date}}],
|
|
})
|
|
records = await cursor.to_list(length=100)
|
|
if not records:
|
|
return None, None
|
|
|
|
by_unit: dict[str, dict] = {}
|
|
for r in records:
|
|
u = r["billing_unit"]
|
|
existing = by_unit.get(u)
|
|
if not existing or SOURCE_PRIORITY.get(r["source"], 0) > SOURCE_PRIORITY.get(existing["source"], 0):
|
|
by_unit[u] = r
|
|
|
|
total_cost = 0.0
|
|
price_id = None
|
|
for u, rec in by_unit.items():
|
|
ppu = rec["price_per_unit_usd"]
|
|
price_id = str(rec["_id"])
|
|
if u == "token_input":
|
|
total_cost += ppu * units.get("token_input", units.get("input_tokens", 0))
|
|
elif u == "token_output":
|
|
total_cost += ppu * units.get("token_output", units.get("output_tokens", 0))
|
|
elif u == "char":
|
|
total_cost += ppu * units.get("char", units.get("chars", 0))
|
|
elif u == "second":
|
|
total_cost += ppu * units.get("second", units.get("seconds", 0))
|
|
elif u == "request":
|
|
total_cost += ppu
|
|
|
|
return round(total_cost, 8), price_id
|