Add comprehensive error handling for TTS synthesis failures: Backend: - Add TTS_FAILED status to JobStatus enum for failed synthesis jobs - Add TTSSynthesisError exception with cue index and context tracking - Improve null-safe error handling in Gemini TTS response parsing - Add _synthesize_cue_with_retry() with exponential backoff (3 attempts) - Enhanced error logging with text preview and model context Frontend: - Add TTS_FAILED status styling (red badge) in StatusBadge component - Add tts_failed to JobStatus TypeScript type 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
440 lines
15 KiB
Python
440 lines
15 KiB
Python
import io
|
|
import wave
|
|
|
|
from google import genai
|
|
from google.genai import types
|
|
from pydub import AudioSegment
|
|
|
|
from ..core.config import settings
|
|
from ..core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class TTSSynthesisError(Exception):
|
|
"""Raised when TTS synthesis fails after all retries."""
|
|
|
|
def __init__(self, message: str, cue_index: int, cue_text: str, api_response_info: str = None):
|
|
super().__init__(message)
|
|
self.cue_index = cue_index
|
|
self.cue_text = cue_text
|
|
self.api_response_info = api_response_info
|
|
|
|
|
|
class GeminiTTSService:
|
|
"""Text-to-Speech service using Gemini TTS API"""
|
|
|
|
def __init__(self):
|
|
self.client = genai.Client(api_key=settings.gemini_api_key)
|
|
self.model = settings.gemini_tts_model
|
|
self.default_voice = settings.gemini_tts_default_voice
|
|
logger.info(f"Gemini TTS service initialized with model: {self.model}")
|
|
|
|
async def synthesize_text(
|
|
self,
|
|
text: str,
|
|
voice_name: str,
|
|
language: str = "en",
|
|
model: str = "flash",
|
|
speed: float = 1.0,
|
|
style_prompt: str = ""
|
|
) -> bytes:
|
|
"""
|
|
Synthesize text to audio using Gemini TTS.
|
|
Returns MP3 audio bytes.
|
|
|
|
Args:
|
|
text: The text to synthesize
|
|
voice_name: Name of the voice to use
|
|
language: Language code (e.g., "en", "es")
|
|
model: Model variant - "flash" (fast) or "pro" (quality)
|
|
speed: Speech rate multiplier (0.5 to 2.0)
|
|
style_prompt: Style instructions to prepend (e.g., "Speak calmly...")
|
|
"""
|
|
if not text.strip():
|
|
raise ValueError("Text cannot be empty")
|
|
|
|
# Validate voice
|
|
if voice_name not in settings.gemini_tts_voices:
|
|
logger.warning(f"Unknown voice '{voice_name}', using default '{self.default_voice}'")
|
|
voice_name = self.default_voice
|
|
|
|
# Select model from config
|
|
model_id = settings.gemini_tts_models.get(model, settings.gemini_tts_model)
|
|
|
|
# Build the full prompt with style and speed instructions
|
|
prompt_parts = []
|
|
|
|
# Add style prompt if provided
|
|
if style_prompt:
|
|
prompt_parts.append(style_prompt)
|
|
|
|
# Add speed instruction if not default
|
|
if speed != 1.0:
|
|
speed_pct = int(speed * 100)
|
|
if speed < 1.0:
|
|
prompt_parts.append(f"Speak slowly at approximately {speed_pct}% of normal speed. ")
|
|
else:
|
|
prompt_parts.append(f"Speak quickly at approximately {speed_pct}% of normal speed. ")
|
|
|
|
# Combine prompts with actual text
|
|
full_text = "".join(prompt_parts) + text
|
|
|
|
try:
|
|
# Generate audio using Gemini TTS
|
|
response = self.client.models.generate_content(
|
|
model=model_id,
|
|
contents=full_text,
|
|
config=types.GenerateContentConfig(
|
|
response_modalities=["AUDIO"],
|
|
speech_config=types.SpeechConfig(
|
|
voice_config=types.VoiceConfig(
|
|
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
|
voice_name=voice_name,
|
|
)
|
|
)
|
|
),
|
|
)
|
|
)
|
|
|
|
# Extract PCM audio data from response with proper null-safe checks
|
|
if not response.candidates:
|
|
logger.error(
|
|
f"Gemini TTS response missing candidates. "
|
|
f"Response type: {type(response)}, Response: {response}"
|
|
)
|
|
raise ValueError("No candidates in Gemini TTS response")
|
|
|
|
candidate = response.candidates[0]
|
|
|
|
if candidate.content is None:
|
|
logger.error(
|
|
f"Gemini TTS candidate has no content. "
|
|
f"Finish reason: {getattr(candidate, 'finish_reason', 'unknown')}, "
|
|
f"Safety ratings: {getattr(candidate, 'safety_ratings', 'unknown')}"
|
|
)
|
|
raise ValueError(
|
|
f"Candidate content is None in Gemini TTS response. "
|
|
f"Finish reason: {getattr(candidate, 'finish_reason', 'unknown')}"
|
|
)
|
|
|
|
if not candidate.content.parts:
|
|
logger.error(
|
|
f"Gemini TTS content has no parts. "
|
|
f"Content role: {getattr(candidate.content, 'role', 'unknown')}"
|
|
)
|
|
raise ValueError("No parts in Gemini TTS response content")
|
|
|
|
part = candidate.content.parts[0]
|
|
if not hasattr(part, 'inline_data') or part.inline_data is None:
|
|
logger.error(
|
|
f"Gemini TTS part missing inline_data. "
|
|
f"Part type: {type(part)}, Part: {part}"
|
|
)
|
|
raise ValueError("No inline_data in Gemini TTS response part")
|
|
|
|
pcm_data = part.inline_data.data
|
|
|
|
# Convert PCM to MP3
|
|
mp3_data = self._pcm_to_mp3(pcm_data)
|
|
|
|
return mp3_data
|
|
|
|
except Exception as e:
|
|
# Log comprehensive error information for debugging
|
|
error_context = {
|
|
"text_length": len(text),
|
|
"text_preview": text[:100] + "..." if len(text) > 100 else text,
|
|
"voice_name": voice_name,
|
|
"language": language,
|
|
"model_id": model_id,
|
|
}
|
|
logger.error(
|
|
f"Gemini TTS synthesis failed: {e}. Context: {error_context}"
|
|
)
|
|
raise
|
|
|
|
async def synthesize_preview(
|
|
self,
|
|
voice_name: str,
|
|
language: str = "en",
|
|
model: str = "flash",
|
|
speed: float = 1.0,
|
|
style_prompt: str = ""
|
|
) -> bytes:
|
|
"""
|
|
Generate a preview audio sample for voice selection.
|
|
Uses language-specific sample text and applies all TTS settings.
|
|
"""
|
|
# Get preview sample text for the language
|
|
sample_text = settings.gemini_tts_preview_samples.get(
|
|
language,
|
|
settings.gemini_tts_preview_samples.get("en", "This is a voice preview.")
|
|
)
|
|
|
|
return await self.synthesize_text(
|
|
sample_text,
|
|
voice_name,
|
|
language,
|
|
model=model,
|
|
speed=speed,
|
|
style_prompt=style_prompt
|
|
)
|
|
|
|
async def _synthesize_cue_with_retry(
|
|
self,
|
|
cue_index: int,
|
|
text: str,
|
|
voice_name: str,
|
|
language: str,
|
|
model: str,
|
|
speed: float,
|
|
style_prompt: str,
|
|
max_attempts: int = 3,
|
|
base_delay: float = 1.0
|
|
) -> bytes:
|
|
"""
|
|
Synthesize a single cue with exponential backoff retry.
|
|
|
|
Args:
|
|
cue_index: Index of the cue (for error reporting)
|
|
text: Text to synthesize
|
|
voice_name: TTS voice name
|
|
language: Language code
|
|
model: Model variant
|
|
speed: Speech rate
|
|
style_prompt: Style instructions
|
|
max_attempts: Total attempts (1 initial + retries)
|
|
base_delay: Base delay in seconds for backoff
|
|
|
|
Returns:
|
|
MP3 audio bytes
|
|
|
|
Raises:
|
|
TTSSynthesisError: If all attempts fail
|
|
"""
|
|
import asyncio
|
|
import random
|
|
|
|
last_exception = None
|
|
api_response_info = None
|
|
|
|
for attempt in range(max_attempts):
|
|
try:
|
|
return await self.synthesize_text(
|
|
text,
|
|
voice_name,
|
|
language,
|
|
model=model,
|
|
speed=speed,
|
|
style_prompt=style_prompt
|
|
)
|
|
except Exception as e:
|
|
last_exception = e
|
|
api_response_info = str(e)
|
|
|
|
if attempt < max_attempts - 1:
|
|
# Exponential backoff with jitter
|
|
delay = base_delay * (2 ** attempt) + random.uniform(0, 1)
|
|
logger.warning(
|
|
f"TTS synthesis attempt {attempt + 1}/{max_attempts} failed for cue {cue_index}. "
|
|
f"Retrying in {delay:.2f}s. Error: {e}"
|
|
)
|
|
await asyncio.sleep(delay)
|
|
else:
|
|
logger.error(
|
|
f"TTS synthesis FAILED after {max_attempts} attempts for cue {cue_index}. "
|
|
f"Text: {text[:50]}{'...' if len(text) > 50 else ''}. Error: {e}"
|
|
)
|
|
|
|
# All retries exhausted - raise hard failure
|
|
raise TTSSynthesisError(
|
|
message=f"TTS synthesis failed after {max_attempts} attempts: {last_exception}",
|
|
cue_index=cue_index,
|
|
cue_text=text,
|
|
api_response_info=api_response_info
|
|
)
|
|
|
|
async def synthesize_audio_description(
|
|
self,
|
|
ad_vtt_content: str,
|
|
language: str = "en",
|
|
voice_name: str | None = None,
|
|
model: str = "flash",
|
|
speed: float = 1.0,
|
|
style_prompt: str = ""
|
|
) -> bytes:
|
|
"""
|
|
Synthesize full audio description from VTT content.
|
|
Maintains timing alignment with original VTT cues.
|
|
|
|
Args:
|
|
ad_vtt_content: VTT content with audio description cues
|
|
language: Language code (e.g., "en", "es")
|
|
voice_name: Name of the voice to use (defaults to service default)
|
|
model: Model variant - "flash" (fast) or "pro" (quality)
|
|
speed: Speech rate multiplier (0.5 to 2.0)
|
|
style_prompt: Style instructions to prepend to each cue
|
|
"""
|
|
if voice_name is None:
|
|
voice_name = self.default_voice
|
|
|
|
# Validate voice
|
|
if voice_name not in settings.gemini_tts_voices:
|
|
logger.warning(f"Unknown voice '{voice_name}', using default '{self.default_voice}'")
|
|
voice_name = self.default_voice
|
|
|
|
# Parse VTT cues
|
|
cues = self._parse_ad_cues(ad_vtt_content)
|
|
|
|
if not cues:
|
|
raise ValueError("No audio description cues found in VTT content")
|
|
|
|
logger.info(
|
|
f"Synthesizing {len(cues)} audio description cues with voice '{voice_name}', "
|
|
f"model '{model}', speed {speed}x"
|
|
)
|
|
|
|
# Synthesize each cue with precise timing anchoring
|
|
audio_segments = []
|
|
current_audio_position = 0.0
|
|
|
|
for i, cue in enumerate(cues):
|
|
target_start_time = cue["start_time"]
|
|
|
|
# Add silence to reach the exact VTT start time
|
|
if target_start_time > current_audio_position:
|
|
silence_duration = target_start_time - current_audio_position
|
|
silence = AudioSegment.silent(duration=int(silence_duration * 1000))
|
|
audio_segments.append(silence)
|
|
current_audio_position = target_start_time
|
|
|
|
# Synthesize this cue's text
|
|
text = cue["text"].strip()
|
|
if text:
|
|
# Ensure proper punctuation for natural TTS flow
|
|
if not text.endswith(('.', '!', '?')):
|
|
text += "."
|
|
|
|
# Use retry helper - will raise TTSSynthesisError on failure after retries
|
|
audio_data = await self._synthesize_cue_with_retry(
|
|
cue_index=i,
|
|
text=text,
|
|
voice_name=voice_name,
|
|
language=language,
|
|
model=model,
|
|
speed=speed,
|
|
style_prompt=style_prompt,
|
|
max_attempts=3,
|
|
base_delay=1.0
|
|
)
|
|
|
|
# Convert to AudioSegment and get actual duration
|
|
audio_segment = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3")
|
|
audio_segments.append(audio_segment)
|
|
|
|
# Update position based on actual audio duration
|
|
actual_audio_duration = len(audio_segment) / 1000.0
|
|
current_audio_position += actual_audio_duration
|
|
|
|
# Combine all segments
|
|
if audio_segments:
|
|
final_audio = sum(audio_segments, AudioSegment.empty())
|
|
else:
|
|
final_audio = AudioSegment.silent(duration=1000)
|
|
|
|
# Export to MP3
|
|
output_buffer = io.BytesIO()
|
|
final_audio.export(output_buffer, format="mp3", bitrate="128k")
|
|
|
|
logger.info(f"Audio description synthesized: {len(output_buffer.getvalue())} bytes")
|
|
return output_buffer.getvalue()
|
|
|
|
def _pcm_to_mp3(self, pcm_data: bytes) -> bytes:
|
|
"""
|
|
Convert raw PCM audio (24kHz, 16-bit, mono) to MP3.
|
|
Gemini TTS outputs PCM at 24000 Hz sample rate.
|
|
"""
|
|
# Create WAV from PCM data
|
|
wav_buffer = io.BytesIO()
|
|
with wave.open(wav_buffer, "wb") as wf:
|
|
wf.setnchannels(1) # Mono
|
|
wf.setsampwidth(2) # 16-bit (2 bytes)
|
|
wf.setframerate(24000) # 24kHz
|
|
wf.writeframes(pcm_data)
|
|
|
|
# Convert WAV to MP3 using pydub
|
|
wav_buffer.seek(0)
|
|
audio_segment = AudioSegment.from_wav(wav_buffer)
|
|
|
|
# Export as MP3
|
|
mp3_buffer = io.BytesIO()
|
|
audio_segment.export(mp3_buffer, format="mp3", bitrate="128k")
|
|
|
|
return mp3_buffer.getvalue()
|
|
|
|
def _parse_ad_cues(self, vtt_content: str) -> list[dict]:
|
|
"""Parse audio description VTT and extract timing + text"""
|
|
lines = vtt_content.strip().split('\n')
|
|
cues = []
|
|
|
|
i = 0
|
|
while i < len(lines):
|
|
line = lines[i].strip()
|
|
|
|
# Skip header and empty lines
|
|
if line == "WEBVTT" or line == "" or line.startswith("NOTE"):
|
|
i += 1
|
|
continue
|
|
|
|
# Check for timing line
|
|
if " --> " in line:
|
|
timing_parts = line.split(" --> ")
|
|
start_time = self._parse_timestamp(timing_parts[0].strip())
|
|
end_time = self._parse_timestamp(timing_parts[1].strip())
|
|
|
|
# Get text from next line(s)
|
|
i += 1
|
|
text_lines = []
|
|
while i < len(lines) and lines[i].strip() != "":
|
|
text_lines.append(lines[i].strip())
|
|
i += 1
|
|
|
|
if text_lines:
|
|
cues.append({
|
|
"start_time": start_time,
|
|
"end_time": end_time,
|
|
"text": " ".join(text_lines)
|
|
})
|
|
else:
|
|
i += 1
|
|
|
|
return cues
|
|
|
|
def _parse_timestamp(self, timestamp: str) -> float:
|
|
"""Convert VTT timestamp to seconds"""
|
|
parts = timestamp.split(":")
|
|
|
|
if len(parts) == 3: # HH:MM:SS.mmm
|
|
hours, minutes, seconds = parts
|
|
elif len(parts) == 2: # MM:SS.mmm
|
|
hours, minutes, seconds = "0", parts[0], parts[1]
|
|
else:
|
|
raise ValueError(f"Invalid timestamp format: {timestamp}")
|
|
|
|
sec_parts = seconds.split(".")
|
|
seconds_val = int(sec_parts[0])
|
|
milliseconds = int(sec_parts[1]) if len(sec_parts) > 1 else 0
|
|
|
|
total_seconds = (
|
|
int(hours) * 3600 +
|
|
int(minutes) * 60 +
|
|
seconds_val +
|
|
milliseconds / 1000.0
|
|
)
|
|
|
|
return total_seconds
|
|
|
|
|
|
# Global service instance
|
|
gemini_tts_service = GeminiTTSService()
|