diff --git a/backend/app/api/v1/routes_public.py b/backend/app/api/v1/routes_public.py index 99d493e..1c46c53 100644 --- a/backend/app/api/v1/routes_public.py +++ b/backend/app/api/v1/routes_public.py @@ -121,6 +121,7 @@ class RecordResponse(BaseModel): @router.post("/usage/record", response_model=RecordResponse) +@router.post("/record", response_model=RecordResponse, include_in_schema=False) async def record_usage( body: RecordRequest, request: Request, @@ -324,7 +325,7 @@ async def health(db: AsyncIOMotorDatabase = Depends(get_db)): def _infer_provider(model: str) -> str: model_lower = model.lower() if "gemini" in model_lower: - return "google" + return "vertex_ai-language-models" if "gpt" in model_lower or "o1" in model_lower or "o3" in model_lower: return "openai" if "claude" in model_lower: diff --git a/backend/app/services/pricing_engine.py b/backend/app/services/pricing_engine.py index 753d976..b8f9a05 100644 --- a/backend/app/services/pricing_engine.py +++ b/backend/app/services/pricing_engine.py @@ -114,6 +114,12 @@ async def compute_cost( return round(cost, 8), price_id +_PROVIDER_ALIAS = { + "google": "vertex_ai-language-models", + "gemini": "vertex_ai-language-models", +} + + async def compute_total_cost( db: AsyncIOMotorDatabase, provider: str, @@ -125,6 +131,7 @@ async def compute_total_cost( if ts is None: ts = datetime.now(timezone.utc) ts_date = ts.date().isoformat() + provider = _PROVIDER_ALIAS.get(provider, provider) cursor = db.model_prices.find({ "provider": provider, @@ -149,13 +156,13 @@ async def compute_total_cost( ppu = rec["price_per_unit_usd"] price_id = str(rec["_id"]) if u == "token_input": - total_cost += ppu * units.get("input_tokens", 0) + total_cost += ppu * units.get("token_input", units.get("input_tokens", 0)) elif u == "token_output": - total_cost += ppu * units.get("output_tokens", 0) + total_cost += ppu * units.get("token_output", units.get("output_tokens", 0)) elif u == "char": - total_cost += ppu * units.get("chars", 0) + total_cost += ppu * units.get("char", units.get("chars", 0)) elif u == "second": - total_cost += ppu * units.get("seconds", 0) + total_cost += ppu * units.get("second", units.get("seconds", 0)) elif u == "request": total_cost += ppu