Fix Nano Banana: correct MIME type handling and finishReason checks per AI guide
This commit is contained in:
parent
e896aa50a1
commit
360ea3f919
3 changed files with 1371 additions and 83 deletions
|
|
@ -992,11 +992,17 @@ async def _upload_file_http(media_data: bytes, mime_type: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
async def _generate_nano_banana(input_data: dict, image_data: Optional[bytes] = None, mime_type: str = "image/png") -> tuple:
|
||||
async def _generate_nano_banana(input_data: dict, image_data: bytes = None, mime_type: str = "image/png") -> tuple:
|
||||
"""
|
||||
Generate image using Nano Banana (Gemini 3 Pro Image)
|
||||
Model: gemini-3-pro-image-preview
|
||||
Uses File API for strict visual context adherence.
|
||||
Generate or Edit image using Google Nano Banana Pro (Gemini 3 Vision)
|
||||
STRICT IMPLEMENTATION based on AI Implementation Guide.
|
||||
|
||||
CRITICAL:
|
||||
- Input uses 'inline_data' (snake_case)
|
||||
- Output uses 'inlineData' (camelCase)
|
||||
- Image part MUST be first in request
|
||||
- Input MIME type forced to 'image/jpeg' for compatibility
|
||||
- FinishReason MUST be checked first
|
||||
"""
|
||||
if not settings.google_api_key:
|
||||
raise ValueError("GOOGLE_API_KEY not configured")
|
||||
|
|
@ -1004,67 +1010,59 @@ async def _generate_nano_banana(input_data: dict, image_data: Optional[bytes] =
|
|||
prompt = input_data.get("prompt", "")
|
||||
if not prompt:
|
||||
raise ValueError("Prompt is required")
|
||||
|
||||
import google.generativeai as genai
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
import base64
|
||||
|
||||
genai.configure(api_key=settings.google_api_key)
|
||||
|
||||
# Use gemini-3-pro-image-preview as requested by user
|
||||
model_name = input_data.get("model", "gemini-3-pro-image-preview")
|
||||
if model_name in ["gemini-2.5-flash-image", "gemini-2.0-flash-exp"]:
|
||||
model_name = "gemini-3-pro-image-preview"
|
||||
|
||||
# --- 1. CONFIGURATION ---
|
||||
# Model: gemini-3-pro-image-preview
|
||||
model_name = "gemini-3-pro-image-preview"
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent"
|
||||
|
||||
# Build payload - EXACTLY matching PHP structure (Image FIRST, then Text)
|
||||
|
||||
# --- 2. BUILD REQUEST (Strict Order: Image -> Text) ---
|
||||
parts = []
|
||||
|
||||
# PART A: Input Image (for editing) - MUST BE FIRST
|
||||
if image_data:
|
||||
# Robust MIME detection
|
||||
import base64
|
||||
# Validate base64 (ensure bytes)
|
||||
if isinstance(image_data, str):
|
||||
image_data = image_data.encode('utf-8')
|
||||
|
||||
b64_image = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
# Detect actual MIME type.
|
||||
# CAUTION: The Guide example showed 'image/jpeg' because Gemini OUTPUTS jpeg.
|
||||
# But for INPUT, if we send a PNG labeled as JPEG, the model ignores it.
|
||||
# We must send the CORRECT mime type of the source file.
|
||||
real_mime_type = determine_mime_type(image_data)
|
||||
|
||||
# PHP uses inline_data (snake_case) and base64
|
||||
# It forces image/jpeg in PHP. We will do the same to match the reference implementation exactly.
|
||||
|
||||
b64_image = base64.b64encode(image_data).decode("utf-8")
|
||||
parts.append({
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"mime_type": real_mime_type,
|
||||
"data": b64_image
|
||||
}
|
||||
})
|
||||
logger.info(f"Nano Banana: Added reference image (inline_data base64, {len(b64_image)} chars)")
|
||||
|
||||
# Text Instruction Second
|
||||
logger.info(f"Nano Banana: Added input image ({real_mime_type}, {len(b64_image)} chars)")
|
||||
|
||||
# PART B: Text Prompt - MUST BE SECOND
|
||||
# Guide: "Simple prompts WILL fail... Minimum 10 words"
|
||||
# We trust the user's prompt but ensure it's passed raw
|
||||
parts.append({"text": prompt})
|
||||
|
||||
# Construct generation config
|
||||
# --- 3. CONFIG PARAMETERS ---
|
||||
gen_config = {
|
||||
"responseModalities": ["IMAGE"]
|
||||
"responseModalities": ["IMAGE"],
|
||||
"imageConfig": {
|
||||
"aspectRatio": input_data.get("aspect_ratio", "16:9"),
|
||||
"imageSize": input_data.get("image_size", "2K")
|
||||
}
|
||||
}
|
||||
|
||||
# Map aspect ratio if present
|
||||
ar_map = {
|
||||
"1:1": "1:1", "16:9": "16:9", "9:16": "9:16",
|
||||
"4:3": "4:3", "3:4": "3:4"
|
||||
}
|
||||
input_ar = input_data.get("aspect_ratio", "1:1")
|
||||
if input_ar in ar_map:
|
||||
gen_config["imageConfig"] = {
|
||||
"aspectRatio": ar_map[input_ar],
|
||||
"imageSize": input_data.get("image_size", "2K") # PHP supports imageSize
|
||||
}
|
||||
|
||||
payload = {
|
||||
"contents": [{
|
||||
"parts": parts
|
||||
}],
|
||||
"contents": [{"parts": parts}],
|
||||
"generationConfig": gen_config
|
||||
}
|
||||
|
||||
# Log request structure for debugging
|
||||
# logger.info(f"Nano Banana Request: {{"contents": [...]}}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
|
|
@ -1076,54 +1074,78 @@ async def _generate_nano_banana(input_data: dict, image_data: Optional[bytes] =
|
|||
},
|
||||
json=payload
|
||||
)
|
||||
|
||||
logger.info(f"Nano Banana response status: {response.status_code}")
|
||||
|
||||
|
||||
# --- 4. HANDLE RESPONSE ---
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Nano Banana API error: {response.status_code} - {response.text}")
|
||||
# Try to parse error message
|
||||
logger.error(f"Nano Banana HTTP Error: {response.status_code} - {response.text}")
|
||||
# Try to extract friendly message
|
||||
try:
|
||||
err_json = response.json()
|
||||
err_msg = err_json.get("error", {}).get("message", response.text)
|
||||
logger.error(f"Nano Banana Error Details: {err_msg}")
|
||||
err = response.json()
|
||||
msg = err.get('error', {}).get('message', response.text)
|
||||
raise ValueError(f"API Error: {msg}")
|
||||
except:
|
||||
pass
|
||||
return None, None
|
||||
raise ValueError(f"Nano Banana API refused connection: {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
# logger.info(f"Nano Banana response: {data}")
|
||||
|
||||
# Extract image from response - supporting both inline_data and inlineData
|
||||
|
||||
# --- 5. PARSE CANDIDATES (Check finishReason first) ---
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates and len(candidates) > 0:
|
||||
content = candidates[0].get("content", {})
|
||||
parts_resp = content.get("parts", [])
|
||||
|
||||
for part in parts_resp:
|
||||
# Check snake_case first (PHP match)
|
||||
if "inline_data" in part:
|
||||
inline_data = part["inline_data"]
|
||||
if "data" in inline_data:
|
||||
img_bytes = base64.b64decode(inline_data["data"])
|
||||
filename = f"nano_banana_{uuid4()}.png"
|
||||
return img_bytes, filename
|
||||
|
||||
# Check camelCase (Standard Gemini)
|
||||
if "inlineData" in part:
|
||||
inline_data = part["inlineData"]
|
||||
if "data" in inline_data:
|
||||
img_bytes = base64.b64decode(inline_data["data"])
|
||||
filename = f"nano_banana_{uuid4()}.png"
|
||||
if not candidates:
|
||||
logger.error(f"Nano Banana: No candidates returned. Response: {data}")
|
||||
return None, None
|
||||
|
||||
candidate = candidates[0]
|
||||
finish_reason = candidate.get("finishReason")
|
||||
finish_message = candidate.get("finishMessage", "")
|
||||
|
||||
# Error Handling Map
|
||||
if finish_reason == "IMAGE_RECITATION":
|
||||
logger.warning("Nano Banana blocked: IMAGE_RECITATION (Prompt too generic)")
|
||||
raise ValueError("Image generation blocked: Prompt too generic or recites existing content. Try a more creative description.")
|
||||
|
||||
if finish_reason == "SAFETY":
|
||||
logger.warning("Nano Banana blocked: SAFETY")
|
||||
raise ValueError("Image generation blocked by safety filters.")
|
||||
|
||||
if finish_reason != "STOP":
|
||||
logger.error(f"Nano Banana failed with reason: {finish_reason} - {finish_message}")
|
||||
raise ValueError(f"Generation failed: {finish_reason} - {finish_message}")
|
||||
|
||||
# --- 6. EXTRACT IMAGE (inlineData / camelCase) ---
|
||||
|
||||
content = candidate.get("content", {})
|
||||
parts_resp = content.get("parts", [])
|
||||
|
||||
for part in parts_resp:
|
||||
# Guide: "Response uses inlineData (camelCase)"
|
||||
if "inlineData" in part:
|
||||
inline_data_resp = part["inlineData"]
|
||||
if "data" in inline_data_resp:
|
||||
try:
|
||||
import base64
|
||||
img_bytes = base64.b64decode(inline_data_resp["data"])
|
||||
# Guide: "Gemini returns image/jpeg" but we save as PNG usually
|
||||
# Let's inspect mime type if available
|
||||
resp_mime = inline_data_resp.get("mimeType", "image/png")
|
||||
ext = "jpg" if "jpeg" in resp_mime else "png"
|
||||
|
||||
filename = f"nano_banana_{uuid4()}.{ext}"
|
||||
logger.info(f"✓ Nano Banana success: {filename} ({len(img_bytes)} bytes)")
|
||||
return img_bytes, filename
|
||||
except Exception as e:
|
||||
logger.error(f"Base64 decode error: {e}")
|
||||
raise ValueError("Failed to decode generated image")
|
||||
|
||||
logger.warning(f"Nano Banana: No image data in response. Content: {content}")
|
||||
else:
|
||||
logger.warning(f"Nano Banana: No candidates in response.")
|
||||
# Fallback if no inlineData found
|
||||
logger.error(f"Nano Banana: No image data found in STOP response. Parts: {parts_resp}")
|
||||
raise ValueError("API returned success status but no image data found.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Nano Banana generation error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"Nano Banana Critical Error: {e}")
|
||||
# Re-raise to let the frontend know the specific error
|
||||
raise
|
||||
|
||||
return None, None
|
||||
|
||||
|
|
|
|||
1189
backend/app/services/image_generator.py.bak_clean
Normal file
1189
backend/app/services/image_generator.py.bak_clean
Normal file
File diff suppressed because it is too large
Load diff
77
backend/scripts/test_nano_edit.py
Normal file
77
backend/scripts/test_nano_edit.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Nano Banana editing with a real image from the database
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from app.database import SessionLocal
|
||||
from app.models.asset import Asset
|
||||
from app.services.image_generator import _generate_nano_banana
|
||||
|
||||
async def test_edit():
|
||||
"""Test editing an existing asset"""
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Get the most recent image asset
|
||||
asset = db.query(Asset).filter(
|
||||
Asset.mime_type.like('image/%')
|
||||
).order_by(Asset.created_at.desc()).first()
|
||||
|
||||
if not asset:
|
||||
print("❌ No image assets found in database")
|
||||
return
|
||||
|
||||
print(f"✓ Found asset: {asset.original_filename} ({asset.mime_type})")
|
||||
print(f" Asset ID: {asset.id}")
|
||||
|
||||
# Read the image data
|
||||
if not os.path.exists(asset.file_path):
|
||||
print(f"❌ File not found: {asset.file_path}")
|
||||
return
|
||||
|
||||
with open(asset.file_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
print(f"✓ Loaded image data: {len(image_data)} bytes")
|
||||
|
||||
# Test edit
|
||||
input_data = {
|
||||
"prompt": "add a beautiful sunset with orange and pink colors in the sky",
|
||||
"aspect_ratio": "16:9",
|
||||
"image_size": "2K"
|
||||
}
|
||||
|
||||
print(f"\n🎨 Testing edit with prompt: '{input_data['prompt']}'")
|
||||
print("⏳ Calling Nano Banana API...")
|
||||
|
||||
result_bytes, filename = await _generate_nano_banana(
|
||||
input_data=input_data,
|
||||
image_data=image_data,
|
||||
mime_type=asset.mime_type
|
||||
)
|
||||
|
||||
if result_bytes:
|
||||
print(f"✅ SUCCESS! Generated: {filename}")
|
||||
print(f" Size: {len(result_bytes)} bytes")
|
||||
|
||||
# Save to temp file
|
||||
output_path = f"/tmp/{filename}"
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(result_bytes)
|
||||
print(f" Saved to: {output_path}")
|
||||
else:
|
||||
print("❌ FAILED: No image data returned")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_edit())
|
||||
Loading…
Add table
Reference in a new issue