video-accessibility/backend/app/tasks/tts_synthesis.py
Vadym Samoilenko ea21cace96 feat: replace SDK with direct HTTP integration to centralized cost tracker
- New services/cost_tracker.py: sync httpx preflight()/record() + async wrappers;
  BudgetExceeded exception; no-op when COST_TRACKER_BASE_URL is empty
- Preflight budget check added before ingestion (Gemini), per-language translation
  (video-native + traditional), and per-language TTS dispatch
- _record_gemini_usage and _record_tts_cost now call cost_tracker directly;
  removes broken asyncio.get_event_loop() hack from sync Celery worker
- Fix: _cost_ctx now threaded into extract_accessibility_targeted (video-native path)
- Fix: user_id/cost_project_id now propagated through dispatch_language_tts →
  synthesize_cue_task.s() and the rerender_accessible_video.py re-render path
- Remove oliver-cost-tracker SDK dependency (was commented-out/never installed)
- Drop cost_tracker_outbox_path setting and get_cost_tracker() factory
- Update COST_TRACKER_BASE_URL default to optical-dev.oliver.solutions in
  .env.prod.example, docker-compose.yml, and all Cloud Run service yamls
- Cloud Run yamls use Secret Manager ref (cost-tracker-api-key) for the API key

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-27 13:36:15 +01:00

569 lines
18 KiB
Python

"""
TTS Synthesis Tasks - Per-cue parallel synthesis for audio descriptions.
This module provides Celery tasks for synthesizing audio description cues
in parallel using a dedicated TTS worker with concurrency=8.
"""
import asyncio
import hashlib
import io
import time
from typing import Any, Optional
from celery import group
from celery.result import AsyncResult
from pydub import AudioSegment
from ..core.config import settings
from ..core.logging import get_logger
from ..services.gcs import gcs_service
from ..services.gemini_tts import gemini_tts_service, TTSSynthesisError
from ..services.tts import tts_service
from . import celery_app
logger = get_logger(__name__)
_TTS_PROVIDER_MODEL_MAP = {
# (provider, model) → cost-tracker provider + model strings
"gemini": "google",
"google": "google_tts",
"elevenlabs": "elevenlabs",
}
_TTS_MODEL_STRINGS = {
"flash": "gemini-2.5-flash-preview-tts",
"pro": "gemini-2.5-pro-preview-tts",
"standard": "standard",
"wavenet": "wavenet",
"neural2": "neural2",
"elevenlabs": "eleven_multilingual_v2",
}
def _record_tts_cost(
provider: str,
model: str,
text: str,
user_id: str,
job_id: str,
project_id: Optional[str],
latency_ms: int,
) -> None:
try:
from ..services.cost_tracker import record
record(
model=_TTS_MODEL_STRINGS.get(model, model),
provider=_TTS_PROVIDER_MODEL_MAP.get(provider, provider),
user_external_id=user_id,
project_id=project_id,
job_external_id=job_id,
chars=len(text),
latency_ms=latency_ms,
)
except Exception as e:
logger.warning(f"Cost tracker TTS record failed (non-fatal): {e}")
@celery_app.task(
bind=True,
queue="tts",
time_limit=120, # 2 minutes max per cue
soft_time_limit=100,
max_retries=3,
default_retry_delay=2
)
def synthesize_cue_task(
self,
job_id: str,
language: str,
cue_index: int,
text: str,
start_time: float,
end_time: float,
voice_name: Optional[str],
provider: str,
model: str,
speed: float,
style_prompt: str,
stability: float = 0.5,
similarity_boost: float = 0.5,
user_id: Optional[str] = None,
cost_project_id: Optional[str] = None,
) -> dict:
"""
Synthesize a single AD cue and upload to GCS immediately.
This task runs on the dedicated TTS worker with concurrency=8,
allowing parallel synthesis of multiple cues.
Args:
job_id: Job identifier for GCS path construction
language: Language code (e.g., "en", "es")
cue_index: Zero-based cue index
text: AD text to synthesize
start_time: VTT cue start time (seconds)
end_time: VTT cue end time (seconds)
voice_name: TTS voice name (optional)
provider: TTS provider ("gemini", "google", "elevenlabs")
model: Model variant ("flash", "pro")
speed: Speech rate multiplier
style_prompt: Style instructions
Returns:
dict with cue_index, gcs_uri, duration, success, error_message
"""
start_ts = time.time()
logger.info(
f"TTS cue synthesis started: job={job_id}, lang={language}, "
f"cue={cue_index}, provider={provider}, attempt={self.request.retries + 1}/{self.max_retries + 1}"
)
try:
# Run async synthesis in sync context
audio_bytes, duration = _run_async(
_synthesize_single_cue(
text=text,
voice_name=voice_name,
language=language,
provider=provider,
model=model,
speed=speed,
style_prompt=style_prompt,
stability=stability,
similarity_boost=similarity_boost,
)
)
# Upload to GCS immediately
gcs_uri, content_hash = _upload_cue_to_gcs(job_id, language, cue_index, audio_bytes, text)
elapsed_ms = (time.time() - start_ts) * 1000
logger.info(
f"TTS cue synthesis complete: job={job_id}, lang={language}, "
f"cue={cue_index}, duration={duration:.2f}s, elapsed={elapsed_ms:.0f}ms"
)
# Record TTS cost (fire-and-forget)
_record_tts_cost(provider, model, text, user_id or "system", job_id, cost_project_id, int(elapsed_ms))
return {
"cue_index": cue_index,
"job_id": job_id,
"language": language,
"gcs_uri": gcs_uri,
"content_hash": content_hash,
"start_time": start_time,
"end_time": end_time,
"duration": duration,
"text": text,
"success": True,
"error_message": None
}
except Exception as e:
error_message = str(e)
logger.warning(
f"TTS cue synthesis attempt failed: job={job_id}, lang={language}, "
f"cue={cue_index}, attempt={self.request.retries + 1}/{self.max_retries + 1}, error={e}"
)
# Check if we have retries left
if self.request.retries < self.max_retries:
# Calculate backoff delay with jitter
import random
delay = (2 ** self.request.retries) + random.uniform(0, 1)
logger.info(
f"Retrying TTS cue {cue_index} in {delay:.1f}s "
f"(attempt {self.request.retries + 2}/{self.max_retries + 1})"
)
raise self.retry(exc=e, countdown=delay)
else:
# Max retries exhausted - return failure result instead of raising
logger.error(
f"TTS cue synthesis FAILED after {self.max_retries + 1} attempts: "
f"job={job_id}, lang={language}, cue={cue_index}, error={e}"
)
return {
"cue_index": cue_index,
"job_id": job_id,
"language": language,
"gcs_uri": None,
"start_time": start_time,
"end_time": end_time,
"duration": 0.0,
"text": text,
"success": False,
"error_message": error_message
}
async def _synthesize_single_cue(
text: str,
voice_name: Optional[str],
language: str,
provider: str,
model: str,
speed: float,
style_prompt: str,
stability: float = 0.5,
similarity_boost: float = 0.5,
) -> tuple[bytes, float]:
"""
Synthesize a single cue's text to audio.
Returns:
Tuple of (audio_bytes, duration_seconds)
"""
# Ensure proper punctuation for natural TTS flow
text = text.strip()
if text and not text.endswith(('.', '!', '?')):
text += "."
# Extract simple language code for Gemini (e.g., "en-US" -> "en")
simple_lang = language.split("-")[0] if "-" in language else language
# Build ordered provider list — configured provider first, then fallbacks
if provider == "gemini":
providers_to_try = ["gemini"]
elif provider == "google":
providers_to_try = ["google", "gemini"]
elif provider == "elevenlabs":
providers_to_try = ["elevenlabs", "google", "gemini"]
else:
raise ValueError(f"Unknown TTS provider: {provider}")
audio_bytes: Optional[bytes] = None
last_error: Optional[Exception] = None
for attempt_provider in providers_to_try:
try:
if attempt_provider == "gemini":
audio_bytes = await gemini_tts_service.synthesize_text(
text,
voice_name or gemini_tts_service.default_voice,
simple_lang,
model=model,
speed=speed,
style_prompt=style_prompt
)
elif attempt_provider == "google":
language_code = f"{simple_lang}-US" if simple_lang == "en" else f"{simple_lang}-{simple_lang.upper()}"
audio_bytes = await tts_service._synthesize_text_google(text, language_code, voice_name)
elif attempt_provider == "elevenlabs":
language_code = f"{simple_lang}-US" if simple_lang == "en" else f"{simple_lang}-{simple_lang.upper()}"
voice_id = tts_service._get_elevenlabs_voice(language_code, voice_name)
audio_bytes = await tts_service._synthesize_text_elevenlabs(
text, voice_id,
stability=stability, similarity_boost=similarity_boost,
)
if audio_bytes:
if attempt_provider != provider:
logger.warning(
f"TTS provider '{provider}' failed — used fallback '{attempt_provider}' successfully"
)
break
except Exception as e:
last_error = e
logger.warning(f"TTS provider '{attempt_provider}' failed: {e}")
if not audio_bytes:
raise last_error or ValueError("All TTS providers failed")
# Get actual duration from audio
audio_segment = AudioSegment.from_file(io.BytesIO(audio_bytes), format="mp3")
duration = len(audio_segment) / 1000.0 # Convert ms to seconds
return audio_bytes, duration
def _upload_cue_to_gcs(job_id: str, language: str, cue_index: int, audio_bytes: bytes, text: str = "") -> tuple[str, str]:
"""
Upload a cue's audio to GCS.
Path convention: gs://{bucket}/{job_id}/{language}/ad_cues/cue_{index}_{hash}.mp3
The content hash ensures files are stable across cue insertions/deletions — the
hash is derived from the cue text so unchanged cues always map to the same file.
Returns:
Tuple of (full GCS URI, content_hash)
"""
content_hash = hashlib.sha256(text.encode()).hexdigest()[:12]
blob_path = f"{job_id}/{language}/ad_cues/cue_{cue_index}_{content_hash}.mp3"
blob = gcs_service.bucket.blob(blob_path)
blob.content_type = "audio/mpeg"
blob.upload_from_string(audio_bytes, content_type="audio/mpeg")
gcs_uri = f"gs://{settings.gcs_bucket}/{blob_path}"
logger.debug(f"Uploaded TTS cue to {gcs_uri}")
return gcs_uri, content_hash
def parse_cue_index_from_blob_name(blob_name: str) -> Optional[int]:
"""
Parse cue index from GCS blob name, supporting both filename formats:
- Legacy: ...ad_cues/cue_0.mp3 → 0
- Current: ...ad_cues/cue_0_abc123def456.mp3 → 0
"""
filename = blob_name.split("/")[-1]
if not filename.startswith("cue_") or not filename.endswith(".mp3"):
return None
inner = filename[4:-4] # strip 'cue_' prefix and '.mp3' suffix
try:
return int(inner.split("_")[0])
except (ValueError, IndexError):
return None
def _run_async(coro):
"""Run an async coroutine in a sync context."""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Create a new loop if current one is running
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
future = pool.submit(asyncio.run, coro)
return future.result()
return loop.run_until_complete(coro)
except RuntimeError:
# No event loop exists
return asyncio.run(coro)
def dispatch_language_tts(
job_id: str,
language: str,
cues: list[dict],
tts_preferences: dict,
user_id: Optional[str] = None,
cost_project_id: Optional[str] = None,
) -> AsyncResult:
"""
Dispatch a group of cue synthesis tasks for a language.
This creates a Celery group that will execute cues in parallel
(up to worker concurrency limit of 8).
Args:
job_id: Job identifier
language: Language code
cues: List of parsed VTT cues with start_time, end_time, text
tts_preferences: TTS configuration dict
Returns:
AsyncResult that can be polled for completion
"""
# Extract TTS settings
voices_per_language = tts_preferences.get("voices_per_language", {})
voice_name = voices_per_language.get(language, tts_preferences.get("default_voice"))
provider = tts_preferences.get("provider", "gemini")
model = tts_preferences.get("model", "flash")
speed = tts_preferences.get("speed", 1.0)
style_preset = tts_preferences.get("style_preset", "neutral")
custom_style_prompt = tts_preferences.get("custom_style_prompt")
stability = tts_preferences.get("stability") if tts_preferences.get("stability") is not None else 0.5
similarity_boost = tts_preferences.get("similarity_boost") if tts_preferences.get("similarity_boost") is not None else 0.5
# Resolve style prompt from preset or custom
if style_preset == "custom" and custom_style_prompt:
style_prompt = custom_style_prompt
else:
style_prompt = settings.gemini_tts_style_prompts.get(style_preset, "")
logger.info(
f"Dispatching {len(cues)} TTS cue tasks for job={job_id}, lang={language}, "
f"provider={provider}, model={model}, speed={speed}x"
)
# Build list of cue task signatures
cue_tasks = [
synthesize_cue_task.s(
job_id=job_id,
language=language,
cue_index=i,
text=cue["text"],
start_time=cue["start_time"],
end_time=cue["end_time"],
voice_name=voice_name,
provider=provider,
model=model,
speed=speed,
style_prompt=style_prompt,
stability=stability,
similarity_boost=similarity_boost,
user_id=user_id,
cost_project_id=cost_project_id,
)
for i, cue in enumerate(cues)
if cue.get("text", "").strip() # Skip empty cues
]
if not cue_tasks:
logger.warning(f"No valid cues to synthesize for job={job_id}, lang={language}")
return None
# Create and dispatch group (parallel execution)
task_group = group(cue_tasks)
result = task_group.apply_async()
logger.info(
f"Dispatched TTS group: job={job_id}, lang={language}, "
f"task_count={len(cue_tasks)}, group_id={result.id}"
)
return result
def parse_ad_cues(vtt_content: str) -> list[dict]:
"""
Parse audio description VTT and extract timing + text.
Returns list of dicts with start_time, end_time, text.
"""
lines = vtt_content.strip().split('\n')
cues = []
i = 0
while i < len(lines):
line = lines[i].strip()
# Skip header and empty lines
if line == "WEBVTT" or line == "" or line.startswith("NOTE"):
i += 1
continue
# Check for timing line
if " --> " in line:
timing_parts = line.split(" --> ")
start_time = _parse_timestamp(timing_parts[0].strip())
end_time = _parse_timestamp(timing_parts[1].strip())
# Get text from next line(s)
i += 1
text_lines = []
while i < len(lines) and lines[i].strip() != "":
text_lines.append(lines[i].strip())
i += 1
if text_lines:
cues.append({
"start_time": start_time,
"end_time": end_time,
"text": " ".join(text_lines)
})
else:
i += 1
return cues
def _parse_timestamp(timestamp: str) -> float:
"""Convert VTT timestamp to seconds."""
parts = timestamp.split(":")
if len(parts) == 3: # HH:MM:SS.mmm
hours, minutes, seconds = parts
elif len(parts) == 2: # MM:SS.mmm
hours, minutes, seconds = "0", parts[0], parts[1]
else:
raise ValueError(f"Invalid timestamp format: {timestamp}")
# Parse seconds and milliseconds
sec_parts = seconds.split(".")
seconds_val = int(sec_parts[0])
milliseconds = int(sec_parts[1]) if len(sec_parts) > 1 else 0
total_seconds = (
int(hours) * 3600 +
int(minutes) * 60 +
seconds_val +
milliseconds / 1000.0
)
return total_seconds
def update_vtt_cue_text(vtt_content: str, cue_index: int, new_text: str) -> str:
"""
Update a specific cue's text in VTT content.
Args:
vtt_content: Original VTT file content
cue_index: Zero-based index of cue to update
new_text: New text for the cue
Returns:
Updated VTT content
"""
lines = vtt_content.strip().split('\n')
result_lines = []
current_cue = -1
i = 0
while i < len(lines):
line = lines[i]
# Skip header and notes
if line.strip() == "WEBVTT" or line.strip() == "" or line.strip().startswith("NOTE"):
result_lines.append(line)
i += 1
continue
# Check for timing line
if " --> " in line:
current_cue += 1
result_lines.append(line)
i += 1
# Process text lines for this cue
if current_cue == cue_index:
# Skip old text lines
while i < len(lines) and lines[i].strip() != "":
i += 1
# Add new text
result_lines.append(new_text)
else:
# Keep existing text lines
while i < len(lines) and lines[i].strip() != "":
result_lines.append(lines[i])
i += 1
else:
result_lines.append(line)
i += 1
return '\n'.join(result_lines)
async def update_vtt_in_gcs(
job_id: str,
language: str,
cue_index: int,
new_text: str
) -> str:
"""
Update a cue in the AD VTT file stored in GCS.
Args:
job_id: Job identifier
language: Language code
cue_index: Index of cue to update
new_text: New text for the cue
Returns:
Updated VTT content
"""
# Download current VTT
vtt_blob_path = f"{job_id}/{language}/ad.vtt"
blob = gcs_service.bucket.blob(vtt_blob_path)
current_vtt = blob.download_as_text()
# Update the cue
updated_vtt = update_vtt_cue_text(current_vtt, cue_index, new_text)
# Upload back to GCS
blob.upload_from_string(updated_vtt, content_type="text/vtt")
logger.info(f"Updated VTT cue {cue_index} in GCS: {vtt_blob_path}")
return updated_vtt