import io import wave from google import genai from google.genai import types from pydub import AudioSegment from ..core.config import settings from ..core.logging import get_logger logger = get_logger(__name__) class TTSSynthesisError(Exception): """Raised when TTS synthesis fails after all retries.""" def __init__(self, message: str, cue_index: int, cue_text: str, api_response_info: str = None): super().__init__(message) self.cue_index = cue_index self.cue_text = cue_text self.api_response_info = api_response_info class GeminiTTSService: """Text-to-Speech service using Gemini TTS API""" def __init__(self): self.client = genai.Client(api_key=settings.gemini_api_key) self.model = settings.gemini_tts_model self.default_voice = settings.gemini_tts_default_voice logger.info(f"Gemini TTS service initialized with model: {self.model}") async def synthesize_text( self, text: str, voice_name: str, language: str = "en", model: str = "flash", speed: float = 1.0, style_prompt: str = "" ) -> bytes: """ Synthesize text to audio using Gemini TTS. Returns MP3 audio bytes. Args: text: The text to synthesize voice_name: Name of the voice to use language: Language code (e.g., "en", "es") model: Model variant - "flash" (fast) or "pro" (quality) speed: Speech rate multiplier (0.5 to 2.0) style_prompt: Style instructions to prepend (e.g., "Speak calmly...") """ if not text.strip(): raise ValueError("Text cannot be empty") # Validate voice if voice_name not in settings.gemini_tts_voices: logger.warning(f"Unknown voice '{voice_name}', using default '{self.default_voice}'") voice_name = self.default_voice # Select model from config model_id = settings.gemini_tts_models.get(model, settings.gemini_tts_model) # Build the full prompt with style and speed instructions prompt_parts = [] # Add style prompt if provided if style_prompt: prompt_parts.append(style_prompt) # Add speed instruction if not default if speed != 1.0: speed_pct = int(speed * 100) if speed < 1.0: prompt_parts.append(f"Speak slowly at approximately {speed_pct}% of normal speed. ") else: prompt_parts.append(f"Speak quickly at approximately {speed_pct}% of normal speed. ") # Combine prompts with actual text full_text = "".join(prompt_parts) + text try: # Generate audio using Gemini TTS response = self.client.models.generate_content( model=model_id, contents=full_text, config=types.GenerateContentConfig( response_modalities=["AUDIO"], speech_config=types.SpeechConfig( voice_config=types.VoiceConfig( prebuilt_voice_config=types.PrebuiltVoiceConfig( voice_name=voice_name, ) ) ), ) ) # Extract PCM audio data from response with proper null-safe checks if not response.candidates: logger.error( f"Gemini TTS response missing candidates. " f"Response type: {type(response)}, Response: {response}" ) raise ValueError("No candidates in Gemini TTS response") candidate = response.candidates[0] if candidate.content is None: logger.error( f"Gemini TTS candidate has no content. " f"Finish reason: {getattr(candidate, 'finish_reason', 'unknown')}, " f"Safety ratings: {getattr(candidate, 'safety_ratings', 'unknown')}" ) raise ValueError( f"Candidate content is None in Gemini TTS response. " f"Finish reason: {getattr(candidate, 'finish_reason', 'unknown')}" ) if not candidate.content.parts: logger.error( f"Gemini TTS content has no parts. " f"Content role: {getattr(candidate.content, 'role', 'unknown')}" ) raise ValueError("No parts in Gemini TTS response content") part = candidate.content.parts[0] if not hasattr(part, 'inline_data') or part.inline_data is None: logger.error( f"Gemini TTS part missing inline_data. " f"Part type: {type(part)}, Part: {part}" ) raise ValueError("No inline_data in Gemini TTS response part") pcm_data = part.inline_data.data # Convert PCM to MP3 mp3_data = self._pcm_to_mp3(pcm_data) return mp3_data except Exception as e: # Log comprehensive error information for debugging error_context = { "text_length": len(text), "text_preview": text[:100] + "..." if len(text) > 100 else text, "voice_name": voice_name, "language": language, "model_id": model_id, } logger.error( f"Gemini TTS synthesis failed: {e}. Context: {error_context}" ) raise async def synthesize_preview( self, voice_name: str, language: str = "en", model: str = "flash", speed: float = 1.0, style_prompt: str = "" ) -> bytes: """ Generate a preview audio sample for voice selection. Uses language-specific sample text and applies all TTS settings. """ # Get preview sample text for the language sample_text = settings.gemini_tts_preview_samples.get( language, settings.gemini_tts_preview_samples.get("en", "This is a voice preview.") ) return await self.synthesize_text( sample_text, voice_name, language, model=model, speed=speed, style_prompt=style_prompt ) async def _synthesize_cue_with_retry( self, cue_index: int, text: str, voice_name: str, language: str, model: str, speed: float, style_prompt: str, max_attempts: int = 3, base_delay: float = 1.0 ) -> bytes: """ Synthesize a single cue with exponential backoff retry. Args: cue_index: Index of the cue (for error reporting) text: Text to synthesize voice_name: TTS voice name language: Language code model: Model variant speed: Speech rate style_prompt: Style instructions max_attempts: Total attempts (1 initial + retries) base_delay: Base delay in seconds for backoff Returns: MP3 audio bytes Raises: TTSSynthesisError: If all attempts fail """ import asyncio import random last_exception = None api_response_info = None for attempt in range(max_attempts): try: return await self.synthesize_text( text, voice_name, language, model=model, speed=speed, style_prompt=style_prompt ) except Exception as e: last_exception = e api_response_info = str(e) if attempt < max_attempts - 1: # Exponential backoff with jitter delay = base_delay * (2 ** attempt) + random.uniform(0, 1) logger.warning( f"TTS synthesis attempt {attempt + 1}/{max_attempts} failed for cue {cue_index}. " f"Retrying in {delay:.2f}s. Error: {e}" ) await asyncio.sleep(delay) else: logger.error( f"TTS synthesis FAILED after {max_attempts} attempts for cue {cue_index}. " f"Text: {text[:50]}{'...' if len(text) > 50 else ''}. Error: {e}" ) # All retries exhausted - raise hard failure raise TTSSynthesisError( message=f"TTS synthesis failed after {max_attempts} attempts: {last_exception}", cue_index=cue_index, cue_text=text, api_response_info=api_response_info ) async def synthesize_audio_description( self, ad_vtt_content: str, language: str = "en", voice_name: str | None = None, model: str = "flash", speed: float = 1.0, style_prompt: str = "" ) -> bytes: """ Synthesize full audio description from VTT content. Maintains timing alignment with original VTT cues. Args: ad_vtt_content: VTT content with audio description cues language: Language code (e.g., "en", "es") voice_name: Name of the voice to use (defaults to service default) model: Model variant - "flash" (fast) or "pro" (quality) speed: Speech rate multiplier (0.5 to 2.0) style_prompt: Style instructions to prepend to each cue """ if voice_name is None: voice_name = self.default_voice # Validate voice if voice_name not in settings.gemini_tts_voices: logger.warning(f"Unknown voice '{voice_name}', using default '{self.default_voice}'") voice_name = self.default_voice # Parse VTT cues cues = self._parse_ad_cues(ad_vtt_content) if not cues: raise ValueError("No audio description cues found in VTT content") logger.info( f"Synthesizing {len(cues)} audio description cues with voice '{voice_name}', " f"model '{model}', speed {speed}x" ) # Synthesize each cue with precise timing anchoring audio_segments = [] current_audio_position = 0.0 for i, cue in enumerate(cues): target_start_time = cue["start_time"] # Add silence to reach the exact VTT start time if target_start_time > current_audio_position: silence_duration = target_start_time - current_audio_position silence = AudioSegment.silent(duration=int(silence_duration * 1000)) audio_segments.append(silence) current_audio_position = target_start_time # Synthesize this cue's text text = cue["text"].strip() if text: # Ensure proper punctuation for natural TTS flow if not text.endswith(('.', '!', '?')): text += "." # Use retry helper - will raise TTSSynthesisError on failure after retries audio_data = await self._synthesize_cue_with_retry( cue_index=i, text=text, voice_name=voice_name, language=language, model=model, speed=speed, style_prompt=style_prompt, max_attempts=3, base_delay=1.0 ) # Convert to AudioSegment and get actual duration audio_segment = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3") audio_segments.append(audio_segment) # Update position based on actual audio duration actual_audio_duration = len(audio_segment) / 1000.0 current_audio_position += actual_audio_duration # Combine all segments if audio_segments: final_audio = sum(audio_segments, AudioSegment.empty()) else: final_audio = AudioSegment.silent(duration=1000) # Export to MP3 output_buffer = io.BytesIO() final_audio.export(output_buffer, format="mp3", bitrate="128k") logger.info(f"Audio description synthesized: {len(output_buffer.getvalue())} bytes") return output_buffer.getvalue() def _pcm_to_mp3(self, pcm_data: bytes) -> bytes: """ Convert raw PCM audio (24kHz, 16-bit, mono) to MP3. Gemini TTS outputs PCM at 24000 Hz sample rate. """ # Create WAV from PCM data wav_buffer = io.BytesIO() with wave.open(wav_buffer, "wb") as wf: wf.setnchannels(1) # Mono wf.setsampwidth(2) # 16-bit (2 bytes) wf.setframerate(24000) # 24kHz wf.writeframes(pcm_data) # Convert WAV to MP3 using pydub wav_buffer.seek(0) audio_segment = AudioSegment.from_wav(wav_buffer) # Export as MP3 mp3_buffer = io.BytesIO() audio_segment.export(mp3_buffer, format="mp3", bitrate="128k") return mp3_buffer.getvalue() def _parse_ad_cues(self, vtt_content: str) -> list[dict]: """Parse audio description VTT and extract timing + text""" lines = vtt_content.strip().split('\n') cues = [] i = 0 while i < len(lines): line = lines[i].strip() # Skip header and empty lines if line == "WEBVTT" or line == "" or line.startswith("NOTE"): i += 1 continue # Check for timing line if " --> " in line: timing_parts = line.split(" --> ") start_time = self._parse_timestamp(timing_parts[0].strip()) end_time = self._parse_timestamp(timing_parts[1].strip()) # Get text from next line(s) i += 1 text_lines = [] while i < len(lines) and lines[i].strip() != "": text_lines.append(lines[i].strip()) i += 1 if text_lines: cues.append({ "start_time": start_time, "end_time": end_time, "text": " ".join(text_lines) }) else: i += 1 return cues def _parse_timestamp(self, timestamp: str) -> float: """Convert VTT timestamp to seconds""" parts = timestamp.split(":") if len(parts) == 3: # HH:MM:SS.mmm hours, minutes, seconds = parts elif len(parts) == 2: # MM:SS.mmm hours, minutes, seconds = "0", parts[0], parts[1] else: raise ValueError(f"Invalid timestamp format: {timestamp}") sec_parts = seconds.split(".") seconds_val = int(sec_parts[0]) milliseconds = int(sec_parts[1]) if len(sec_parts) > 1 else 0 total_seconds = ( int(hours) * 3600 + int(minutes) * 60 + seconds_val + milliseconds / 1000.0 ) return total_seconds # Global service instance gemini_tts_service = GeminiTTSService()