- New services/cost_tracker.py: sync httpx preflight()/record() + async wrappers; BudgetExceeded exception; no-op when COST_TRACKER_BASE_URL is empty - Preflight budget check added before ingestion (Gemini), per-language translation (video-native + traditional), and per-language TTS dispatch - _record_gemini_usage and _record_tts_cost now call cost_tracker directly; removes broken asyncio.get_event_loop() hack from sync Celery worker - Fix: _cost_ctx now threaded into extract_accessibility_targeted (video-native path) - Fix: user_id/cost_project_id now propagated through dispatch_language_tts → synthesize_cue_task.s() and the rerender_accessible_video.py re-render path - Remove oliver-cost-tracker SDK dependency (was commented-out/never installed) - Drop cost_tracker_outbox_path setting and get_cost_tracker() factory - Update COST_TRACKER_BASE_URL default to optical-dev.oliver.solutions in .env.prod.example, docker-compose.yml, and all Cloud Run service yamls - Cloud Run yamls use Secret Manager ref (cost-tracker-api-key) for the API key Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
569 lines
18 KiB
Python
569 lines
18 KiB
Python
"""
|
|
TTS Synthesis Tasks - Per-cue parallel synthesis for audio descriptions.
|
|
|
|
This module provides Celery tasks for synthesizing audio description cues
|
|
in parallel using a dedicated TTS worker with concurrency=8.
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import io
|
|
import time
|
|
from typing import Any, Optional
|
|
|
|
from celery import group
|
|
from celery.result import AsyncResult
|
|
from pydub import AudioSegment
|
|
|
|
from ..core.config import settings
|
|
from ..core.logging import get_logger
|
|
from ..services.gcs import gcs_service
|
|
from ..services.gemini_tts import gemini_tts_service, TTSSynthesisError
|
|
from ..services.tts import tts_service
|
|
from . import celery_app
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
_TTS_PROVIDER_MODEL_MAP = {
|
|
# (provider, model) → cost-tracker provider + model strings
|
|
"gemini": "google",
|
|
"google": "google_tts",
|
|
"elevenlabs": "elevenlabs",
|
|
}
|
|
|
|
_TTS_MODEL_STRINGS = {
|
|
"flash": "gemini-2.5-flash-preview-tts",
|
|
"pro": "gemini-2.5-pro-preview-tts",
|
|
"standard": "standard",
|
|
"wavenet": "wavenet",
|
|
"neural2": "neural2",
|
|
"elevenlabs": "eleven_multilingual_v2",
|
|
}
|
|
|
|
|
|
def _record_tts_cost(
|
|
provider: str,
|
|
model: str,
|
|
text: str,
|
|
user_id: str,
|
|
job_id: str,
|
|
project_id: Optional[str],
|
|
latency_ms: int,
|
|
) -> None:
|
|
try:
|
|
from ..services.cost_tracker import record
|
|
record(
|
|
model=_TTS_MODEL_STRINGS.get(model, model),
|
|
provider=_TTS_PROVIDER_MODEL_MAP.get(provider, provider),
|
|
user_external_id=user_id,
|
|
project_id=project_id,
|
|
job_external_id=job_id,
|
|
chars=len(text),
|
|
latency_ms=latency_ms,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Cost tracker TTS record failed (non-fatal): {e}")
|
|
|
|
|
|
@celery_app.task(
|
|
bind=True,
|
|
queue="tts",
|
|
time_limit=120, # 2 minutes max per cue
|
|
soft_time_limit=100,
|
|
max_retries=3,
|
|
default_retry_delay=2
|
|
)
|
|
def synthesize_cue_task(
|
|
self,
|
|
job_id: str,
|
|
language: str,
|
|
cue_index: int,
|
|
text: str,
|
|
start_time: float,
|
|
end_time: float,
|
|
voice_name: Optional[str],
|
|
provider: str,
|
|
model: str,
|
|
speed: float,
|
|
style_prompt: str,
|
|
stability: float = 0.5,
|
|
similarity_boost: float = 0.5,
|
|
user_id: Optional[str] = None,
|
|
cost_project_id: Optional[str] = None,
|
|
) -> dict:
|
|
"""
|
|
Synthesize a single AD cue and upload to GCS immediately.
|
|
|
|
This task runs on the dedicated TTS worker with concurrency=8,
|
|
allowing parallel synthesis of multiple cues.
|
|
|
|
Args:
|
|
job_id: Job identifier for GCS path construction
|
|
language: Language code (e.g., "en", "es")
|
|
cue_index: Zero-based cue index
|
|
text: AD text to synthesize
|
|
start_time: VTT cue start time (seconds)
|
|
end_time: VTT cue end time (seconds)
|
|
voice_name: TTS voice name (optional)
|
|
provider: TTS provider ("gemini", "google", "elevenlabs")
|
|
model: Model variant ("flash", "pro")
|
|
speed: Speech rate multiplier
|
|
style_prompt: Style instructions
|
|
|
|
Returns:
|
|
dict with cue_index, gcs_uri, duration, success, error_message
|
|
"""
|
|
start_ts = time.time()
|
|
logger.info(
|
|
f"TTS cue synthesis started: job={job_id}, lang={language}, "
|
|
f"cue={cue_index}, provider={provider}, attempt={self.request.retries + 1}/{self.max_retries + 1}"
|
|
)
|
|
|
|
try:
|
|
# Run async synthesis in sync context
|
|
audio_bytes, duration = _run_async(
|
|
_synthesize_single_cue(
|
|
text=text,
|
|
voice_name=voice_name,
|
|
language=language,
|
|
provider=provider,
|
|
model=model,
|
|
speed=speed,
|
|
style_prompt=style_prompt,
|
|
stability=stability,
|
|
similarity_boost=similarity_boost,
|
|
)
|
|
)
|
|
|
|
# Upload to GCS immediately
|
|
gcs_uri, content_hash = _upload_cue_to_gcs(job_id, language, cue_index, audio_bytes, text)
|
|
|
|
elapsed_ms = (time.time() - start_ts) * 1000
|
|
logger.info(
|
|
f"TTS cue synthesis complete: job={job_id}, lang={language}, "
|
|
f"cue={cue_index}, duration={duration:.2f}s, elapsed={elapsed_ms:.0f}ms"
|
|
)
|
|
|
|
# Record TTS cost (fire-and-forget)
|
|
_record_tts_cost(provider, model, text, user_id or "system", job_id, cost_project_id, int(elapsed_ms))
|
|
|
|
return {
|
|
"cue_index": cue_index,
|
|
"job_id": job_id,
|
|
"language": language,
|
|
"gcs_uri": gcs_uri,
|
|
"content_hash": content_hash,
|
|
"start_time": start_time,
|
|
"end_time": end_time,
|
|
"duration": duration,
|
|
"text": text,
|
|
"success": True,
|
|
"error_message": None
|
|
}
|
|
|
|
except Exception as e:
|
|
error_message = str(e)
|
|
logger.warning(
|
|
f"TTS cue synthesis attempt failed: job={job_id}, lang={language}, "
|
|
f"cue={cue_index}, attempt={self.request.retries + 1}/{self.max_retries + 1}, error={e}"
|
|
)
|
|
|
|
# Check if we have retries left
|
|
if self.request.retries < self.max_retries:
|
|
# Calculate backoff delay with jitter
|
|
import random
|
|
delay = (2 ** self.request.retries) + random.uniform(0, 1)
|
|
logger.info(
|
|
f"Retrying TTS cue {cue_index} in {delay:.1f}s "
|
|
f"(attempt {self.request.retries + 2}/{self.max_retries + 1})"
|
|
)
|
|
raise self.retry(exc=e, countdown=delay)
|
|
else:
|
|
# Max retries exhausted - return failure result instead of raising
|
|
logger.error(
|
|
f"TTS cue synthesis FAILED after {self.max_retries + 1} attempts: "
|
|
f"job={job_id}, lang={language}, cue={cue_index}, error={e}"
|
|
)
|
|
return {
|
|
"cue_index": cue_index,
|
|
"job_id": job_id,
|
|
"language": language,
|
|
"gcs_uri": None,
|
|
"start_time": start_time,
|
|
"end_time": end_time,
|
|
"duration": 0.0,
|
|
"text": text,
|
|
"success": False,
|
|
"error_message": error_message
|
|
}
|
|
|
|
|
|
async def _synthesize_single_cue(
|
|
text: str,
|
|
voice_name: Optional[str],
|
|
language: str,
|
|
provider: str,
|
|
model: str,
|
|
speed: float,
|
|
style_prompt: str,
|
|
stability: float = 0.5,
|
|
similarity_boost: float = 0.5,
|
|
) -> tuple[bytes, float]:
|
|
"""
|
|
Synthesize a single cue's text to audio.
|
|
|
|
Returns:
|
|
Tuple of (audio_bytes, duration_seconds)
|
|
"""
|
|
# Ensure proper punctuation for natural TTS flow
|
|
text = text.strip()
|
|
if text and not text.endswith(('.', '!', '?')):
|
|
text += "."
|
|
|
|
# Extract simple language code for Gemini (e.g., "en-US" -> "en")
|
|
simple_lang = language.split("-")[0] if "-" in language else language
|
|
|
|
# Build ordered provider list — configured provider first, then fallbacks
|
|
if provider == "gemini":
|
|
providers_to_try = ["gemini"]
|
|
elif provider == "google":
|
|
providers_to_try = ["google", "gemini"]
|
|
elif provider == "elevenlabs":
|
|
providers_to_try = ["elevenlabs", "google", "gemini"]
|
|
else:
|
|
raise ValueError(f"Unknown TTS provider: {provider}")
|
|
|
|
audio_bytes: Optional[bytes] = None
|
|
last_error: Optional[Exception] = None
|
|
|
|
for attempt_provider in providers_to_try:
|
|
try:
|
|
if attempt_provider == "gemini":
|
|
audio_bytes = await gemini_tts_service.synthesize_text(
|
|
text,
|
|
voice_name or gemini_tts_service.default_voice,
|
|
simple_lang,
|
|
model=model,
|
|
speed=speed,
|
|
style_prompt=style_prompt
|
|
)
|
|
elif attempt_provider == "google":
|
|
language_code = f"{simple_lang}-US" if simple_lang == "en" else f"{simple_lang}-{simple_lang.upper()}"
|
|
audio_bytes = await tts_service._synthesize_text_google(text, language_code, voice_name)
|
|
elif attempt_provider == "elevenlabs":
|
|
language_code = f"{simple_lang}-US" if simple_lang == "en" else f"{simple_lang}-{simple_lang.upper()}"
|
|
voice_id = tts_service._get_elevenlabs_voice(language_code, voice_name)
|
|
audio_bytes = await tts_service._synthesize_text_elevenlabs(
|
|
text, voice_id,
|
|
stability=stability, similarity_boost=similarity_boost,
|
|
)
|
|
if audio_bytes:
|
|
if attempt_provider != provider:
|
|
logger.warning(
|
|
f"TTS provider '{provider}' failed — used fallback '{attempt_provider}' successfully"
|
|
)
|
|
break
|
|
except Exception as e:
|
|
last_error = e
|
|
logger.warning(f"TTS provider '{attempt_provider}' failed: {e}")
|
|
|
|
if not audio_bytes:
|
|
raise last_error or ValueError("All TTS providers failed")
|
|
|
|
# Get actual duration from audio
|
|
audio_segment = AudioSegment.from_file(io.BytesIO(audio_bytes), format="mp3")
|
|
duration = len(audio_segment) / 1000.0 # Convert ms to seconds
|
|
|
|
return audio_bytes, duration
|
|
|
|
|
|
def _upload_cue_to_gcs(job_id: str, language: str, cue_index: int, audio_bytes: bytes, text: str = "") -> tuple[str, str]:
|
|
"""
|
|
Upload a cue's audio to GCS.
|
|
|
|
Path convention: gs://{bucket}/{job_id}/{language}/ad_cues/cue_{index}_{hash}.mp3
|
|
The content hash ensures files are stable across cue insertions/deletions — the
|
|
hash is derived from the cue text so unchanged cues always map to the same file.
|
|
|
|
Returns:
|
|
Tuple of (full GCS URI, content_hash)
|
|
"""
|
|
content_hash = hashlib.sha256(text.encode()).hexdigest()[:12]
|
|
blob_path = f"{job_id}/{language}/ad_cues/cue_{cue_index}_{content_hash}.mp3"
|
|
blob = gcs_service.bucket.blob(blob_path)
|
|
blob.content_type = "audio/mpeg"
|
|
blob.upload_from_string(audio_bytes, content_type="audio/mpeg")
|
|
|
|
gcs_uri = f"gs://{settings.gcs_bucket}/{blob_path}"
|
|
logger.debug(f"Uploaded TTS cue to {gcs_uri}")
|
|
|
|
return gcs_uri, content_hash
|
|
|
|
|
|
def parse_cue_index_from_blob_name(blob_name: str) -> Optional[int]:
|
|
"""
|
|
Parse cue index from GCS blob name, supporting both filename formats:
|
|
- Legacy: ...ad_cues/cue_0.mp3 → 0
|
|
- Current: ...ad_cues/cue_0_abc123def456.mp3 → 0
|
|
"""
|
|
filename = blob_name.split("/")[-1]
|
|
if not filename.startswith("cue_") or not filename.endswith(".mp3"):
|
|
return None
|
|
inner = filename[4:-4] # strip 'cue_' prefix and '.mp3' suffix
|
|
try:
|
|
return int(inner.split("_")[0])
|
|
except (ValueError, IndexError):
|
|
return None
|
|
|
|
|
|
def _run_async(coro):
|
|
"""Run an async coroutine in a sync context."""
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
# Create a new loop if current one is running
|
|
import concurrent.futures
|
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
|
future = pool.submit(asyncio.run, coro)
|
|
return future.result()
|
|
return loop.run_until_complete(coro)
|
|
except RuntimeError:
|
|
# No event loop exists
|
|
return asyncio.run(coro)
|
|
|
|
|
|
def dispatch_language_tts(
|
|
job_id: str,
|
|
language: str,
|
|
cues: list[dict],
|
|
tts_preferences: dict,
|
|
user_id: Optional[str] = None,
|
|
cost_project_id: Optional[str] = None,
|
|
) -> AsyncResult:
|
|
"""
|
|
Dispatch a group of cue synthesis tasks for a language.
|
|
|
|
This creates a Celery group that will execute cues in parallel
|
|
(up to worker concurrency limit of 8).
|
|
|
|
Args:
|
|
job_id: Job identifier
|
|
language: Language code
|
|
cues: List of parsed VTT cues with start_time, end_time, text
|
|
tts_preferences: TTS configuration dict
|
|
|
|
Returns:
|
|
AsyncResult that can be polled for completion
|
|
"""
|
|
# Extract TTS settings
|
|
voices_per_language = tts_preferences.get("voices_per_language", {})
|
|
voice_name = voices_per_language.get(language, tts_preferences.get("default_voice"))
|
|
provider = tts_preferences.get("provider", "gemini")
|
|
model = tts_preferences.get("model", "flash")
|
|
speed = tts_preferences.get("speed", 1.0)
|
|
style_preset = tts_preferences.get("style_preset", "neutral")
|
|
custom_style_prompt = tts_preferences.get("custom_style_prompt")
|
|
stability = tts_preferences.get("stability") if tts_preferences.get("stability") is not None else 0.5
|
|
similarity_boost = tts_preferences.get("similarity_boost") if tts_preferences.get("similarity_boost") is not None else 0.5
|
|
|
|
# Resolve style prompt from preset or custom
|
|
if style_preset == "custom" and custom_style_prompt:
|
|
style_prompt = custom_style_prompt
|
|
else:
|
|
style_prompt = settings.gemini_tts_style_prompts.get(style_preset, "")
|
|
|
|
logger.info(
|
|
f"Dispatching {len(cues)} TTS cue tasks for job={job_id}, lang={language}, "
|
|
f"provider={provider}, model={model}, speed={speed}x"
|
|
)
|
|
|
|
# Build list of cue task signatures
|
|
cue_tasks = [
|
|
synthesize_cue_task.s(
|
|
job_id=job_id,
|
|
language=language,
|
|
cue_index=i,
|
|
text=cue["text"],
|
|
start_time=cue["start_time"],
|
|
end_time=cue["end_time"],
|
|
voice_name=voice_name,
|
|
provider=provider,
|
|
model=model,
|
|
speed=speed,
|
|
style_prompt=style_prompt,
|
|
stability=stability,
|
|
similarity_boost=similarity_boost,
|
|
user_id=user_id,
|
|
cost_project_id=cost_project_id,
|
|
)
|
|
for i, cue in enumerate(cues)
|
|
if cue.get("text", "").strip() # Skip empty cues
|
|
]
|
|
|
|
if not cue_tasks:
|
|
logger.warning(f"No valid cues to synthesize for job={job_id}, lang={language}")
|
|
return None
|
|
|
|
# Create and dispatch group (parallel execution)
|
|
task_group = group(cue_tasks)
|
|
result = task_group.apply_async()
|
|
|
|
logger.info(
|
|
f"Dispatched TTS group: job={job_id}, lang={language}, "
|
|
f"task_count={len(cue_tasks)}, group_id={result.id}"
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
def parse_ad_cues(vtt_content: str) -> list[dict]:
|
|
"""
|
|
Parse audio description VTT and extract timing + text.
|
|
|
|
Returns list of dicts with start_time, end_time, text.
|
|
"""
|
|
lines = vtt_content.strip().split('\n')
|
|
cues = []
|
|
|
|
i = 0
|
|
while i < len(lines):
|
|
line = lines[i].strip()
|
|
|
|
# Skip header and empty lines
|
|
if line == "WEBVTT" or line == "" or line.startswith("NOTE"):
|
|
i += 1
|
|
continue
|
|
|
|
# Check for timing line
|
|
if " --> " in line:
|
|
timing_parts = line.split(" --> ")
|
|
start_time = _parse_timestamp(timing_parts[0].strip())
|
|
end_time = _parse_timestamp(timing_parts[1].strip())
|
|
|
|
# Get text from next line(s)
|
|
i += 1
|
|
text_lines = []
|
|
while i < len(lines) and lines[i].strip() != "":
|
|
text_lines.append(lines[i].strip())
|
|
i += 1
|
|
|
|
if text_lines:
|
|
cues.append({
|
|
"start_time": start_time,
|
|
"end_time": end_time,
|
|
"text": " ".join(text_lines)
|
|
})
|
|
else:
|
|
i += 1
|
|
|
|
return cues
|
|
|
|
|
|
def _parse_timestamp(timestamp: str) -> float:
|
|
"""Convert VTT timestamp to seconds."""
|
|
parts = timestamp.split(":")
|
|
|
|
if len(parts) == 3: # HH:MM:SS.mmm
|
|
hours, minutes, seconds = parts
|
|
elif len(parts) == 2: # MM:SS.mmm
|
|
hours, minutes, seconds = "0", parts[0], parts[1]
|
|
else:
|
|
raise ValueError(f"Invalid timestamp format: {timestamp}")
|
|
|
|
# Parse seconds and milliseconds
|
|
sec_parts = seconds.split(".")
|
|
seconds_val = int(sec_parts[0])
|
|
milliseconds = int(sec_parts[1]) if len(sec_parts) > 1 else 0
|
|
|
|
total_seconds = (
|
|
int(hours) * 3600 +
|
|
int(minutes) * 60 +
|
|
seconds_val +
|
|
milliseconds / 1000.0
|
|
)
|
|
|
|
return total_seconds
|
|
|
|
|
|
def update_vtt_cue_text(vtt_content: str, cue_index: int, new_text: str) -> str:
|
|
"""
|
|
Update a specific cue's text in VTT content.
|
|
|
|
Args:
|
|
vtt_content: Original VTT file content
|
|
cue_index: Zero-based index of cue to update
|
|
new_text: New text for the cue
|
|
|
|
Returns:
|
|
Updated VTT content
|
|
"""
|
|
lines = vtt_content.strip().split('\n')
|
|
result_lines = []
|
|
current_cue = -1
|
|
i = 0
|
|
|
|
while i < len(lines):
|
|
line = lines[i]
|
|
|
|
# Skip header and notes
|
|
if line.strip() == "WEBVTT" or line.strip() == "" or line.strip().startswith("NOTE"):
|
|
result_lines.append(line)
|
|
i += 1
|
|
continue
|
|
|
|
# Check for timing line
|
|
if " --> " in line:
|
|
current_cue += 1
|
|
result_lines.append(line)
|
|
i += 1
|
|
|
|
# Process text lines for this cue
|
|
if current_cue == cue_index:
|
|
# Skip old text lines
|
|
while i < len(lines) and lines[i].strip() != "":
|
|
i += 1
|
|
# Add new text
|
|
result_lines.append(new_text)
|
|
else:
|
|
# Keep existing text lines
|
|
while i < len(lines) and lines[i].strip() != "":
|
|
result_lines.append(lines[i])
|
|
i += 1
|
|
else:
|
|
result_lines.append(line)
|
|
i += 1
|
|
|
|
return '\n'.join(result_lines)
|
|
|
|
|
|
async def update_vtt_in_gcs(
|
|
job_id: str,
|
|
language: str,
|
|
cue_index: int,
|
|
new_text: str
|
|
) -> str:
|
|
"""
|
|
Update a cue in the AD VTT file stored in GCS.
|
|
|
|
Args:
|
|
job_id: Job identifier
|
|
language: Language code
|
|
cue_index: Index of cue to update
|
|
new_text: New text for the cue
|
|
|
|
Returns:
|
|
Updated VTT content
|
|
"""
|
|
# Download current VTT
|
|
vtt_blob_path = f"{job_id}/{language}/ad.vtt"
|
|
blob = gcs_service.bucket.blob(vtt_blob_path)
|
|
current_vtt = blob.download_as_text()
|
|
|
|
# Update the cue
|
|
updated_vtt = update_vtt_cue_text(current_vtt, cue_index, new_text)
|
|
|
|
# Upload back to GCS
|
|
blob.upload_from_string(updated_vtt, content_type="text/vtt")
|
|
|
|
logger.info(f"Updated VTT cue {cue_index} in GCS: {vtt_blob_path}")
|
|
return updated_vtt
|