feat(gemini): add model fallback chain on 429 quota errors
Routes all generate_content calls through _generate() which retries gemini-3.1-flash-preview then gemini-2.5-pro when primary model hits RESOURCE_EXHAUSTED. Cost tracker records actual model used. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f38325b461
commit
56a3a62368
1 changed files with 41 additions and 36 deletions
|
|
@ -44,10 +44,35 @@ async def _record_gemini_usage(
|
|||
|
||||
|
||||
class GeminiService:
|
||||
_fallback_models: list[str] = [
|
||||
"gemini-3.1-flash-preview",
|
||||
"gemini-2.5-pro",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.model_name = 'gemini-3.1-pro-preview'
|
||||
self.prompts_dir = Path(__file__).parent.parent / "prompts"
|
||||
|
||||
async def _generate(self, contents: Any, config: Any = None) -> tuple[Any, str]:
|
||||
"""Call generate_content, falling back on 429/quota errors. Returns (response, model_used)."""
|
||||
for model in [self.model_name, *self._fallback_models]:
|
||||
try:
|
||||
kw: dict[str, Any] = {"model": model, "contents": contents}
|
||||
if config is not None:
|
||||
kw["config"] = config
|
||||
response = await asyncio.to_thread(client.models.generate_content, **kw)
|
||||
if model != self.model_name:
|
||||
logger.warning(f"Used fallback model {model!r} (primary quota exceeded)")
|
||||
return response, model
|
||||
except Exception as exc:
|
||||
msg = str(exc)
|
||||
if "429" in msg or "RESOURCE_EXHAUSTED" in msg:
|
||||
logger.warning(f"Model {model!r} quota exceeded, trying next fallback")
|
||||
last_exc: Exception = exc
|
||||
continue
|
||||
raise
|
||||
raise last_exc # noqa: F821 — set in loop above when all models exhausted
|
||||
|
||||
def _load_prompt(self, prompt_file: str) -> str:
|
||||
"""Load prompt template from prompts directory"""
|
||||
prompt_path = self.prompts_dir / prompt_file
|
||||
|
|
@ -174,9 +199,7 @@ Generate sdh_captions_vtt using the same cue timings as captions_vtt, enriched w
|
|||
# Generate content using new API - use asyncio.to_thread to avoid blocking
|
||||
logger.info("Generating content with Gemini model...")
|
||||
_t0 = time.monotonic()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=self.model_name,
|
||||
response, _model_used = await self._generate(
|
||||
contents=[
|
||||
genai.types.Part.from_text(text=prompt),
|
||||
genai.types.Part.from_uri(
|
||||
|
|
@ -185,13 +208,13 @@ Generate sdh_captions_vtt using the same cue timings as captions_vtt, enriched w
|
|||
)
|
||||
],
|
||||
config=genai.types.GenerateContentConfig(
|
||||
temperature=0.2, # Lower temperature for consistent, deterministic AD output
|
||||
temperature=0.2,
|
||||
top_p=0.8,
|
||||
top_k=40,
|
||||
),
|
||||
)
|
||||
if _cost_ctx:
|
||||
asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
|
||||
asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
|
||||
|
||||
# Parse JSON response
|
||||
response_text = response.text.strip()
|
||||
|
|
@ -292,9 +315,7 @@ Fix the JSON and return it:
|
|||
"""
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=self.model_name,
|
||||
response, _ = await self._generate(
|
||||
contents=[genai.types.Part.from_text(text=self_heal_prompt)]
|
||||
)
|
||||
|
||||
|
|
@ -394,9 +415,7 @@ Fix the JSON and return it:
|
|||
# Generate content using new API
|
||||
logger.info(f"Generating content with Gemini model for {target_language}...")
|
||||
_t0 = time.monotonic()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=self.model_name,
|
||||
response, _model_used = await self._generate(
|
||||
contents=[
|
||||
genai.types.Part.from_text(text=prompt),
|
||||
genai.types.Part.from_uri(
|
||||
|
|
@ -406,7 +425,7 @@ Fix the JSON and return it:
|
|||
]
|
||||
)
|
||||
if _cost_ctx:
|
||||
asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
|
||||
asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
|
||||
|
||||
# Parse JSON response
|
||||
response_text = response.text.strip()
|
||||
|
|
@ -509,9 +528,7 @@ Fix the JSON and return it:
|
|||
"""
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=self.model_name,
|
||||
response, _ = await self._generate(
|
||||
contents=[genai.types.Part.from_text(text=self_heal_prompt)]
|
||||
)
|
||||
|
||||
|
|
@ -668,9 +685,7 @@ Fix the JSON and return it:
|
|||
|
||||
# Generate content with video and prompt
|
||||
logger.info("Analyzing video with Gemini for accessible video placement...")
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=self.model_name,
|
||||
response, _ = await self._generate(
|
||||
contents=[
|
||||
genai.types.Part.from_text(text=prompt),
|
||||
genai.types.Part.from_uri(
|
||||
|
|
@ -752,9 +767,7 @@ Fix the JSON and return it:
|
|||
"""
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=self.model_name,
|
||||
response, _ = await self._generate(
|
||||
contents=[genai.types.Part.from_text(text=self_heal_prompt)]
|
||||
)
|
||||
|
||||
|
|
@ -802,15 +815,11 @@ JSON:
|
|||
|
||||
try:
|
||||
_t0 = time.monotonic()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=self.model_name,
|
||||
contents=[
|
||||
genai.types.Part.from_text(text=prompt + "\n\n" + user_prompt)
|
||||
]
|
||||
response, _model_used = await self._generate(
|
||||
contents=[genai.types.Part.from_text(text=prompt + "\n\n" + user_prompt)]
|
||||
)
|
||||
if _cost_ctx:
|
||||
asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
|
||||
asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
|
||||
|
||||
response_text = response.text.strip()
|
||||
|
||||
|
|
@ -892,13 +901,11 @@ Segments to translate:
|
|||
{numbered_texts}"""
|
||||
|
||||
_t0 = time.monotonic()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=self.model_name,
|
||||
response, _model_used = await self._generate(
|
||||
contents=[genai.types.Part.from_text(text=prompt)]
|
||||
)
|
||||
if _cost_ctx:
|
||||
asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
|
||||
asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
|
||||
return self._parse_numbered_translation(response.text.strip(), cue_count)
|
||||
|
||||
try:
|
||||
|
|
@ -985,13 +992,11 @@ Segments to translate:
|
|||
logger.info(f"Rewriting TTS cue for safety: '{original_text[:50]}...'")
|
||||
|
||||
_t0 = time.monotonic()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=self.model_name,
|
||||
response, _model_used = await self._generate(
|
||||
contents=[genai.types.Part.from_text(text=prompt)]
|
||||
)
|
||||
if _cost_ctx:
|
||||
asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
|
||||
asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000)))
|
||||
|
||||
result = response.text.strip()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue