feat(gemini): add model fallback chain on 429 quota errors

Routes all generate_content calls through _generate() which retries
gemini-3.1-flash-preview then gemini-2.5-pro when primary model hits
RESOURCE_EXHAUSTED. Cost tracker records actual model used.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Vadym Samoilenko 2026-05-08 12:02:59 +01:00
parent f38325b461
commit 56a3a62368

View file

@ -44,10 +44,35 @@ async def _record_gemini_usage(
class GeminiService:
_fallback_models: list[str] = [
"gemini-3.1-flash-preview",
"gemini-2.5-pro",
]
def __init__(self):
self.model_name = 'gemini-3.1-pro-preview'
self.prompts_dir = Path(__file__).parent.parent / "prompts"
async def _generate(self, contents: Any, config: Any = None) -> tuple[Any, str]:
"""Call generate_content, falling back on 429/quota errors. Returns (response, model_used)."""
for model in [self.model_name, *self._fallback_models]:
try:
kw: dict[str, Any] = {"model": model, "contents": contents}
if config is not None:
kw["config"] = config
response = await asyncio.to_thread(client.models.generate_content, **kw)
if model != self.model_name:
logger.warning(f"Used fallback model {model!r} (primary quota exceeded)")
return response, model
except Exception as exc:
msg = str(exc)
if "429" in msg or "RESOURCE_EXHAUSTED" in msg:
logger.warning(f"Model {model!r} quota exceeded, trying next fallback")
last_exc: Exception = exc
continue
raise
raise last_exc # noqa: F821 — set in loop above when all models exhausted
def _load_prompt(self, prompt_file: str) -> str:
"""Load prompt template from prompts directory"""
prompt_path = self.prompts_dir / prompt_file
@ -174,9 +199,7 @@ Generate sdh_captions_vtt using the same cue timings as captions_vtt, enriched w
# Generate content using new API - use asyncio.to_thread to avoid blocking
logger.info("Generating content with Gemini model...")
_t0 = time.monotonic()
response = await asyncio.to_thread(
client.models.generate_content,
model=self.model_name,
response, _model_used = await self._generate(
contents=[
genai.types.Part.from_text(text=prompt),
genai.types.Part.from_uri(
@ -185,13 +208,13 @@ Generate sdh_captions_vtt using the same cue timings as captions_vtt, enriched w
)
],
config=genai.types.GenerateContentConfig(
temperature=0.2, # Lower temperature for consistent, deterministic AD output
temperature=0.2,
top_p=0.8,
top_k=40,
),
)
if _cost_ctx:
asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
# Parse JSON response
response_text = response.text.strip()
@ -292,9 +315,7 @@ Fix the JSON and return it:
"""
try:
response = await asyncio.to_thread(
client.models.generate_content,
model=self.model_name,
response, _ = await self._generate(
contents=[genai.types.Part.from_text(text=self_heal_prompt)]
)
@ -394,9 +415,7 @@ Fix the JSON and return it:
# Generate content using new API
logger.info(f"Generating content with Gemini model for {target_language}...")
_t0 = time.monotonic()
response = await asyncio.to_thread(
client.models.generate_content,
model=self.model_name,
response, _model_used = await self._generate(
contents=[
genai.types.Part.from_text(text=prompt),
genai.types.Part.from_uri(
@ -406,7 +425,7 @@ Fix the JSON and return it:
]
)
if _cost_ctx:
asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
# Parse JSON response
response_text = response.text.strip()
@ -509,9 +528,7 @@ Fix the JSON and return it:
"""
try:
response = await asyncio.to_thread(
client.models.generate_content,
model=self.model_name,
response, _ = await self._generate(
contents=[genai.types.Part.from_text(text=self_heal_prompt)]
)
@ -668,9 +685,7 @@ Fix the JSON and return it:
# Generate content with video and prompt
logger.info("Analyzing video with Gemini for accessible video placement...")
response = await asyncio.to_thread(
client.models.generate_content,
model=self.model_name,
response, _ = await self._generate(
contents=[
genai.types.Part.from_text(text=prompt),
genai.types.Part.from_uri(
@ -752,9 +767,7 @@ Fix the JSON and return it:
"""
try:
response = await asyncio.to_thread(
client.models.generate_content,
model=self.model_name,
response, _ = await self._generate(
contents=[genai.types.Part.from_text(text=self_heal_prompt)]
)
@ -802,15 +815,11 @@ JSON:
try:
_t0 = time.monotonic()
response = await asyncio.to_thread(
client.models.generate_content,
model=self.model_name,
contents=[
genai.types.Part.from_text(text=prompt + "\n\n" + user_prompt)
]
response, _model_used = await self._generate(
contents=[genai.types.Part.from_text(text=prompt + "\n\n" + user_prompt)]
)
if _cost_ctx:
asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
response_text = response.text.strip()
@ -892,13 +901,11 @@ Segments to translate:
{numbered_texts}"""
_t0 = time.monotonic()
response = await asyncio.to_thread(
client.models.generate_content,
model=self.model_name,
response, _model_used = await self._generate(
contents=[genai.types.Part.from_text(text=prompt)]
)
if _cost_ctx:
asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
return self._parse_numbered_translation(response.text.strip(), cue_count)
try:
@ -985,13 +992,11 @@ Segments to translate:
logger.info(f"Rewriting TTS cue for safety: '{original_text[:50]}...'")
_t0 = time.monotonic()
response = await asyncio.to_thread(
client.models.generate_content,
model=self.model_name,
response, _model_used = await self._generate(
contents=[genai.types.Part.from_text(text=prompt)]
)
if _cost_ctx:
asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
result = response.text.strip()