diff --git a/backend/app/services/gemini.py b/backend/app/services/gemini.py index 323b2d1..ec84a1a 100644 --- a/backend/app/services/gemini.py +++ b/backend/app/services/gemini.py @@ -44,10 +44,35 @@ async def _record_gemini_usage( class GeminiService: + _fallback_models: list[str] = [ + "gemini-3.1-flash-preview", + "gemini-2.5-pro", + ] + def __init__(self): self.model_name = 'gemini-3.1-pro-preview' self.prompts_dir = Path(__file__).parent.parent / "prompts" + async def _generate(self, contents: Any, config: Any = None) -> tuple[Any, str]: + """Call generate_content, falling back on 429/quota errors. Returns (response, model_used).""" + for model in [self.model_name, *self._fallback_models]: + try: + kw: dict[str, Any] = {"model": model, "contents": contents} + if config is not None: + kw["config"] = config + response = await asyncio.to_thread(client.models.generate_content, **kw) + if model != self.model_name: + logger.warning(f"Used fallback model {model!r} (primary quota exceeded)") + return response, model + except Exception as exc: + msg = str(exc) + if "429" in msg or "RESOURCE_EXHAUSTED" in msg: + logger.warning(f"Model {model!r} quota exceeded, trying next fallback") + last_exc: Exception = exc + continue + raise + raise last_exc # noqa: F821 — set in loop above when all models exhausted + def _load_prompt(self, prompt_file: str) -> str: """Load prompt template from prompts directory""" prompt_path = self.prompts_dir / prompt_file @@ -174,9 +199,7 @@ Generate sdh_captions_vtt using the same cue timings as captions_vtt, enriched w # Generate content using new API - use asyncio.to_thread to avoid blocking logger.info("Generating content with Gemini model...") _t0 = time.monotonic() - response = await asyncio.to_thread( - client.models.generate_content, - model=self.model_name, + response, _model_used = await self._generate( contents=[ genai.types.Part.from_text(text=prompt), genai.types.Part.from_uri( @@ -185,13 +208,13 @@ Generate sdh_captions_vtt using the same cue timings as captions_vtt, enriched w ) ], config=genai.types.GenerateContentConfig( - temperature=0.2, # Lower temperature for consistent, deterministic AD output + temperature=0.2, top_p=0.8, top_k=40, ), ) if _cost_ctx: - asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000))) + asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000))) # Parse JSON response response_text = response.text.strip() @@ -292,9 +315,7 @@ Fix the JSON and return it: """ try: - response = await asyncio.to_thread( - client.models.generate_content, - model=self.model_name, + response, _ = await self._generate( contents=[genai.types.Part.from_text(text=self_heal_prompt)] ) @@ -394,9 +415,7 @@ Fix the JSON and return it: # Generate content using new API logger.info(f"Generating content with Gemini model for {target_language}...") _t0 = time.monotonic() - response = await asyncio.to_thread( - client.models.generate_content, - model=self.model_name, + response, _model_used = await self._generate( contents=[ genai.types.Part.from_text(text=prompt), genai.types.Part.from_uri( @@ -406,7 +425,7 @@ Fix the JSON and return it: ] ) if _cost_ctx: - asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000))) + asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000))) # Parse JSON response response_text = response.text.strip() @@ -509,9 +528,7 @@ Fix the JSON and return it: """ try: - response = await asyncio.to_thread( - client.models.generate_content, - model=self.model_name, + response, _ = await self._generate( contents=[genai.types.Part.from_text(text=self_heal_prompt)] ) @@ -668,9 +685,7 @@ Fix the JSON and return it: # Generate content with video and prompt logger.info("Analyzing video with Gemini for accessible video placement...") - response = await asyncio.to_thread( - client.models.generate_content, - model=self.model_name, + response, _ = await self._generate( contents=[ genai.types.Part.from_text(text=prompt), genai.types.Part.from_uri( @@ -752,9 +767,7 @@ Fix the JSON and return it: """ try: - response = await asyncio.to_thread( - client.models.generate_content, - model=self.model_name, + response, _ = await self._generate( contents=[genai.types.Part.from_text(text=self_heal_prompt)] ) @@ -802,15 +815,11 @@ JSON: try: _t0 = time.monotonic() - response = await asyncio.to_thread( - client.models.generate_content, - model=self.model_name, - contents=[ - genai.types.Part.from_text(text=prompt + "\n\n" + user_prompt) - ] + response, _model_used = await self._generate( + contents=[genai.types.Part.from_text(text=prompt + "\n\n" + user_prompt)] ) if _cost_ctx: - asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000))) + asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000))) response_text = response.text.strip() @@ -892,13 +901,11 @@ Segments to translate: {numbered_texts}""" _t0 = time.monotonic() - response = await asyncio.to_thread( - client.models.generate_content, - model=self.model_name, + response, _model_used = await self._generate( contents=[genai.types.Part.from_text(text=prompt)] ) if _cost_ctx: - asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000))) + asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000))) return self._parse_numbered_translation(response.text.strip(), cue_count) try: @@ -985,13 +992,11 @@ Segments to translate: logger.info(f"Rewriting TTS cue for safety: '{original_text[:50]}...'") _t0 = time.monotonic() - response = await asyncio.to_thread( - client.models.generate_content, - model=self.model_name, + response, _model_used = await self._generate( contents=[genai.types.Part.from_text(text=prompt)] ) if _cost_ctx: - asyncio.create_task(_record_gemini_usage(response, self.model_name, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000))) + asyncio.create_task(_record_gemini_usage(response, _model_used, _cost_ctx.get("user_id", "system"), _cost_ctx.get("job_id", ""), _cost_ctx.get("project_id"), int((time.monotonic() - _t0) * 1000))) result = response.text.strip()