- Add google_tts/google_tts entry to models.yaml (16 USD/1M chars, WaveNet tier) - Add scripts/migrate_video_accessibility.py for historical data backfill: migrates 25 users and 103 jobs (Gemini + TTS usage records) from accessible_video MongoDB into cost tracker Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
250 lines
9 KiB
Python
250 lines
9 KiB
Python
"""
|
|
Backfill migration: video-accessibility → ai-cost-tracker
|
|
|
|
Reads completed jobs from accessible_video MongoDB, estimates token usage
|
|
from stored text, and POSTs to cost tracker API.
|
|
|
|
Usage (run inside optical where both MongoDBs are accessible via SSH tunnel):
|
|
python migrate_video_accessibility.py [--dry-run] [--limit N]
|
|
|
|
Required env vars:
|
|
SOURCE_MONGO_URL — accessible_video MongoDB (default: mongodb://mongodb:27017)
|
|
COST_TRACKER_URL — cost tracker API base URL
|
|
COST_TRACKER_KEY — API key with record + upsert scopes
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from datetime import datetime, timezone
|
|
|
|
try:
|
|
import pymongo
|
|
import requests
|
|
except ImportError:
|
|
print("pip install pymongo requests")
|
|
sys.exit(1)
|
|
|
|
CONFIG = {
|
|
"source_mongo_url": os.getenv("SOURCE_MONGO_URL", "mongodb://mongodb:27017"),
|
|
"source_db": "accessible_video",
|
|
"cost_tracker_url": os.getenv("COST_TRACKER_URL", "http://localhost:8000"),
|
|
"cost_tracker_key": os.getenv("COST_TRACKER_KEY", ""),
|
|
"source_app": "video-accessibility",
|
|
"gemini_model": "gemini-2.0-flash-001",
|
|
"gemini_provider": "google",
|
|
"tts_model": "google_tts",
|
|
"tts_provider": "google_tts",
|
|
"elevenlabs_model": "eleven_multilingual_v2",
|
|
"elevenlabs_provider": "elevenlabs",
|
|
"chars_per_token": 4,
|
|
}
|
|
|
|
STATUSES_TO_MIGRATE = ["completed", "done", "success", "pending_qc", "pending_final_review"]
|
|
|
|
|
|
def estimate_gemini_tokens(job: dict) -> tuple[int, int]:
|
|
ingestion = job.get("ai", {}).get("ingestion_json", {})
|
|
if not isinstance(ingestion, dict):
|
|
ingestion = {}
|
|
|
|
# Input: whisper transcript that was fed to Gemini
|
|
input_text = ingestion.get("transcript_plaintext", "")
|
|
brand_ctx = job.get("brand_context", "") or ""
|
|
if isinstance(brand_ctx, str):
|
|
input_text += brand_ctx
|
|
|
|
# Output: audio descriptions + summary generated by Gemini
|
|
output_text = ingestion.get("audio_description_vtt", "") + ingestion.get("summary", "")
|
|
# Also count captions if they differ from the transcript (Gemini re-timed them)
|
|
captions = ingestion.get("captions_vtt", "")
|
|
output_text += captions
|
|
|
|
cpp = CONFIG["chars_per_token"]
|
|
input_tokens = max(50, len(input_text) // cpp)
|
|
output_tokens = max(20, len(output_text) // cpp)
|
|
return input_tokens, output_tokens
|
|
|
|
|
|
def estimate_tts_chars(job: dict) -> int:
|
|
ingestion = job.get("ai", {}).get("ingestion_json", {})
|
|
if not isinstance(ingestion, dict):
|
|
return 0
|
|
# audio_description_vtt contains the spoken text
|
|
ad_vtt = ingestion.get("audio_description_vtt", "")
|
|
# Strip VTT timestamps, count only the actual text lines
|
|
text_lines = [
|
|
line for line in ad_vtt.splitlines()
|
|
if line.strip() and not line.startswith("WEBVTT") and "-->" not in line
|
|
]
|
|
return len(" ".join(text_lines))
|
|
|
|
|
|
def has_elevenlabs(job: dict) -> bool:
|
|
ro = job.get("requested_outputs", {})
|
|
if isinstance(ro, dict):
|
|
return str(ro).lower().find("eleven") >= 0
|
|
if isinstance(ro, list):
|
|
return any("eleven" in str(o).lower() for o in ro)
|
|
return False
|
|
|
|
|
|
def api_headers() -> dict:
|
|
return {
|
|
"X-API-Key": CONFIG["cost_tracker_key"],
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
|
|
def post(path: str, payload: dict, dry_run: bool) -> bool:
|
|
if dry_run:
|
|
print(f" [DRY] POST {path} {payload}")
|
|
return True
|
|
resp = requests.post(
|
|
CONFIG["cost_tracker_url"] + path,
|
|
json=payload,
|
|
headers=api_headers(),
|
|
timeout=15,
|
|
)
|
|
if resp.status_code not in (200, 201):
|
|
print(f" [ERR] {resp.status_code}: {resp.text[:200]}")
|
|
return False
|
|
data = resp.json()
|
|
# For record endpoint show cost result
|
|
if "cost_usd" in data:
|
|
cost = data.get("cost_usd")
|
|
missing = data.get("pricing_missing", False)
|
|
print(f" cost={cost} pricing_missing={missing}")
|
|
return True
|
|
|
|
|
|
def migrate_users(db, dry_run: bool) -> int:
|
|
print("\n── Users ────────────────────────────────────────────────────")
|
|
count = 0
|
|
for user in db.users.find({"is_active": {"$ne": False}}):
|
|
uid = str(user["_id"])
|
|
email = user.get("email")
|
|
name = user.get("full_name")
|
|
role = user.get("role", "user")
|
|
# Skip demo/placeholder accounts
|
|
if email and "example.com" in email:
|
|
print(f" skip {uid} ({email}) — placeholder")
|
|
continue
|
|
print(f" upsert {uid} | {email} | {name}")
|
|
ok = post("/v1/users/upsert", {
|
|
"external_id": uid,
|
|
"email": email,
|
|
"full_name": name,
|
|
"role": role,
|
|
}, dry_run)
|
|
if ok:
|
|
count += 1
|
|
print(f"Users migrated: {count}")
|
|
return count
|
|
|
|
|
|
def migrate_jobs(db, dry_run: bool, limit: int | None) -> dict:
|
|
print("\n── Jobs ─────────────────────────────────────────────────────")
|
|
query = {"status": {"$in": STATUSES_TO_MIGRATE}}
|
|
total = db.jobs.count_documents(query)
|
|
cursor = db.jobs.find(query)
|
|
if limit:
|
|
cursor = cursor.limit(limit)
|
|
print(f"Jobs to process: {total}{f' (limit={limit})' if limit else ''}")
|
|
|
|
stats = {"processed": 0, "gemini_ok": 0, "tts_ok": 0, "skipped": 0, "errors": 0}
|
|
|
|
for job in cursor:
|
|
job_id = str(job["_id"])
|
|
user_id = str(job.get("client_id", "unknown"))
|
|
ts_raw = job.get("created_at") or job.get("updated_at") or datetime.now(timezone.utc)
|
|
ts = ts_raw if isinstance(ts_raw, datetime) else datetime.fromisoformat(str(ts_raw))
|
|
|
|
print(f"\n job={job_id} user={user_id} status={job.get('status')}")
|
|
|
|
# ── Gemini record ──
|
|
input_tok, output_tok = estimate_gemini_tokens(job)
|
|
print(f" gemini: ~{input_tok} in / {output_tok} out tokens")
|
|
ok = post("/v1/usage/record", {
|
|
"request_id": f"backfill-{job_id}-gemini",
|
|
"user_external_id": user_id,
|
|
"job_external_id": job_id,
|
|
"provider": CONFIG["gemini_provider"],
|
|
"model": CONFIG["gemini_model"],
|
|
"units": {"token_input": input_tok, "token_output": output_tok},
|
|
"status": "success",
|
|
"metadata": {"backfill": True, "source": "video-accessibility", "estimated": True},
|
|
}, dry_run)
|
|
if ok:
|
|
stats["gemini_ok"] += 1
|
|
else:
|
|
stats["errors"] += 1
|
|
|
|
# ── TTS record ──
|
|
tts_chars = estimate_tts_chars(job)
|
|
if tts_chars > 0:
|
|
tts_provider = CONFIG["elevenlabs_provider"] if has_elevenlabs(job) else CONFIG["tts_provider"]
|
|
tts_model = CONFIG["elevenlabs_model"] if has_elevenlabs(job) else CONFIG["tts_model"]
|
|
print(f" tts ({tts_provider}): ~{tts_chars} chars")
|
|
ok = post("/v1/usage/record", {
|
|
"request_id": f"backfill-{job_id}-tts",
|
|
"user_external_id": user_id,
|
|
"job_external_id": job_id,
|
|
"provider": tts_provider,
|
|
"model": tts_model,
|
|
"units": {"char": tts_chars},
|
|
"status": "success",
|
|
"metadata": {"backfill": True, "source": "video-accessibility", "estimated": True},
|
|
}, dry_run)
|
|
if ok:
|
|
stats["tts_ok"] += 1
|
|
else:
|
|
stats["errors"] += 1
|
|
else:
|
|
stats["skipped"] += 1
|
|
print(" tts: skipped (no audio description text found)")
|
|
|
|
stats["processed"] += 1
|
|
|
|
print(f"\n── Summary ──────────────────────────────────────────────────")
|
|
for k, v in stats.items():
|
|
print(f" {k}: {v}")
|
|
return stats
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--dry-run", action="store_true")
|
|
parser.add_argument("--limit", type=int, default=None)
|
|
parser.add_argument("--users-only", action="store_true")
|
|
parser.add_argument("--jobs-only", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
if not CONFIG["cost_tracker_key"] and not args.dry_run:
|
|
print("ERROR: COST_TRACKER_KEY not set. Use --dry-run or set env var.")
|
|
sys.exit(1)
|
|
|
|
print(f"Source: {CONFIG['source_mongo_url']}/{CONFIG['source_db']}")
|
|
print(f"Target: {CONFIG['cost_tracker_url']}")
|
|
print(f"Dry run: {args.dry_run}")
|
|
|
|
client = pymongo.MongoClient(CONFIG["source_mongo_url"], serverSelectionTimeoutMS=5000)
|
|
try:
|
|
client.server_info()
|
|
except Exception as e:
|
|
print(f"Cannot connect to source MongoDB: {e}")
|
|
sys.exit(1)
|
|
|
|
db = client[CONFIG["source_db"]]
|
|
|
|
if not args.jobs_only:
|
|
migrate_users(db, dry_run=args.dry_run)
|
|
if not args.users_only:
|
|
migrate_jobs(db, dry_run=args.dry_run, limit=args.limit)
|
|
|
|
client.close()
|
|
print("\nDone.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|