From 360ea3f91978f35336bd71b4577f51e9233ef816 Mon Sep 17 00:00:00 2001 From: DJP Date: Fri, 12 Dec 2025 16:21:58 -0500 Subject: [PATCH] Fix Nano Banana: correct MIME type handling and finishReason checks per AI guide --- backend/app/services/image_generator.py | 188 +-- .../app/services/image_generator.py.bak_clean | 1189 +++++++++++++++++ backend/scripts/test_nano_edit.py | 77 ++ 3 files changed, 1371 insertions(+), 83 deletions(-) create mode 100644 backend/app/services/image_generator.py.bak_clean create mode 100644 backend/scripts/test_nano_edit.py diff --git a/backend/app/services/image_generator.py b/backend/app/services/image_generator.py index db4f79d..d8e15c2 100644 --- a/backend/app/services/image_generator.py +++ b/backend/app/services/image_generator.py @@ -992,11 +992,17 @@ async def _upload_file_http(media_data: bytes, mime_type: str) -> Optional[str]: return None -async def _generate_nano_banana(input_data: dict, image_data: Optional[bytes] = None, mime_type: str = "image/png") -> tuple: +async def _generate_nano_banana(input_data: dict, image_data: bytes = None, mime_type: str = "image/png") -> tuple: """ - Generate image using Nano Banana (Gemini 3 Pro Image) - Model: gemini-3-pro-image-preview - Uses File API for strict visual context adherence. + Generate or Edit image using Google Nano Banana Pro (Gemini 3 Vision) + STRICT IMPLEMENTATION based on AI Implementation Guide. + + CRITICAL: + - Input uses 'inline_data' (snake_case) + - Output uses 'inlineData' (camelCase) + - Image part MUST be first in request + - Input MIME type forced to 'image/jpeg' for compatibility + - FinishReason MUST be checked first """ if not settings.google_api_key: raise ValueError("GOOGLE_API_KEY not configured") @@ -1004,67 +1010,59 @@ async def _generate_nano_banana(input_data: dict, image_data: Optional[bytes] = prompt = input_data.get("prompt", "") if not prompt: raise ValueError("Prompt is required") - - import google.generativeai as genai - import tempfile - import os - - import base64 - - genai.configure(api_key=settings.google_api_key) - - # Use gemini-3-pro-image-preview as requested by user - model_name = input_data.get("model", "gemini-3-pro-image-preview") - if model_name in ["gemini-2.5-flash-image", "gemini-2.0-flash-exp"]: - model_name = "gemini-3-pro-image-preview" + # --- 1. CONFIGURATION --- + # Model: gemini-3-pro-image-preview + model_name = "gemini-3-pro-image-preview" url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" - - # Build payload - EXACTLY matching PHP structure (Image FIRST, then Text) + + # --- 2. BUILD REQUEST (Strict Order: Image -> Text) --- parts = [] + # PART A: Input Image (for editing) - MUST BE FIRST if image_data: - # Robust MIME detection + import base64 + # Validate base64 (ensure bytes) + if isinstance(image_data, str): + image_data = image_data.encode('utf-8') + + b64_image = base64.b64encode(image_data).decode("utf-8") + + # Detect actual MIME type. + # CAUTION: The Guide example showed 'image/jpeg' because Gemini OUTPUTS jpeg. + # But for INPUT, if we send a PNG labeled as JPEG, the model ignores it. + # We must send the CORRECT mime type of the source file. real_mime_type = determine_mime_type(image_data) - # PHP uses inline_data (snake_case) and base64 - # It forces image/jpeg in PHP. We will do the same to match the reference implementation exactly. - - b64_image = base64.b64encode(image_data).decode("utf-8") parts.append({ "inline_data": { - "mime_type": "image/jpeg", + "mime_type": real_mime_type, "data": b64_image } }) - logger.info(f"Nano Banana: Added reference image (inline_data base64, {len(b64_image)} chars)") - - # Text Instruction Second + logger.info(f"Nano Banana: Added input image ({real_mime_type}, {len(b64_image)} chars)") + + # PART B: Text Prompt - MUST BE SECOND + # Guide: "Simple prompts WILL fail... Minimum 10 words" + # We trust the user's prompt but ensure it's passed raw parts.append({"text": prompt}) - # Construct generation config + # --- 3. CONFIG PARAMETERS --- gen_config = { - "responseModalities": ["IMAGE"] + "responseModalities": ["IMAGE"], + "imageConfig": { + "aspectRatio": input_data.get("aspect_ratio", "16:9"), + "imageSize": input_data.get("image_size", "2K") + } } - # Map aspect ratio if present - ar_map = { - "1:1": "1:1", "16:9": "16:9", "9:16": "9:16", - "4:3": "4:3", "3:4": "3:4" - } - input_ar = input_data.get("aspect_ratio", "1:1") - if input_ar in ar_map: - gen_config["imageConfig"] = { - "aspectRatio": ar_map[input_ar], - "imageSize": input_data.get("image_size", "2K") # PHP supports imageSize - } - payload = { - "contents": [{ - "parts": parts - }], + "contents": [{"parts": parts}], "generationConfig": gen_config } + + # Log request structure for debugging + # logger.info(f"Nano Banana Request: {{"contents": [...]}}") try: async with httpx.AsyncClient(timeout=120.0) as client: @@ -1076,54 +1074,78 @@ async def _generate_nano_banana(input_data: dict, image_data: Optional[bytes] = }, json=payload ) - - logger.info(f"Nano Banana response status: {response.status_code}") - + + # --- 4. HANDLE RESPONSE --- + if response.status_code != 200: - logger.error(f"Nano Banana API error: {response.status_code} - {response.text}") - # Try to parse error message + logger.error(f"Nano Banana HTTP Error: {response.status_code} - {response.text}") + # Try to extract friendly message try: - err_json = response.json() - err_msg = err_json.get("error", {}).get("message", response.text) - logger.error(f"Nano Banana Error Details: {err_msg}") + err = response.json() + msg = err.get('error', {}).get('message', response.text) + raise ValueError(f"API Error: {msg}") except: - pass - return None, None + raise ValueError(f"Nano Banana API refused connection: {response.status_code}") data = response.json() - # logger.info(f"Nano Banana response: {data}") - - # Extract image from response - supporting both inline_data and inlineData + + # --- 5. PARSE CANDIDATES (Check finishReason first) --- + candidates = data.get("candidates", []) - if candidates and len(candidates) > 0: - content = candidates[0].get("content", {}) - parts_resp = content.get("parts", []) - - for part in parts_resp: - # Check snake_case first (PHP match) - if "inline_data" in part: - inline_data = part["inline_data"] - if "data" in inline_data: - img_bytes = base64.b64decode(inline_data["data"]) - filename = f"nano_banana_{uuid4()}.png" - return img_bytes, filename - - # Check camelCase (Standard Gemini) - if "inlineData" in part: - inline_data = part["inlineData"] - if "data" in inline_data: - img_bytes = base64.b64decode(inline_data["data"]) - filename = f"nano_banana_{uuid4()}.png" + if not candidates: + logger.error(f"Nano Banana: No candidates returned. Response: {data}") + return None, None + + candidate = candidates[0] + finish_reason = candidate.get("finishReason") + finish_message = candidate.get("finishMessage", "") + + # Error Handling Map + if finish_reason == "IMAGE_RECITATION": + logger.warning("Nano Banana blocked: IMAGE_RECITATION (Prompt too generic)") + raise ValueError("Image generation blocked: Prompt too generic or recites existing content. Try a more creative description.") + + if finish_reason == "SAFETY": + logger.warning("Nano Banana blocked: SAFETY") + raise ValueError("Image generation blocked by safety filters.") + + if finish_reason != "STOP": + logger.error(f"Nano Banana failed with reason: {finish_reason} - {finish_message}") + raise ValueError(f"Generation failed: {finish_reason} - {finish_message}") + + # --- 6. EXTRACT IMAGE (inlineData / camelCase) --- + + content = candidate.get("content", {}) + parts_resp = content.get("parts", []) + + for part in parts_resp: + # Guide: "Response uses inlineData (camelCase)" + if "inlineData" in part: + inline_data_resp = part["inlineData"] + if "data" in inline_data_resp: + try: + import base64 + img_bytes = base64.b64decode(inline_data_resp["data"]) + # Guide: "Gemini returns image/jpeg" but we save as PNG usually + # Let's inspect mime type if available + resp_mime = inline_data_resp.get("mimeType", "image/png") + ext = "jpg" if "jpeg" in resp_mime else "png" + + filename = f"nano_banana_{uuid4()}.{ext}" + logger.info(f"✓ Nano Banana success: {filename} ({len(img_bytes)} bytes)") return img_bytes, filename + except Exception as e: + logger.error(f"Base64 decode error: {e}") + raise ValueError("Failed to decode generated image") - logger.warning(f"Nano Banana: No image data in response. Content: {content}") - else: - logger.warning(f"Nano Banana: No candidates in response.") + # Fallback if no inlineData found + logger.error(f"Nano Banana: No image data found in STOP response. Parts: {parts_resp}") + raise ValueError("API returned success status but no image data found.") except Exception as e: - logger.error(f"Nano Banana generation error: {e}") - import traceback - traceback.print_exc() + logger.error(f"Nano Banana Critical Error: {e}") + # Re-raise to let the frontend know the specific error + raise return None, None diff --git a/backend/app/services/image_generator.py.bak_clean b/backend/app/services/image_generator.py.bak_clean new file mode 100644 index 0000000..75ea71a --- /dev/null +++ b/backend/app/services/image_generator.py.bak_clean @@ -0,0 +1,1189 @@ +"""Image Generator Service - Multiple AI Providers + +Supported Providers: +- openai: GPT-Image-1 (latest) or DALL-E 3 +- imagen: Google Imagen 4 (Standard, Ultra, Fast) +- nano-banana: Gemini 2.5 Flash Image / Nano Banana Pro +- stable-diffusion: Stability AI SDXL, SD3, image-to-image +- leonardo: Leonardo.ai models +- ideogram: Ideogram v2 with text rendering +- flux: Black Forest Labs Flux Pro + +OpenAI GPT-Image-1 (April 2025): +- model: 'gpt-image-1' (default) or 'dall-e-3' +- quality: 'low', 'medium', 'high' (default high) +- size: 1024x1024, 1024x1536, 1536x1024 +- background: 'transparent', 'opaque', 'auto' (for PNG/WebP) +- output_format: 'png', 'jpeg', 'webp' +- n: 1-10 images per request +- Pricing: ~$0.02 (low), $0.07 (medium), $0.19 (high) per image + +Google Imagen 4 (December 2025): +- model: 'imagen-4.0-generate-001' (default), 'imagen-4.0-ultra-generate-001', 'imagen-4.0-fast-generate-001' +- image_size: '1K', '2K' (Ultra/Standard only) +- aspect_ratio: '1:1', '3:4', '4:3', '9:16', '16:9' +- number_of_images: 1-4 +- enhance_prompt: true/false (LLM prompt enhancement) +- person_generation: 'dont_allow', 'allow_adult', 'allow_all' +- Pricing: $0.02 (Fast), $0.04 (Standard), $0.06 (Ultra) per image + +Nano Banana / Gemini Image (December 2025): +- model: 'gemini-2.5-flash-image' (Nano Banana), 'gemini-3-pro-image-preview' (Nano Banana Pro) +- aspect_ratio: '1:1', '2:3', '3:2', '3:4', '4:3', '4:5', '5:4', '9:16', '16:9', '21:9' +- image_size: '1K', '2K', '4K' (Pro only for 4K) +- Features: Text rendering, image editing, multi-turn conversation +- Pricing: ~$0.04 per 1MP image + +DALL-E 3 Options: +- quality: 'standard' or 'hd' (default hd) +- style: 'vivid' (hyper-real) or 'natural' (more realistic) +- size: 1024x1024, 1024x1792, 1792x1024 + +Stability AI Options: +- model: sd3.5-large, sd3.5-medium, sd3-large, sd3-medium, sdxl-1.0 +- aspect_ratio: 1:1, 16:9, 9:16, 4:3, 3:4, 21:9, 9:21 +- negative_prompt: What to avoid in generation +- image_to_image: Use input image as starting point +- strength: 0.0-1.0 for image-to-image (how much to change) +- style_preset: enhance, anime, photographic, digital-art, etc. +""" +import httpx +import os +import base64 +import logging +from uuid import uuid4 +from datetime import datetime +from typing import Optional, Dict, Any, Tuple + +logger = logging.getLogger(__name__) + +from app.database import SessionLocal +from app.models.job import Job +from app.models.asset import Asset +from app.config import settings + +def determine_mime_type(data: bytes) -> str: + """Detect MIME type from magic bytes""" + if data.startswith(b'\x89PNG\r\n\x1a\n'): + return 'image/png' + elif data.startswith(b'\xff\xd8'): + return 'image/jpeg' + elif data.startswith(b'RIFF') and data[8:12] == b'WEBP': + return 'image/webp' + return 'image/png' # Default fallback + +# Provider configurations +IMAGE_PROVIDERS = { + "openai": { + "name": "OpenAI Image Generation", + "models": ["gpt-image-1", "dall-e-3", "dall-e-2"], + "default_model": "gpt-image-1", + "gpt-image-1": { + "sizes": ["1024x1024", "1024x1536", "1536x1024"], + "qualities": ["low", "medium", "high"], + "output_formats": ["png", "jpeg", "webp"], + "backgrounds": ["auto", "transparent", "opaque"], + "max_images": 10 + }, + "dall-e-3": { + "sizes": ["1024x1024", "1024x1792", "1792x1024"], + "qualities": ["standard", "hd"], + "styles": ["vivid", "natural"] + }, + "supports_styles": True + }, + "imagen": { + "name": "Google Imagen 4", + "models": ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001", "imagen-4.0-fast-generate-001"], + "default_model": "imagen-4.0-generate-001", + "aspect_ratios": ["1:1", "3:4", "4:3", "9:16", "16:9"], + "image_sizes": ["1K", "2K"], + "max_images": 4, + "supports_enhance_prompt": True, + "supports_person_generation": True + }, + "nano-banana": { + "name": "Nano Banana (Gemini Image)", + "models": ["gemini-3-pro-image-preview", "gemini-2.0-flash-exp"], + "default_model": "gemini-3-pro-image-preview", + "aspect_ratios": ["1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], + "image_sizes": ["1K", "2K", "4K"], + "supports_text_rendering": True, + "supports_image_editing": True + }, + "stable-diffusion": { + "name": "Stability AI", + "models": ["sd3.5-large", "sd3.5-medium", "sd3-large", "sd3-medium", "sdxl-1.0"], + "default_model": "sd3.5-large", + "aspect_ratios": ["1:1", "16:9", "9:16", "4:3", "3:4", "21:9", "9:21"], + "supports_img2img": True, + "supports_negative_prompt": True + }, + "leonardo": { + "name": "Leonardo.ai", + "models": { + # Latest Models (2025) + # Phoenix: de7d3faf-762f-48e0-b3b7-9d0ac3a3fcf3 (Found in docs) + "de7d3faf-762f-48e0-b3b7-9d0ac3a3fcf3": "Leonardo Phoenix 1.0", + "7b592283-e8a7-4c5a-9ba6-d18c31f258b9": "Lucid Origin", + "05ce0082-2d80-4a2d-8653-4d1c85e2418e": "Lucid Realism", + "28aeddf8-bd19-4803-80fc-79602d1a9989": "FLUX.1 Kontext", + "b2614463-296c-462a-9586-aafdb8f00e36": "Flux Dev", + "1dd50843-d653-4516-a8e3-f0238ee453ff": "Flux Schnell", + + # XL Models + "aa77f04e-3eec-4034-9c07-d0f619684628": "Leonardo Kino XL", + "5c232a9e-9061-4777-980a-ddc8e65647c6": "Leonardo Vision XL", + "b24e16ff-06e3-43eb-8d33-4416c2d75876": "Leonardo Lightning XL", + "1e60896f-3c26-4296-8ecc-53e2afecc132": "Leonardo Diffusion XL", + + # Older/Other Support + "16e7060a-803e-4df3-97ee-edcfa5dc9cc8": "SDXL 1.0", + "ac614f96-1082-45bf-be9d-757f2d31c174": "DreamShaper v7", + "e316348f-7773-490e-adcd-46757c738eb7": "Absolute Reality v1.6" + }, + "default_model": "de7d3faf-762f-48e0-b3b7-9d0ac3a3fcf3", + # Explicit mapping for Aspect Ratio -> Dimensions (Width x Height) + # These are generally safe for SDXL/Phoenix models + "dimensions": { + "1:1": {"width": 1024, "height": 1024}, + "16:9": {"width": 1472, "height": 832}, + "9:16": {"width": 832, "height": 1472}, + "4:3": {"width": 1248, "height": 928}, # Approx for SDXL + "3:4": {"width": 928, "height": 1248}, + "21:9": {"width": 1536, "height": 640}, # Ultra wide + "9:21": {"width": 640, "height": 1536} + }, + "style_presets": [ + "ANIME", "BOKEH", "CINEMATIC", "CINEMATIC_CLOSEUP", "CREATIVE", + "DYNAMIC", "ENVIRONMENT", "FASHION", "FILM", "FOOD", "GENERAL", + "HDR", "ILLUSTRATION", "LEONARDO", "LONG_EXPOSURE", "MACRO", + "MINIMALISTIC", "MONOCHROME", "MOODY", "NONE", "NEUTRAL", + "PHOTOGRAPHY", "PORTRAIT", "RAYTRACED", "RENDER_3D", "RETRO", + "SKETCH_BW", "SKETCH_COLOR", "STOCK_PHOTO", "VIBRANT", "UNPROCESSED" + ], + "supports_img2img": True, + "supports_character_reference": True, + "supports_style_reference": True + }, + "bria": { + "name": "Bria AI", + "models": ["base", "fast"], + "default_model": "base", + "aspect_ratios": ["1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9"], + "mediums": ["photography", "art"], + "supports_prompt_enhancement": True, + "base_config": {"steps_num": [20, 50], "guidance_scale": [1, 10]}, + "fast_config": {"steps_num": [4, 10]} + }, + "ideogram": { + "name": "Ideogram", + "models": ["V_2", "V_2_TURBO"], + "supports_text_rendering": True + }, + "flux": { + "name": "Flux Pro", + "models": ["flux-pro-1.1", "flux-dev", "flux-schnell"], + "supports_img2img": True + } +} + +STABILITY_STYLE_PRESETS = [ + "enhance", "anime", "photographic", "digital-art", "comic-book", + "fantasy-art", "line-art", "analog-film", "neon-punk", "isometric", + "low-poly", "origami", "modeling-compound", "cinematic", "3d-model", "pixel-art" +] + + +async def generate(job_id: str): + """Generate image based on provider""" + db = SessionLocal() + try: + job = db.query(Job).filter(Job.id == job_id).first() + if not job: + return + + input_data = job.input_data + provider = input_data.get("provider", "openai") + prompt = input_data.get("prompt", "") + + # Update progress + job.progress = 10 + job.api_provider = provider + db.commit() + + image_data = None + filename = None + + if provider == "openai" or provider == "dalle3": + image_data, filename = await _generate_openai(input_data) + job.api_model = input_data.get("model", "gpt-image-1") + elif provider == "imagen": + image_data, filename = await _generate_imagen(input_data) + job.api_model = input_data.get("model", "imagen-4.0-generate-001") + elif provider == "nano-banana" or provider == "gemini": + # Fetch reference image if provided + ref_id = input_data.get("reference_asset_id") + ref_image_data = None + ref_mime_type = "image/png" # Default + + if ref_id: + ref_asset = db.query(Asset).filter(Asset.id == ref_id).first() + if ref_asset and os.path.exists(ref_asset.file_path): + with open(ref_asset.file_path, "rb") as f: + ref_image_data = f.read() + if ref_asset.mime_type: + ref_mime_type = ref_asset.mime_type + + image_data, filename = await _generate_nano_banana(input_data, ref_image_data, ref_mime_type) + job.api_model = input_data.get("model", "gemini-3-pro-image-preview") + elif provider == "stable-diffusion": + image_data, filename = await _generate_stability(input_data) + job.api_model = input_data.get("model", "sd3.5-large") + elif provider == "leonardo": + image_data, filename = await _generate_leonardo(input_data) + job.api_model = "leonardo" + elif provider == "ideogram": + image_data, filename = await _generate_ideogram(input_data) + job.api_model = "ideogram-v2" + elif provider == "flux": + image_data, filename = await _generate_flux(input_data) + job.api_model = "flux-pro" + elif provider == "bria": + image_data, filename = await _generate_bria(input_data) + job.api_model = input_data.get("model", "base") + elif provider == "runway-image": + image_data, filename = await _generate_runway_image(input_data) + job.api_model = "gen4_image" + else: + raise ValueError(f"Unknown provider: {provider}") + + job.progress = 80 + db.commit() + + # Save image + if image_data: + storage_path = os.path.join(settings.storage_path, "images") + os.makedirs(storage_path, exist_ok=True) + file_path = os.path.join(storage_path, filename) + + with open(file_path, "wb") as f: + f.write(image_data) + + # Create asset + asset = Asset( + user_id=job.user_id, + project_id=job.project_id, + original_filename=filename, + stored_filename=filename, + file_path=file_path, + file_type="image", + mime_type="image/png", + file_size_bytes=len(image_data), + source_module="image_generator", + source_job_id=job.id, + asset_metadata={ + "prompt": prompt, + "provider": provider, + "model": job.api_model + } + ) + db.add(asset) + db.commit() + db.refresh(asset) + + job.output_asset_ids = [asset.id] + job.output_data = {"asset_id": str(asset.id), "file_path": file_path} + + # Log Usage + try: + from app.utils.logging import log_model_usage + # Other imports are available globally + + # Placeholder values for logging, these would ideally be returned by _generate_ functions + # For now, we'll use what's available from input_data and job.api_model + model = job.api_model + width = input_data.get("width") + height = input_data.get("height") + n = input_data.get("n", 1) # Number of images requested + ext = "png" # Default, actual ext should come from _generate_ functions + + # Use existing asset data for logging + output_asset_ids = job.output_asset_ids or [] + output_paths = [] + if job.output_data and "file_path" in job.output_data: + output_paths.append(job.output_data["file_path"]) + + duration_ms = 0 + if job.started_at: + duration_ms = int((datetime.utcnow() - job.started_at).total_seconds() * 1000) + + log_model_usage( + db=db, + job_id=str(job.id), + user_id=str(job.user_id), + module="image_generator", + action="generate", + provider=provider, + model=model, + usage_stats={ + "images": len(output_asset_ids), + "processing_time_ms": duration_ms + }, + request_metadata={ + "prompt": prompt, + "negative_prompt": input_data.get("negative_prompt"), + "size": f"{width}x{height}" if width and height else None, + "n": n + }, + response_metadata={ + "output_assets": [str(a_id) for a_id in output_asset_ids], + "filenames": [os.path.basename(p) for p in output_paths] + } + ) + except Exception as log_e: + logger.error(f"Failed to log image generation usage: {log_e}") + + job.output_asset_ids = output_asset_ids + job.output_data = { + "prompt": prompt, + "provider": provider, + "model": model, + "image_paths": output_paths + } + job.progress = 100 + job.status = "completed" + job.completed_at = datetime.utcnow() + db.commit() + + except Exception as e: + job.status = "failed" + job.error_message = str(e) + db.commit() + finally: + db.close() + + +async def _generate_openai(input_data: dict) -> Tuple[Optional[bytes], Optional[str]]: + """Generate image using OpenAI GPT-Image-1 or DALL-E 3 + + GPT-Image-1 Parameters (default): + - prompt: Text description (max 32000 chars) + - quality: 'low', 'medium', 'high' (default: high) + - size: '1024x1024', '1024x1536', '1536x1024' + - background: 'transparent', 'opaque', 'auto' + - output_format: 'png', 'jpeg', 'webp' (default: png) + - output_compression: 0-100 for jpeg/webp + - moderation: 'auto' or 'low' (less restrictive) + - n: 1-10 images + + DALL-E 3 Parameters: + - prompt: Text description (max 4000 chars) + - quality: 'standard' or 'hd' (default: hd) + - style: 'vivid' or 'natural' (default: vivid) + - size: '1024x1024', '1024x1792', '1792x1024' + """ + prompt = input_data.get("prompt", "") + model = input_data.get("model", "gpt-image-1") + width = input_data.get("width", 1024) + height = input_data.get("height", 1024) + + # Determine size based on width/height + if width > height: + size = "1536x1024" if model == "gpt-image-1" else "1792x1024" + elif height > width: + size = "1024x1536" if model == "gpt-image-1" else "1024x1792" + else: + size = "1024x1024" + + async with httpx.AsyncClient(timeout=180) as client: + if model == "gpt-image-1": + # GPT-Image-1 (latest model) + quality = input_data.get("quality", "high") + background = input_data.get("background", "auto") + output_format = input_data.get("output_format", "png") + output_compression = input_data.get("output_compression", 100) + moderation = input_data.get("moderation", "auto") + n = min(input_data.get("n", 1), 10) + + payload = { + "model": "gpt-image-1", + "prompt": prompt, + "size": size, + "quality": quality, + "n": n + } + + # Add optional parameters + if background != "auto": + payload["background"] = background + if output_format != "png": + payload["output_format"] = output_format + if output_format in ["jpeg", "webp"] and output_compression != 100: + payload["output_compression"] = output_compression + if moderation != "auto": + payload["moderation"] = moderation + + response = await client.post( + "https://api.openai.com/v1/images/generations", + headers={ + "Authorization": f"Bearer {settings.openai_api_key}", + "Content-Type": "application/json" + }, + json=payload + ) + response.raise_for_status() + data = response.json() + + if data.get("data") and len(data["data"]) > 0: + # GPT-Image-1 always returns base64 + b64_image = data["data"][0].get("b64_json") + if b64_image: + ext = output_format if output_format in ["png", "jpeg", "webp"] else "png" + filename = f"gptimage1_{quality}_{uuid4()}.{ext}" + return base64.b64decode(b64_image), filename + + else: + # DALL-E 3 (or DALL-E 2) + quality = input_data.get("quality", "hd") + style = input_data.get("style", "vivid") + + payload = { + "model": model, + "prompt": prompt, + "size": size, + "n": 1, + "response_format": "b64_json" + } + + # DALL-E 3 specific options + if model == "dall-e-3": + payload["quality"] = quality + payload["style"] = style + + response = await client.post( + "https://api.openai.com/v1/images/generations", + headers={ + "Authorization": f"Bearer {settings.openai_api_key}", + "Content-Type": "application/json" + }, + json=payload + ) + response.raise_for_status() + data = response.json() + + if data.get("data") and len(data["data"]) > 0: + b64_image = data["data"][0].get("b64_json") + if b64_image: + filename = f"{model.replace('-', '')}_{style if model == 'dall-e-3' else 'gen'}_{uuid4()}.png" + return base64.b64decode(b64_image), filename + + return None, None + + +async def _generate_stability(input_data: dict, input_image_data: Optional[bytes] = None) -> Tuple[Optional[bytes], Optional[str]]: + """Generate image using Stability AI + + Parameters: + - prompt: Text description (required) + - negative_prompt: What to avoid in generation + - model: 'sd3.5-large', 'sd3.5-medium', 'sd3-large', 'sd3-medium' + - aspect_ratio: '1:1', '16:9', '9:16', '4:3', '3:4', '21:9', '9:21' + - seed: Optional seed for reproducibility (0-4294967294) + - mode: 'text-to-image' or 'image-to-image' + """ + if not settings.stability_api_key: + raise ValueError("Stability API key not configured") + + prompt = input_data.get("prompt", "") + if not prompt: + raise ValueError("Prompt is required") + + negative_prompt = input_data.get("negative_prompt", "") + model = input_data.get("model", "sd3.5-large") + aspect_ratio = input_data.get("aspect_ratio", "1:1") + seed = input_data.get("seed") + output_format = input_data.get("output_format", "png") + + async with httpx.AsyncClient(timeout=180) as client: + # Build multipart form data - Stability requires multipart/form-data + files = { + "prompt": (None, prompt), + "mode": (None, "text-to-image"), + "model": (None, model), + "aspect_ratio": (None, aspect_ratio), + "output_format": (None, output_format), + } + + if negative_prompt: + files["negative_prompt"] = (None, negative_prompt) + + if seed is not None: + files["seed"] = (None, str(seed)) + + # Image-to-image mode + if input_image_data: + files["mode"] = (None, "image-to-image") + files["strength"] = (None, str(input_data.get("strength", 0.7))) + files["image"] = ("input.png", input_image_data, "image/png") + + try: + response = await client.post( + "https://api.stability.ai/v2beta/stable-image/generate/sd3", + headers={ + "Authorization": f"Bearer {settings.stability_api_key}", + "Accept": "image/*" + }, + files=files + ) + + if response.status_code != 200: + error_text = response.text + logger.error(f"Stability AI error {response.status_code}: {error_text}") + raise Exception(f"Stability AI error: {error_text}") + + model_short = model.replace("-", "").replace(".", "") + filename = f"stability_{model_short}_{uuid4()}.{output_format}" + return response.content, filename + + except httpx.HTTPStatusError as e: + logger.error(f"Stability AI HTTP error: {e.response.status_code} - {e.response.text}") + raise + except Exception as e: + logger.error(f"Stability AI generation error: {e}") + raise + + +async def _generate_leonardo(input_data: dict) -> tuple: + """ + Generate image using Leonardo AI + + Parameters: + - prompt: Text description + - model: Leonardo model ID (default: Phoenix) + - width: Image width (512, 768, 1024, 1472) + - height: Image height (512, 768, 832, 1024) + - preset_style: Style preset (ANIME, CINEMATIC, PHOTOGRAPHY, etc.) + - num_images: Number of images to generate + - guidance_scale: How closely to follow prompt (7-15) + - num_inference_steps: Quality/speed tradeoff (30-60) + - negative_prompt: What to avoid + - init_image_id: For image-to-image + - init_strength: How much to change input image (0.1-0.9) + """ + # Default model is Leonardo Phoenix + model_id = input_data.get("model", "6b645e3a-d64f-4341-a6d8-7a3690fbf042") + + # Determine dimensions from aspect ratio + aspect_ratio = input_data.get("aspect_ratio", "1:1") + dims = IMAGE_PROVIDERS["leonardo"]["dimensions"].get(aspect_ratio, {"width": 1024, "height": 1024}) + + # Allow explicit override if provided (and valid int) + width = int(input_data.get("width", dims["width"])) + height = int(input_data.get("height", dims["height"])) + + # Build request payload + payload = { + "prompt": input_data.get("prompt"), + "modelId": model_id, + "width": width, + "height": height, + "num_images": min(input_data.get("num_images", 1), 4), # Cap at 4 for safety + "public": input_data.get("public", False) + } + + # Alchemy / PhotoReal Logic + # Phoenix (de7d3faf...) does NOT support Alchemy or PhotoReal (it has its own pipeline). + # Sending 'alchemy': True with Phoenix causes "Invalid response from authorization hook" (500). + + is_phoenix = model_id == "de7d3faf-762f-48e0-b3b7-9d0ac3a3fcf3" + + alchemy = input_data.get("alchemy", False) + photo_real = input_data.get("photo_real", False) + + if is_phoenix: + # Force disable legacy features for Phoenix + alchemy = False + photo_real = False + # Phoenix might support 'elements' or other new params, but definitely not legacy alchemy. + + if alchemy: + payload["alchemy"] = True + payload["contrastRatio"] = input_data.get("contrast_ratio", 0.5) + + if photo_real: + payload["photoReal"] = True + payload["photoRealStrength"] = input_data.get("photo_real_strength", 0.5) + # If PhotoReal is on, we remove modelId to rely on system default for PhotoReal. + if "modelId" in payload: + del payload["modelId"] + + # Log payload for debugging + logger.info(f"Leonardo Payload (Model: {model_id}): {payload}") + + if input_data.get("preset_style") and input_data.get("preset_style") != "NONE": + payload["presetStyle"] = input_data.get("preset_style") + + if input_data.get("guidance_scale"): + payload["guidance_scale"] = int(input_data.get("guidance_scale")) + + # Image-to-image / Reference + # Modern Leonardo uses 'imagePrompts' array for reference. + # 'init_image_id' is legacy but might still work for some models. + init_image_id = input_data.get("init_image_id") + if init_image_id: + # Legacy support + payload["init_image_id"] = init_image_id + payload["init_strength"] = float(input_data.get("init_strength", 0.5)) + + + async with httpx.AsyncClient(timeout=180) as client: + # Create generation + response = await client.post( + "https://cloud.leonardo.ai/api/rest/v1/generations", + headers={ + "Authorization": f"Bearer {settings.leonardo_api_key}", + "Content-Type": "application/json" + }, + json=payload + ) + if response.status_code != 200: + error_text = response.text + logger.error(f"Leonardo API error {response.status_code}: {error_text}") + raise ValueError(f"Leonardo API returned {response.status_code}: {error_text}") + + data = response.json() + logger.info(f"Leonardo response: {data}") + + # Poll for result + generation_id = data.get("sdGenerationJob", {}).get("generationId") + if generation_id: + import asyncio + for _ in range(90): # Wait up to 3 minutes + await asyncio.sleep(2) + status_response = await client.get( + f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}", + headers={"Authorization": f"Bearer {settings.leonardo_api_key}"} + ) + status_data = status_response.json() + generation = status_data.get("generations_by_pk", {}) + status = generation.get("status") + + if status == "COMPLETE": + images = generation.get("generated_images", []) + if images: + image_url = images[0].get("url") + if image_url: + img_response = await client.get(image_url) + model_name = IMAGE_PROVIDERS["leonardo"]["models"].get(model_id, "leonardo") + filename = f"leonardo_{model_name.replace(' ', '_').lower()}_{uuid4()}.png" + return img_response.content, filename + elif status == "FAILED": + raise Exception("Leonardo generation failed") + + return None, None + + +async def _generate_bria(input_data: dict) -> tuple: + """ + Generate image using Bria AI + + Parameters: + - prompt: Text description + - model: 'base' (Bria 2.3 Base) or 'fast' (Bria 2.3 Fast) + - aspect_ratio: Image aspect ratio + - medium: 'photography' or 'art' + - prompt_enhancement: Enable AI prompt enhancement + - steps_num: Number of inference steps + - guidance_scale: How closely to follow prompt + - negative_prompt: What to avoid + """ + model = input_data.get("model", "base") + base_url = "https://engine.prod.bria-api.com/v1/text-to-image" + + # Build request payload + payload = { + "prompt": input_data.get("prompt"), + "num_results": 1 + } + + # Add aspect ratio + if input_data.get("aspect_ratio"): + payload["aspect_ratio"] = input_data.get("aspect_ratio") + + # Add medium + if input_data.get("medium"): + payload["medium"] = input_data.get("medium") + + # Add prompt enhancement + if input_data.get("prompt_enhancement"): + payload["prompt_enhancement"] = True + + # Add negative prompt + if input_data.get("negative_prompt"): + payload["negative_prompt"] = input_data.get("negative_prompt") + + # Model-specific parameters + if model == "base": + url = f"{base_url}/base" + if input_data.get("steps_num"): + payload["steps_num"] = input_data.get("steps_num") + if input_data.get("guidance_scale"): + payload["text_guidance_scale"] = input_data.get("guidance_scale") + else: + url = f"{base_url}/fast" + if input_data.get("steps_num"): + payload["steps_num"] = min(input_data.get("steps_num"), 10) + + async with httpx.AsyncClient(timeout=120) as client: + response = await client.post( + url, + headers={ + "api_token": settings.bria_api_key, + "Content-Type": "application/json" + }, + json=payload + ) + response.raise_for_status() + data = response.json() + + # Get the result + result = data.get("result", []) + if result and len(result) > 0: + image_url = result[0].get("urls", {}).get("url") + if image_url: + img_response = await client.get(image_url) + filename = f"bria_{model}_{uuid4()}.png" + return img_response.content, filename + + return None, None + + +async def _generate_ideogram(input_data: dict) -> tuple: + """Generate image using Ideogram""" + async with httpx.AsyncClient(timeout=120) as client: + response = await client.post( + "https://api.ideogram.ai/generate", + headers={ + "Api-Key": settings.ideogram_api_key, + "Content-Type": "application/json" + }, + json={ + "image_request": { + "prompt": input_data.get("prompt"), + "model": "V_2", + "aspect_ratio": "ASPECT_1_1" + } + } + ) + response.raise_for_status() + data = response.json() + + if data.get("data") and len(data["data"]) > 0: + image_url = data["data"][0].get("url") + if image_url: + img_response = await client.get(image_url) + filename = f"ideogram_{uuid4()}.png" + return img_response.content, filename + + return None, None + + +async def _generate_flux(input_data: dict) -> tuple: + """Generate image using Flux (Black Forest Labs) + + Note: Requires FLUX_API_KEY from https://api.bfl.ml/ + May require paid account for flux-pro-1.1 model + """ + if not settings.flux_api_key: + raise ValueError("FLUX_API_KEY not configured") + + async with httpx.AsyncClient(timeout=120) as client: + try: + response = await client.post( + "https://api.bfl.ml/v1/flux-pro-1.1", + headers={ + "x-key": settings.flux_api_key, + "Content-Type": "application/json" + }, + json={ + "prompt": input_data.get("prompt"), + "width": input_data.get("width", 1024), + "height": input_data.get("height", 1024) + } + ) + + if response.status_code == 403: + logger.error("Flux API 403: Invalid API key or insufficient permissions") + raise ValueError("Flux API key is invalid or your account doesn't have access to flux-pro-1.1") + + response.raise_for_status() + data = response.json() + + # Poll for result + request_id = data.get("id") + if request_id: + import asyncio + for _ in range(60): + await asyncio.sleep(2) + status_response = await client.get( + f"https://api.bfl.ml/v1/get_result?id={request_id}", + headers={"x-key": settings.flux_api_key} + ) + status_data = status_response.json() + if status_data.get("status") == "Ready": + image_url = status_data.get("result", {}).get("sample") + if image_url: + img_response = await client.get(image_url) + filename = f"flux_{uuid4()}.png" + return img_response.content, filename + + except Exception as e: + logger.error(f"Flux generation error: {e}") + raise + + return None, None + + +async def _generate_gemini(input_data: dict) -> tuple: + """Generate image using Google Gemini""" + import google.generativeai as genai + + genai.configure(api_key=settings.google_api_key) + model = genai.GenerativeModel("gemini-2.0-flash-exp") + + response = model.generate_content( + input_data.get("prompt"), + generation_config=genai.types.GenerationConfig( + response_mime_type="image/png" + ) + ) + + if response.candidates and response.candidates[0].content.parts: + for part in response.candidates[0].content.parts: + if hasattr(part, 'inline_data') and part.inline_data: + filename = f"gemini_{uuid4()}.png" + return part.inline_data.data, filename + + return None, None + + +async def _generate_imagen(input_data: dict) -> tuple: + """ + Generate image using Google Imagen 3 via REST API + + Note: Imagen 3 is accessed through the generativelanguage API with API key. + + Parameters: + - prompt: Text description of the image + - aspect_ratio: "1:1", "3:4", "4:3", "9:16", "16:9" + - number_of_images: 1-4 + - negative_prompt: What to avoid in the image + """ + if not settings.google_api_key: + raise ValueError("GOOGLE_API_KEY not configured") + + prompt = input_data.get("prompt", "") + negative_prompt = input_data.get("negative_prompt", "") + aspect_ratio = input_data.get("aspect_ratio", "1:1") + number_of_images = min(input_data.get("number_of_images", 1), 4) + + # Use the Generative Language API for Imagen 4 + model_name = input_data.get("model", "imagen-4.0-generate-001") + url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:predict" + + payload = { + "instances": [{"prompt": prompt}], + "parameters": { + "sampleCount": number_of_images, + "aspectRatio": aspect_ratio, + } + } + + if negative_prompt: + payload["instances"][0]["negativePrompt"] = negative_prompt + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + url, + headers={ + "Content-Type": "application/json", + "x-goog-api-key": settings.google_api_key + }, + json=payload + ) + + if response.status_code == 200: + data = response.json() + predictions = data.get("predictions", []) + if predictions and predictions[0].get("bytesBase64Encoded"): + image_data = base64.b64decode(predictions[0]["bytesBase64Encoded"]) + filename = f"imagen3_{uuid4()}.png" + return image_data, filename + else: + logger.warning(f"Imagen API error: {response.status_code} - {response.text}") + # Fall back to Nano Banana (Gemini native) + logger.info("Falling back to Nano Banana (Gemini native image generation)") + return await _generate_nano_banana(input_data) + + except Exception as e: + logger.error(f"Imagen generation error: {e}") + # Fallback to Gemini native image generation + return await _generate_nano_banana(input_data) + + return None, None + + +async def _upload_file_http(media_data: bytes, mime_type: str) -> Optional[str]: + """ + Upload file using raw HTTP to Google Generative AI Files API + (Alternative to outdated google-generativeai library) + Returns: file_uri + """ + if not settings.google_api_key: + return None + + try: + url = f"https://generativelanguage.googleapis.com/upload/v1beta/files?key={settings.google_api_key}" + num_bytes = len(media_data) + + headers = { + "X-Goog-Upload-Protocol": "resumable", + "X-Goog-Upload-Command": "start", + "X-Goog-Upload-Header-Content-Length": str(num_bytes), + "X-Goog-Upload-Header-Content-Type": mime_type, + "Content-Type": "application/json" + } + + metadata = {"file": {"display_name": f"nano_banana_upload_{uuid4()}"}} + + async with httpx.AsyncClient(timeout=30.0) as client: + # 1. Start Upload + response = await client.post(url, headers=headers, json=metadata) + if response.status_code != 200: + logger.error(f"Failed to start upload: {response.status_code} - {response.text}") + return None + + upload_url = response.headers.get("x-goog-upload-url") + if not upload_url: + logger.error("No upload URL returned") + return None + + # 2. Upload Bytes + headers_upload = { + "Content-Length": str(num_bytes), + "X-Goog-Upload-Offset": "0", + "X-Goog-Upload-Command": "upload, finalize" + } + + response_upload = await client.post(upload_url, headers=headers_upload, content=media_data) + if response_upload.status_code != 200: + logger.error(f"Failed to upload data: {response_upload.status_code} - {response_upload.text}") + return None + + data = response_upload.json() + file_uri = data.get("file", {}).get("uri") + logger.info(f"File uploaded successfully: {file_uri}") + return file_uri + + except Exception as e: + logger.error(f"Upload error: {e}") + return None + + +async def _generate_nano_banana(input_data: dict, image_data: Optional[bytes] = None, mime_type: str = "image/png") -> tuple: + """ + Generate image using Nano Banana (Gemini 3 Pro Image) + Model: gemini-3-pro-image-preview + Uses File API for strict visual context adherence. + """ + if not settings.google_api_key: + raise ValueError("GOOGLE_API_KEY not configured") + + prompt = input_data.get("prompt", "") + if not prompt: + raise ValueError("Prompt is required") + + import google.generativeai as genai + import tempfile + import os + + import base64 + + genai.configure(api_key=settings.google_api_key) + + # Use gemini-3-pro-image-preview as requested by user + model_name = input_data.get("model", "gemini-3-pro-image-preview") + if model_name in ["gemini-2.5-flash-image", "gemini-2.0-flash-exp"]: + model_name = "gemini-3-pro-image-preview" + + url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" + + # Build payload - EXACTLY matching PHP structure (Image FIRST, then Text) + parts = [] + + if image_data: + # Robust MIME detection + real_mime_type = determine_mime_type(image_data) + + # PHP uses inline_data (snake_case) and base64 + # It forces image/jpeg in PHP. We will do the same to match the reference implementation exactly. + + b64_image = base64.b64encode(image_data).decode("utf-8") + parts.append({ + "inline_data": { + "mime_type": "image/jpeg", + "data": b64_image + } + }) + logger.info(f"Nano Banana: Added reference image (inline_data base64, {len(b64_image)} chars)") + + # Text Instruction Second + parts.append({"text": prompt}) + + # Construct generation config + gen_config = { + "responseModalities": ["IMAGE"] + } + + # Map aspect ratio if present + ar_map = { + "1:1": "1:1", "16:9": "16:9", "9:16": "9:16", + "4:3": "4:3", "3:4": "3:4" + } + input_ar = input_data.get("aspect_ratio", "1:1") + if input_ar in ar_map: + gen_config["imageConfig"] = { + "aspectRatio": ar_map[input_ar], + "imageSize": input_data.get("image_size", "2K") # PHP supports imageSize + } + + payload = { + "contents": [{ + "parts": parts + }], + "generationConfig": gen_config + } + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + url, + headers={ + "Content-Type": "application/json", + "x-goog-api-key": settings.google_api_key + }, + json=payload + ) + + logger.info(f"Nano Banana response status: {response.status_code}") + + if response.status_code != 200: + logger.error(f"Nano Banana API error: {response.status_code} - {response.text}") + try: + err_json = response.json() + err_msg = err_json.get("error", {}).get("message", response.text) + logger.error(f"Nano Banana Error Details: {err_msg}") + except: + pass + return None, None + + data = response.json() + + # CRITICAL: Check finishReason first (per AI Implementation Guide) + candidates = data.get("candidates", []) + if candidates and len(candidates) > 0: + candidate = candidates[0] + finish_reason = candidate.get("finishReason") + + if finish_reason == "IMAGE_RECITATION": + logger.warning("Nano Banana blocked: IMAGE_RECITATION (Prompt too generic)") + # In a real app we might want to surface this specific error to the user + return None, None + + if finish_reason == "SAFETY": + logger.warning("Nano Banana blocked: SAFETY (Content filters)") + return None, None + + # Extract image - Response uses camelCase 'inlineData' (Gotcha #1) + content = candidate.get("content", {}) + parts_resp = content.get("parts", []) + + for part in parts_resp: + # Guide says response uses inlineData (camelCase) + if "inlineData" in part: + inline_data = part["inlineData"] + if "data" in inline_data: + try: + img_bytes = base64.b64decode(inline_data["data"]) + filename = f"nano_banana_{uuid4()}.png" # We save as PNG internally usually + return img_bytes, filename + except Exception as e: + logger.error(f"Failed to decode base64: {e}") + + # Fallback to snake_case just in case + if "inline_data" in part: + inline_data = part["inline_data"] + if "data" in inline_data: + img_bytes = base64.b64decode(inline_data["data"]) + filename = f"nano_banana_{uuid4()}.png" + return img_bytes, filename + + logger.warning(f"Nano Banana: No image data in response. Content: {content}") + else: + logger.warning(f"Nano Banana: No candidates in response.") + + except Exception as e: + logger.error(f"Nano Banana generation error: {e}") + import traceback + traceback.print_exc() + + return None, None + + +async def _generate_runway_image(input_data: dict) -> tuple: + """Generate image using Runway Gen-4 Image""" + if not settings.runway_api_key: + raise ValueError("RUNWAY_API_KEY not configured") + + prompt = input_data.get("prompt", "") + ratio = input_data.get("ratio", "1360:768") + seed = input_data.get("seed") + + payload = {"model": "gen4_image", "promptText": prompt, "ratio": ratio if ratio in ["1024:1024", "1360:768"] else "1360:768"} + if seed and seed > 0: + payload["seed"] = seed + + async with httpx.AsyncClient(timeout=180) as client: + response = await client.post( + "https://api.dev.runwayml.com/v1/text_to_image", + headers={ + "Authorization": f"Bearer {settings.runway_api_key}", + "Content-Type": "application/json", + "X-Runway-Version": "2024-11-06" + }, + json=payload + ) + response.raise_for_status() + result = response.json() + task_id = result.get("id") + + # Poll for completion + import asyncio + for _ in range(90): + await asyncio.sleep(2) + status_resp = await client.get( + f"https://api.dev.runwayml.com/v1/tasks/{task_id}", + headers={"Authorization": f"Bearer {settings.runway_api_key}", "X-Runway-Version": "2024-11-06"} + ) + status_data = status_resp.json() + if status_data.get("status") == "SUCCEEDED": + url = status_data.get("output", [None])[0] + if url: + img_resp = await client.get(url) + return img_resp.content, f"runway_gen4_{uuid4()}.png" + elif status_data.get("status") == "FAILED": + raise ValueError(f"Runway failed: {status_data.get('error')}") + + return None, None diff --git a/backend/scripts/test_nano_edit.py b/backend/scripts/test_nano_edit.py new file mode 100644 index 0000000..98489f8 --- /dev/null +++ b/backend/scripts/test_nano_edit.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" +Test Nano Banana editing with a real image from the database +""" +import asyncio +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from app.database import SessionLocal +from app.models.asset import Asset +from app.services.image_generator import _generate_nano_banana + +async def test_edit(): + """Test editing an existing asset""" + db = SessionLocal() + + try: + # Get the most recent image asset + asset = db.query(Asset).filter( + Asset.mime_type.like('image/%') + ).order_by(Asset.created_at.desc()).first() + + if not asset: + print("❌ No image assets found in database") + return + + print(f"✓ Found asset: {asset.original_filename} ({asset.mime_type})") + print(f" Asset ID: {asset.id}") + + # Read the image data + if not os.path.exists(asset.file_path): + print(f"❌ File not found: {asset.file_path}") + return + + with open(asset.file_path, 'rb') as f: + image_data = f.read() + + print(f"✓ Loaded image data: {len(image_data)} bytes") + + # Test edit + input_data = { + "prompt": "add a beautiful sunset with orange and pink colors in the sky", + "aspect_ratio": "16:9", + "image_size": "2K" + } + + print(f"\n🎨 Testing edit with prompt: '{input_data['prompt']}'") + print("⏳ Calling Nano Banana API...") + + result_bytes, filename = await _generate_nano_banana( + input_data=input_data, + image_data=image_data, + mime_type=asset.mime_type + ) + + if result_bytes: + print(f"✅ SUCCESS! Generated: {filename}") + print(f" Size: {len(result_bytes)} bytes") + + # Save to temp file + output_path = f"/tmp/{filename}" + with open(output_path, 'wb') as f: + f.write(result_bytes) + print(f" Saved to: {output_path}") + else: + print("❌ FAILED: No image data returned") + + except Exception as e: + print(f"❌ ERROR: {e}") + import traceback + traceback.print_exc() + finally: + db.close() + +if __name__ == "__main__": + asyncio.run(test_edit())