Add toast notification when primary Gemini model falls back to backup

Backend: thread on_fallback callback through analysis chain
(gemini_service → agents → analysis_service → handlers). The handler
sends a 'model_fallback' WebSocket message exactly once per analysis
when the primary model is unavailable.

Frontend: handle 'model_fallback' WS message and show a dismissible
yellow toast at the bottom of the screen with an 8-second auto-dismiss.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Vadym Samoilenko 2026-03-02 13:00:12 +00:00
parent 9ecabafa2b
commit efa6e772e0
11 changed files with 87 additions and 20 deletions

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from typing import Awaitable, Callable, List, Optional, Tuple
from app.models.schemas import PreviousReviewContext, SubReview
@ -14,6 +14,7 @@ class BaseAgent(ABC):
self,
images: List[Tuple[bytes, str]],
previous_review: Optional[PreviousReviewContext] = None,
on_fallback: Optional[Callable[[], Awaitable[None]]] = None,
) -> SubReview:
"""
Analyze the proof and return a SubReview.

View file

@ -74,6 +74,7 @@ Your response MUST include:
channel: Optional[str] = None,
sub_channel: Optional[str] = None,
proof_type: Optional[str] = None,
on_fallback=None,
) -> SubReview:
"""
Analyze the proof for brand guideline adherence.
@ -169,9 +170,9 @@ If the proof is nonsensical, not a marketing material, or cannot be analyzed, se
if len(images) == 1:
file_data, file_type = images[0]
return await self.gemini.analyze_with_image(
prompt, file_data, file_type, include_revision_fields=include_revision_fields
prompt, file_data, file_type, include_revision_fields=include_revision_fields, on_fallback=on_fallback
)
else:
return await self.gemini.analyze_with_images(
prompt, images, include_revision_fields=include_revision_fields
prompt, images, include_revision_fields=include_revision_fields, on_fallback=on_fallback
)

View file

@ -57,6 +57,7 @@ Your response MUST include:
channel: Optional[str] = None,
sub_channel: Optional[str] = None,
proof_type: Optional[str] = None,
on_fallback=None,
) -> SubReview:
"""
Analyze the proof for channel best practices and content strategy.
@ -162,9 +163,9 @@ If the proof is nonsensical, not a marketing material, or cannot be analyzed, se
if len(images) == 1:
file_data, file_type = images[0]
return await self.gemini.analyze_with_image(
prompt, file_data, file_type, include_revision_fields=include_revision_fields
prompt, file_data, file_type, include_revision_fields=include_revision_fields, on_fallback=on_fallback
)
else:
return await self.gemini.analyze_with_images(
prompt, images, include_revision_fields=include_revision_fields
prompt, images, include_revision_fields=include_revision_fields, on_fallback=on_fallback
)

View file

@ -57,6 +57,7 @@ Your response MUST include:
channel: Optional[str] = None,
sub_channel: Optional[str] = None,
proof_type: Optional[str] = None,
on_fallback=None,
) -> SubReview:
"""
Analyze the proof for technical specifications compliance.
@ -170,9 +171,9 @@ If the proof is nonsensical, not a marketing material, or cannot be analyzed, se
if len(images) == 1:
file_data, file_type = images[0]
return await self.gemini.analyze_with_image(
prompt, file_data, file_type, include_revision_fields=include_revision_fields
prompt, file_data, file_type, include_revision_fields=include_revision_fields, on_fallback=on_fallback
)
else:
return await self.gemini.analyze_with_images(
prompt, images, include_revision_fields=include_revision_fields
prompt, images, include_revision_fields=include_revision_fields, on_fallback=on_fallback
)

View file

@ -79,6 +79,7 @@ In your summary:
channel: Optional[str] = None,
sub_channel: Optional[str] = None,
proof_type: Optional[str] = None,
on_fallback=None,
) -> tuple[OverallStatus, str, str | None]:
"""
Synthesize specialist reviews into final status and summary.
@ -175,7 +176,7 @@ Here are the specialist reviews:
Now, provide your final status and summary as a JSON object.
"""
result = await self.gemini.generate_summary(prompt)
result = await self.gemini.generate_summary(prompt, on_fallback=on_fallback)
overall_status = OverallStatus(result.get("overallStatus", "Analysis Error"))
summary = result.get("summary", "Unable to generate summary.")

View file

@ -57,6 +57,7 @@ Your response MUST include:
channel: Optional[str] = None,
sub_channel: Optional[str] = None,
proof_type: Optional[str] = None,
on_fallback=None,
) -> SubReview:
"""
Analyze the proof for legal compliance.
@ -172,9 +173,9 @@ If the proof is nonsensical, not a marketing material, or cannot be analyzed, se
if len(images) == 1:
file_data, file_type = images[0]
return await self.gemini.analyze_with_image(
prompt, file_data, file_type, include_revision_fields=include_revision_fields
prompt, file_data, file_type, include_revision_fields=include_revision_fields, on_fallback=on_fallback
)
else:
return await self.gemini.analyze_with_images(
prompt, images, include_revision_fields=include_revision_fields
prompt, images, include_revision_fields=include_revision_fields, on_fallback=on_fallback
)

View file

@ -108,6 +108,7 @@ class AnalysisService:
channel: Optional[str] = None,
sub_channel: Optional[str] = None,
proof_type: Optional[str] = None,
on_fallback=None,
) -> Tuple[str, SubReview]:
"""Run a single agent with callback notifications."""
agent = self.agents[agent_name]
@ -117,9 +118,9 @@ class AnalysisService:
await on_agent_update(agent_name, None)
if agent_name == "Brand Agent":
review = await agent.analyze(images, previous_review=previous_review, brand=brand, channel=channel, sub_channel=sub_channel, proof_type=proof_type)
review = await agent.analyze(images, previous_review=previous_review, brand=brand, channel=channel, sub_channel=sub_channel, proof_type=proof_type, on_fallback=on_fallback)
else:
review = await agent.analyze(images, previous_review=previous_review, channel=channel, sub_channel=sub_channel, proof_type=proof_type)
review = await agent.analyze(images, previous_review=previous_review, channel=channel, sub_channel=sub_channel, proof_type=proof_type, on_fallback=on_fallback)
logger.info(f"[ANALYSIS] Agent completed: {agent_name} - ragStatus: {review.ragStatus}")
if on_agent_update:
@ -138,6 +139,7 @@ class AnalysisService:
channel: Optional[str] = None,
sub_channel: Optional[str] = None,
proof_type: Optional[str] = None,
on_fallback=None,
) -> Tuple[AgentReview, Optional[List[Tuple[bytes, int, int]]]]:
"""
Analyze a proof using all agents in parallel.
@ -207,6 +209,7 @@ class AnalysisService:
channel=channel,
sub_channel=sub_channel,
proof_type=proof_type,
on_fallback=on_fallback,
)
for agent_name in self.AGENT_ORDER
]
@ -221,6 +224,7 @@ class AnalysisService:
overall_status, summary, financial_promotion_reason = await self.lead_agent.synthesize(
reviews, previous_analysis=previous_analysis,
channel=channel, sub_channel=sub_channel, proof_type=proof_type,
on_fallback=on_fallback,
)
logger.info(f"[ANALYSIS] Analysis complete - overallStatus: {overall_status}")

View file

@ -1,7 +1,7 @@
import asyncio
import json
import logging
from typing import List, Tuple
from typing import Awaitable, Callable, List, Optional, Tuple
from google import genai
from google.genai import types
@ -41,7 +41,12 @@ class GeminiService:
self.model = "gemini-3.1-pro-preview"
self.fallback_model = "gemini-3-flash-preview"
async def _generate_content(self, contents, config) -> any:
async def _generate_content(
self,
contents,
config,
on_fallback: Optional[Callable[[], Awaitable[None]]] = None,
) -> any:
"""Call generate_content, falling back to fallback_model if the primary fails or times out."""
try:
return await self.primary_client.aio.models.generate_content(
@ -54,6 +59,8 @@ class GeminiService:
f"[GEMINI API] Primary model {self.model} failed: {e}. "
f"Retrying with fallback {self.fallback_model}"
)
if on_fallback:
await on_fallback()
return await self.fallback_client.aio.models.generate_content(
model=self.fallback_model,
contents=contents,
@ -66,6 +73,7 @@ class GeminiService:
file_data: bytes,
file_type: str,
include_revision_fields: bool = False,
on_fallback: Optional[Callable[[], Awaitable[None]]] = None,
) -> SubReview:
"""
Analyze an image/file with Gemini and return a structured SubReview.
@ -142,8 +150,9 @@ class GeminiService:
contents=[file_part, prompt],
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=response_schema
response_schema=response_schema,
),
on_fallback=on_fallback,
)
logger.info(f"[GEMINI API] Response received from Gemini")
@ -194,6 +203,7 @@ class GeminiService:
prompt: str,
images: List[Tuple[bytes, str]],
include_revision_fields: bool = False,
on_fallback: Optional[Callable[[], Awaitable[None]]] = None,
) -> SubReview:
"""
Analyze multiple images with Gemini and return a structured SubReview.
@ -275,8 +285,9 @@ class GeminiService:
contents=contents,
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=response_schema
response_schema=response_schema,
),
on_fallback=on_fallback,
)
logger.info(f"[GEMINI API] Response received from Gemini (multi-image)")
@ -325,6 +336,7 @@ class GeminiService:
async def generate_summary(
self,
prompt: str,
on_fallback: Optional[Callable[[], Awaitable[None]]] = None,
) -> dict:
"""
Generate a text summary (for lead agent).
@ -356,8 +368,9 @@ class GeminiService:
contents=prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=response_schema
response_schema=response_schema,
),
on_fallback=on_fallback,
)
result = json.loads(response.text.strip())

View file

@ -133,6 +133,18 @@ async def handle_analyze_message(
sub_channel = data.get("sub_channel")
proof_type = data.get("proof_type")
# Build a once-only callback that notifies the client when the primary
# Gemini model is unavailable and the fallback model is used instead.
fallback_notified = False
async def on_model_fallback() -> None:
nonlocal fallback_notified
if fallback_notified:
return
fallback_notified = True
if manager.is_connected(client_id):
await manager.send_message(client_id, {"type": "model_fallback"})
# Run the analysis
logger.info("[WEBSOCKET] Starting analysis...")
result, pdf_pages = await analysis_service.analyze_proof(
@ -145,6 +157,7 @@ async def handle_analyze_message(
channel=channel,
sub_channel=sub_channel,
proof_type=proof_type,
on_fallback=on_model_fallback,
)
# Build the result dict

View file

@ -85,6 +85,14 @@ const AppContent: React.FC<{ msalInstance: any }> = ({ msalInstance }) => {
const [pendingProofId, setPendingProofId] = useState<string | null>(initialUrlState.proofId);
const [error, setError] = useState<string | null>(null);
const [isLoadingData, setIsLoadingData] = useState(true);
const [notification, setNotification] = useState<string | null>(null);
const notificationTimerRef = React.useRef<ReturnType<typeof setTimeout> | null>(null);
const showNotification = (message: string) => {
setNotification(message);
if (notificationTimerRef.current) clearTimeout(notificationTimerRef.current);
notificationTimerRef.current = setTimeout(() => setNotification(null), 8000);
};
// Agency filter state (session-level, for oversight_admin / super_admin)
const [selectedAgencyId, setSelectedAgencyId] = useState<string | null>(null);
@ -339,7 +347,7 @@ const AppContent: React.FC<{ msalInstance: any }> = ({ msalInstance }) => {
subChannel,
proofType,
brand: campaign.brandGuidelines,
});
}, showNotification);
const feedback = result.review;
@ -451,7 +459,7 @@ const AppContent: React.FC<{ msalInstance: any }> = ({ msalInstance }) => {
subChannel,
proofType,
brand: campaign.brandGuidelines,
});
}, showNotification);
// Refresh proofs from API to get the persisted data
try {
@ -971,6 +979,24 @@ const AppContent: React.FC<{ msalInstance: any }> = ({ msalInstance }) => {
{renderContent()}
</main>
</div>
{/* Model fallback notification toast */}
{notification && (
<div className="fixed bottom-6 left-1/2 -translate-x-1/2 z-50 flex items-start gap-3 bg-gray-800 text-white text-sm px-5 py-3.5 rounded-xl shadow-2xl max-w-md w-full mx-4 animate-fade-in">
<svg className="w-5 h-5 text-yellow-400 shrink-0 mt-0.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" strokeWidth={2}>
<path strokeLinecap="round" strokeLinejoin="round" d="M12 9v3.75m-9.303 3.376c-.866 1.5.217 3.374 1.948 3.374h14.71c1.73 0 2.813-1.874 1.948-3.374L13.949 3.378c-.866-1.5-3.032-1.5-3.898 0L2.697 16.126zM12 15.75h.007v.008H12v-.008z" />
</svg>
<div className="flex-1">
<p className="font-semibold text-yellow-300 mb-0.5">AI Model Notice</p>
<p className="text-gray-200 leading-snug">{notification}</p>
</div>
<button onClick={() => setNotification(null)} className="text-gray-400 hover:text-white shrink-0">
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" strokeWidth={2}>
<path strokeLinecap="round" strokeLinejoin="round" d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
</div>
)}
</div>
);
};

View file

@ -40,7 +40,8 @@ export const analyzeProof = async (
file: File,
onAgentUpdate: (name: AgentName | 'Summary', review?: SubReview) => void,
msalInstance: IPublicClientApplication,
options?: AnalyzeProofOptions
options?: AnalyzeProofOptions,
onNotification?: (message: string) => void,
): Promise<AnalyzeProofResult> => {
// Acquire token before connecting
const accessToken = await getAccessToken(msalInstance);
@ -126,6 +127,10 @@ export const analyzeProof = async (
});
break;
case 'model_fallback':
onNotification?.('The primary AI model is currently unavailable. Analysis is continuing with the backup model and may take longer than usual.');
break;
case 'error':
// Error occurred
resolved = true;