Fix backfill pricing: read from model_pricing collection + --delete-existing-estimates flag
This commit is contained in:
parent
66c8e1762e
commit
d7ee22e557
1 changed files with 41 additions and 15 deletions
|
|
@ -45,22 +45,39 @@ def _estimate_tokens(text: str, model: str) -> dict:
|
|||
return {"prompt": n, "completion": 0}
|
||||
|
||||
|
||||
def _estimate_cost(prompt_tokens: int, completion_tokens: int, model: str) -> float:
|
||||
"""Rough cost estimate in USD."""
|
||||
rate_per_m = {
|
||||
"gemini": (0.35, 1.05),
|
||||
"gpt-4": (30.00, 60.00),
|
||||
"gpt-3": (0.50, 1.50),
|
||||
}
|
||||
key = "gemini"
|
||||
if model:
|
||||
m = model.lower()
|
||||
if "gpt-4" in m or "gpt-5" in m:
|
||||
key = "gpt-4"
|
||||
elif "gpt-3" in m:
|
||||
key = "gpt-3"
|
||||
_pricing_cache: dict = {}
|
||||
|
||||
input_rate, output_rate = rate_per_m[key]
|
||||
def _load_pricing(db) -> None:
|
||||
"""Load current pricing from model_pricing collection into cache."""
|
||||
for row in db.model_pricing.find({"effective_until": None}):
|
||||
model = row.get("model", "")
|
||||
tiers = row.get("tiers") or []
|
||||
if tiers:
|
||||
t = tiers[0]
|
||||
_pricing_cache[model] = (
|
||||
t.get("input_per_mtok", 2.0),
|
||||
t.get("output_per_mtok", 12.0),
|
||||
)
|
||||
|
||||
|
||||
def _estimate_cost(prompt_tokens: int, completion_tokens: int, model: str) -> float:
|
||||
"""Cost estimate in USD using model_pricing collection rates."""
|
||||
# Try exact match, then prefix match
|
||||
rates = _pricing_cache.get(model)
|
||||
if not rates:
|
||||
for key, val in _pricing_cache.items():
|
||||
if model and key and (key in model or model in key):
|
||||
rates = val
|
||||
break
|
||||
# Final fallback matching seed_model_pricing.py values
|
||||
if not rates:
|
||||
m = (model or "").lower()
|
||||
if "gpt-5" in m or "gpt-4" in m:
|
||||
rates = (2.50, 15.00) # gpt-5.4 pricing from seed
|
||||
else:
|
||||
rates = (2.00, 12.00) # gemini-3.1-pro-preview pricing from seed
|
||||
|
||||
input_rate, output_rate = rates
|
||||
cost = (prompt_tokens / 1_000_000) * input_rate + (completion_tokens / 1_000_000) * output_rate
|
||||
return round(cost, 8)
|
||||
|
||||
|
|
@ -257,6 +274,8 @@ def backfill_personas(db, dry_run: bool) -> int:
|
|||
def main():
|
||||
parser = argparse.ArgumentParser(description="Backfill usage_events from existing data")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Preview without writing")
|
||||
parser.add_argument("--delete-existing-estimates", action="store_true",
|
||||
help="Delete previously created estimated events before backfilling")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dry_run:
|
||||
|
|
@ -264,6 +283,13 @@ def main():
|
|||
|
||||
db = connect()
|
||||
|
||||
if args.delete_existing_estimates and not args.dry_run:
|
||||
result = db.usage_events.delete_many({"is_estimated": True})
|
||||
print(f"Deleted {result.deleted_count} existing estimated events\n")
|
||||
|
||||
_load_pricing(db)
|
||||
print(f"Loaded {len(_pricing_cache)} pricing rows: {list(_pricing_cache.keys())}")
|
||||
|
||||
total = 0
|
||||
total += backfill_messages(db, args.dry_run)
|
||||
total += backfill_personas(db, args.dry_run)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue