video-accessibility/backend/app/services/gemini_tts.py
michael 865fcdc246 feat: add TTS settings panel with model, speed, and style options
- Add model selection (flash vs pro) for quality control
- Add speed slider (0.5x - 2.0x) for pacing adjustment
- Add style presets (neutral, calm, energetic, professional, warm, documentary)
- Add custom style prompt option for advanced customization
- New /tts/options endpoint returns available TTS options
- Voice preview now tests all settings so users hear exact output
- Backward compatible: all new fields have sensible defaults

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-22 15:22:14 -06:00

319 lines
11 KiB
Python

import io
import wave
from google import genai
from google.genai import types
from pydub import AudioSegment
from ..core.config import settings
from ..core.logging import get_logger
logger = get_logger(__name__)
class GeminiTTSService:
"""Text-to-Speech service using Gemini TTS API"""
def __init__(self):
self.client = genai.Client(api_key=settings.gemini_api_key)
self.model = settings.gemini_tts_model
self.default_voice = settings.gemini_tts_default_voice
logger.info(f"Gemini TTS service initialized with model: {self.model}")
async def synthesize_text(
self,
text: str,
voice_name: str,
language: str = "en",
model: str = "flash",
speed: float = 1.0,
style_prompt: str = ""
) -> bytes:
"""
Synthesize text to audio using Gemini TTS.
Returns MP3 audio bytes.
Args:
text: The text to synthesize
voice_name: Name of the voice to use
language: Language code (e.g., "en", "es")
model: Model variant - "flash" (fast) or "pro" (quality)
speed: Speech rate multiplier (0.5 to 2.0)
style_prompt: Style instructions to prepend (e.g., "Speak calmly...")
"""
if not text.strip():
raise ValueError("Text cannot be empty")
# Validate voice
if voice_name not in settings.gemini_tts_voices:
logger.warning(f"Unknown voice '{voice_name}', using default '{self.default_voice}'")
voice_name = self.default_voice
# Select model from config
model_id = settings.gemini_tts_models.get(model, settings.gemini_tts_model)
# Build the full prompt with style and speed instructions
prompt_parts = []
# Add style prompt if provided
if style_prompt:
prompt_parts.append(style_prompt)
# Add speed instruction if not default
if speed != 1.0:
speed_pct = int(speed * 100)
if speed < 1.0:
prompt_parts.append(f"Speak slowly at approximately {speed_pct}% of normal speed. ")
else:
prompt_parts.append(f"Speak quickly at approximately {speed_pct}% of normal speed. ")
# Combine prompts with actual text
full_text = "".join(prompt_parts) + text
try:
# Generate audio using Gemini TTS
response = self.client.models.generate_content(
model=model_id,
contents=full_text,
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=voice_name,
)
)
),
)
)
# Extract PCM audio data from response
if not response.candidates or not response.candidates[0].content.parts:
raise ValueError("No audio data in Gemini TTS response")
pcm_data = response.candidates[0].content.parts[0].inline_data.data
# Convert PCM to MP3
mp3_data = self._pcm_to_mp3(pcm_data)
return mp3_data
except Exception as e:
logger.error(f"Gemini TTS synthesis failed: {e}")
raise
async def synthesize_preview(
self,
voice_name: str,
language: str = "en",
model: str = "flash",
speed: float = 1.0,
style_prompt: str = ""
) -> bytes:
"""
Generate a preview audio sample for voice selection.
Uses language-specific sample text and applies all TTS settings.
"""
# Get preview sample text for the language
sample_text = settings.gemini_tts_preview_samples.get(
language,
settings.gemini_tts_preview_samples.get("en", "This is a voice preview.")
)
return await self.synthesize_text(
sample_text,
voice_name,
language,
model=model,
speed=speed,
style_prompt=style_prompt
)
async def synthesize_audio_description(
self,
ad_vtt_content: str,
language: str = "en",
voice_name: str | None = None,
model: str = "flash",
speed: float = 1.0,
style_prompt: str = ""
) -> bytes:
"""
Synthesize full audio description from VTT content.
Maintains timing alignment with original VTT cues.
Args:
ad_vtt_content: VTT content with audio description cues
language: Language code (e.g., "en", "es")
voice_name: Name of the voice to use (defaults to service default)
model: Model variant - "flash" (fast) or "pro" (quality)
speed: Speech rate multiplier (0.5 to 2.0)
style_prompt: Style instructions to prepend to each cue
"""
if voice_name is None:
voice_name = self.default_voice
# Validate voice
if voice_name not in settings.gemini_tts_voices:
logger.warning(f"Unknown voice '{voice_name}', using default '{self.default_voice}'")
voice_name = self.default_voice
# Parse VTT cues
cues = self._parse_ad_cues(ad_vtt_content)
if not cues:
raise ValueError("No audio description cues found in VTT content")
logger.info(
f"Synthesizing {len(cues)} audio description cues with voice '{voice_name}', "
f"model '{model}', speed {speed}x"
)
# Synthesize each cue with precise timing anchoring
audio_segments = []
current_audio_position = 0.0
for i, cue in enumerate(cues):
target_start_time = cue["start_time"]
# Add silence to reach the exact VTT start time
if target_start_time > current_audio_position:
silence_duration = target_start_time - current_audio_position
silence = AudioSegment.silent(duration=int(silence_duration * 1000))
audio_segments.append(silence)
current_audio_position = target_start_time
# Synthesize this cue's text
text = cue["text"].strip()
if text:
# Ensure proper punctuation for natural TTS flow
if not text.endswith(('.', '!', '?')):
text += "."
try:
audio_data = await self.synthesize_text(
text,
voice_name,
language,
model=model,
speed=speed,
style_prompt=style_prompt
)
# Convert to AudioSegment and get actual duration
audio_segment = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3")
audio_segments.append(audio_segment)
# Update position based on actual audio duration
actual_audio_duration = len(audio_segment) / 1000.0
current_audio_position += actual_audio_duration
except Exception as e:
logger.warning(f"Failed to synthesize cue {i}: {e}")
# Add silence for failed cue
cue_duration = cue["end_time"] - cue["start_time"]
silence = AudioSegment.silent(duration=int(cue_duration * 1000))
audio_segments.append(silence)
current_audio_position += cue_duration
# Combine all segments
if audio_segments:
final_audio = sum(audio_segments, AudioSegment.empty())
else:
final_audio = AudioSegment.silent(duration=1000)
# Export to MP3
output_buffer = io.BytesIO()
final_audio.export(output_buffer, format="mp3", bitrate="128k")
logger.info(f"Audio description synthesized: {len(output_buffer.getvalue())} bytes")
return output_buffer.getvalue()
def _pcm_to_mp3(self, pcm_data: bytes) -> bytes:
"""
Convert raw PCM audio (24kHz, 16-bit, mono) to MP3.
Gemini TTS outputs PCM at 24000 Hz sample rate.
"""
# Create WAV from PCM data
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wf:
wf.setnchannels(1) # Mono
wf.setsampwidth(2) # 16-bit (2 bytes)
wf.setframerate(24000) # 24kHz
wf.writeframes(pcm_data)
# Convert WAV to MP3 using pydub
wav_buffer.seek(0)
audio_segment = AudioSegment.from_wav(wav_buffer)
# Export as MP3
mp3_buffer = io.BytesIO()
audio_segment.export(mp3_buffer, format="mp3", bitrate="128k")
return mp3_buffer.getvalue()
def _parse_ad_cues(self, vtt_content: str) -> list[dict]:
"""Parse audio description VTT and extract timing + text"""
lines = vtt_content.strip().split('\n')
cues = []
i = 0
while i < len(lines):
line = lines[i].strip()
# Skip header and empty lines
if line == "WEBVTT" or line == "" or line.startswith("NOTE"):
i += 1
continue
# Check for timing line
if " --> " in line:
timing_parts = line.split(" --> ")
start_time = self._parse_timestamp(timing_parts[0].strip())
end_time = self._parse_timestamp(timing_parts[1].strip())
# Get text from next line(s)
i += 1
text_lines = []
while i < len(lines) and lines[i].strip() != "":
text_lines.append(lines[i].strip())
i += 1
if text_lines:
cues.append({
"start_time": start_time,
"end_time": end_time,
"text": " ".join(text_lines)
})
else:
i += 1
return cues
def _parse_timestamp(self, timestamp: str) -> float:
"""Convert VTT timestamp to seconds"""
parts = timestamp.split(":")
if len(parts) == 3: # HH:MM:SS.mmm
hours, minutes, seconds = parts
elif len(parts) == 2: # MM:SS.mmm
hours, minutes, seconds = "0", parts[0], parts[1]
else:
raise ValueError(f"Invalid timestamp format: {timestamp}")
sec_parts = seconds.split(".")
seconds_val = int(sec_parts[0])
milliseconds = int(sec_parts[1]) if len(sec_parts) > 1 else 0
total_seconds = (
int(hours) * 3600 +
int(minutes) * 60 +
seconds_val +
milliseconds / 1000.0
)
return total_seconds
# Global service instance
gemini_tts_service = GeminiTTSService()