1200 lines
46 KiB
Python
1200 lines
46 KiB
Python
"""Image Generator Service - Multiple AI Providers
|
|
|
|
Supported Providers:
|
|
- openai: GPT-Image-1 (latest) or DALL-E 3
|
|
- imagen: Google Imagen 4 (Standard, Ultra, Fast)
|
|
- nano-banana: Gemini 2.5 Flash Image / Nano Banana Pro
|
|
- stable-diffusion: Stability AI SDXL, SD3, image-to-image
|
|
- leonardo: Leonardo.ai models
|
|
- ideogram: Ideogram v2 with text rendering
|
|
- flux: Black Forest Labs Flux Pro
|
|
|
|
OpenAI GPT-Image-1 (April 2025):
|
|
- model: 'gpt-image-1' (default) or 'dall-e-3'
|
|
- quality: 'low', 'medium', 'high' (default high)
|
|
- size: 1024x1024, 1024x1536, 1536x1024
|
|
- background: 'transparent', 'opaque', 'auto' (for PNG/WebP)
|
|
- output_format: 'png', 'jpeg', 'webp'
|
|
- n: 1-10 images per request
|
|
- Pricing: ~$0.02 (low), $0.07 (medium), $0.19 (high) per image
|
|
|
|
Google Imagen 4 (December 2025):
|
|
- model: 'imagen-4.0-generate-001' (default), 'imagen-4.0-ultra-generate-001', 'imagen-4.0-fast-generate-001'
|
|
- image_size: '1K', '2K' (Ultra/Standard only)
|
|
- aspect_ratio: '1:1', '3:4', '4:3', '9:16', '16:9'
|
|
- number_of_images: 1-4
|
|
- enhance_prompt: true/false (LLM prompt enhancement)
|
|
- person_generation: 'dont_allow', 'allow_adult', 'allow_all'
|
|
- Pricing: $0.02 (Fast), $0.04 (Standard), $0.06 (Ultra) per image
|
|
|
|
Nano Banana / Gemini Image (December 2025):
|
|
- model: 'gemini-2.5-flash-image' (Nano Banana), 'gemini-3-pro-image-preview' (Nano Banana Pro)
|
|
- aspect_ratio: '1:1', '2:3', '3:2', '3:4', '4:3', '4:5', '5:4', '9:16', '16:9', '21:9'
|
|
- image_size: '1K', '2K', '4K' (Pro only for 4K)
|
|
- Features: Text rendering, image editing, multi-turn conversation
|
|
- Pricing: ~$0.04 per 1MP image
|
|
|
|
DALL-E 3 Options:
|
|
- quality: 'standard' or 'hd' (default hd)
|
|
- style: 'vivid' (hyper-real) or 'natural' (more realistic)
|
|
- size: 1024x1024, 1024x1792, 1792x1024
|
|
|
|
Stability AI Options:
|
|
- model: sd3.5-large, sd3.5-medium, sd3-large, sd3-medium, sdxl-1.0
|
|
- aspect_ratio: 1:1, 16:9, 9:16, 4:3, 3:4, 21:9, 9:21
|
|
- negative_prompt: What to avoid in generation
|
|
- image_to_image: Use input image as starting point
|
|
- strength: 0.0-1.0 for image-to-image (how much to change)
|
|
- style_preset: enhance, anime, photographic, digital-art, etc.
|
|
"""
|
|
import httpx
|
|
import os
|
|
import base64
|
|
import logging
|
|
from uuid import uuid4
|
|
from datetime import datetime
|
|
from typing import Optional, Dict, Any, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from app.database import SessionLocal
|
|
from app.models.job import Job
|
|
from app.models.asset import Asset
|
|
from app.config import settings
|
|
|
|
def determine_mime_type(data: bytes) -> str:
|
|
"""Detect MIME type from magic bytes"""
|
|
if data.startswith(b'\x89PNG\r\n\x1a\n'):
|
|
return 'image/png'
|
|
elif data.startswith(b'\xff\xd8'):
|
|
return 'image/jpeg'
|
|
elif data.startswith(b'RIFF') and data[8:12] == b'WEBP':
|
|
return 'image/webp'
|
|
return 'image/png' # Default fallback
|
|
|
|
# Provider configurations
|
|
IMAGE_PROVIDERS = {
|
|
"openai": {
|
|
"name": "OpenAI Image Generation",
|
|
"models": ["gpt-image-1", "dall-e-3", "dall-e-2"],
|
|
"default_model": "gpt-image-1",
|
|
"gpt-image-1": {
|
|
"sizes": ["1024x1024", "1024x1536", "1536x1024"],
|
|
"qualities": ["low", "medium", "high"],
|
|
"output_formats": ["png", "jpeg", "webp"],
|
|
"backgrounds": ["auto", "transparent", "opaque"],
|
|
"max_images": 10
|
|
},
|
|
"dall-e-3": {
|
|
"sizes": ["1024x1024", "1024x1792", "1792x1024"],
|
|
"qualities": ["standard", "hd"],
|
|
"styles": ["vivid", "natural"]
|
|
},
|
|
"supports_styles": True
|
|
},
|
|
"imagen": {
|
|
"name": "Google Imagen 4",
|
|
"models": ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001", "imagen-4.0-fast-generate-001"],
|
|
"default_model": "imagen-4.0-generate-001",
|
|
"aspect_ratios": ["1:1", "3:4", "4:3", "9:16", "16:9"],
|
|
"image_sizes": ["1K", "2K"],
|
|
"max_images": 4,
|
|
"supports_enhance_prompt": True,
|
|
"supports_person_generation": True
|
|
},
|
|
"nano-banana": {
|
|
"name": "Nano Banana (Gemini Image)",
|
|
"models": ["gemini-3-pro-image-preview", "gemini-2.0-flash-exp"],
|
|
"default_model": "gemini-3-pro-image-preview",
|
|
"aspect_ratios": ["1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
|
"image_sizes": ["1K", "2K", "4K"],
|
|
"supports_text_rendering": True,
|
|
"supports_image_editing": True
|
|
},
|
|
"stable-diffusion": {
|
|
"name": "Stability AI",
|
|
"models": ["sd3.5-large", "sd3.5-medium", "sd3-large", "sd3-medium", "sdxl-1.0"],
|
|
"default_model": "sd3.5-large",
|
|
"aspect_ratios": ["1:1", "16:9", "9:16", "4:3", "3:4", "21:9", "9:21"],
|
|
"supports_img2img": True,
|
|
"supports_negative_prompt": True
|
|
},
|
|
"leonardo": {
|
|
"name": "Leonardo.ai",
|
|
"models": {
|
|
# Latest Models (2025)
|
|
# Phoenix: de7d3faf-762f-48e0-b3b7-9d0ac3a3fcf3 (Found in docs)
|
|
"de7d3faf-762f-48e0-b3b7-9d0ac3a3fcf3": "Leonardo Phoenix 1.0",
|
|
"7b592283-e8a7-4c5a-9ba6-d18c31f258b9": "Lucid Origin",
|
|
"05ce0082-2d80-4a2d-8653-4d1c85e2418e": "Lucid Realism",
|
|
"28aeddf8-bd19-4803-80fc-79602d1a9989": "FLUX.1 Kontext",
|
|
"b2614463-296c-462a-9586-aafdb8f00e36": "Flux Dev",
|
|
"1dd50843-d653-4516-a8e3-f0238ee453ff": "Flux Schnell",
|
|
|
|
# XL Models
|
|
"aa77f04e-3eec-4034-9c07-d0f619684628": "Leonardo Kino XL",
|
|
"5c232a9e-9061-4777-980a-ddc8e65647c6": "Leonardo Vision XL",
|
|
"b24e16ff-06e3-43eb-8d33-4416c2d75876": "Leonardo Lightning XL",
|
|
"1e60896f-3c26-4296-8ecc-53e2afecc132": "Leonardo Diffusion XL",
|
|
|
|
# Older/Other Support
|
|
"16e7060a-803e-4df3-97ee-edcfa5dc9cc8": "SDXL 1.0",
|
|
"ac614f96-1082-45bf-be9d-757f2d31c174": "DreamShaper v7",
|
|
"e316348f-7773-490e-adcd-46757c738eb7": "Absolute Reality v1.6"
|
|
},
|
|
"default_model": "de7d3faf-762f-48e0-b3b7-9d0ac3a3fcf3",
|
|
# Explicit mapping for Aspect Ratio -> Dimensions (Width x Height)
|
|
# These are generally safe for SDXL/Phoenix models
|
|
"dimensions": {
|
|
"1:1": {"width": 1024, "height": 1024},
|
|
"16:9": {"width": 1472, "height": 832},
|
|
"9:16": {"width": 832, "height": 1472},
|
|
"4:3": {"width": 1248, "height": 928}, # Approx for SDXL
|
|
"3:4": {"width": 928, "height": 1248},
|
|
"21:9": {"width": 1536, "height": 640}, # Ultra wide
|
|
"9:21": {"width": 640, "height": 1536}
|
|
},
|
|
"style_presets": [
|
|
"ANIME", "BOKEH", "CINEMATIC", "CINEMATIC_CLOSEUP", "CREATIVE",
|
|
"DYNAMIC", "ENVIRONMENT", "FASHION", "FILM", "FOOD", "GENERAL",
|
|
"HDR", "ILLUSTRATION", "LEONARDO", "LONG_EXPOSURE", "MACRO",
|
|
"MINIMALISTIC", "MONOCHROME", "MOODY", "NONE", "NEUTRAL",
|
|
"PHOTOGRAPHY", "PORTRAIT", "RAYTRACED", "RENDER_3D", "RETRO",
|
|
"SKETCH_BW", "SKETCH_COLOR", "STOCK_PHOTO", "VIBRANT", "UNPROCESSED"
|
|
],
|
|
"supports_img2img": True,
|
|
"supports_character_reference": True,
|
|
"supports_style_reference": True
|
|
},
|
|
"bria": {
|
|
"name": "Bria AI",
|
|
"models": ["base", "fast"],
|
|
"default_model": "base",
|
|
"aspect_ratios": ["1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9"],
|
|
"mediums": ["photography", "art"],
|
|
"supports_prompt_enhancement": True,
|
|
"base_config": {"steps_num": [20, 50], "guidance_scale": [1, 10]},
|
|
"fast_config": {"steps_num": [4, 10]}
|
|
},
|
|
"ideogram": {
|
|
"name": "Ideogram",
|
|
"models": ["V_2", "V_2_TURBO"],
|
|
"supports_text_rendering": True
|
|
},
|
|
"flux": {
|
|
"name": "Flux Pro",
|
|
"models": ["flux-pro-1.1", "flux-dev", "flux-schnell"],
|
|
"supports_img2img": True
|
|
}
|
|
}
|
|
|
|
STABILITY_STYLE_PRESETS = [
|
|
"enhance", "anime", "photographic", "digital-art", "comic-book",
|
|
"fantasy-art", "line-art", "analog-film", "neon-punk", "isometric",
|
|
"low-poly", "origami", "modeling-compound", "cinematic", "3d-model", "pixel-art"
|
|
]
|
|
|
|
|
|
async def generate(job_id: str):
|
|
"""Generate image based on provider"""
|
|
db = SessionLocal()
|
|
try:
|
|
job = db.query(Job).filter(Job.id == job_id).first()
|
|
if not job:
|
|
return
|
|
|
|
input_data = job.input_data
|
|
provider = input_data.get("provider", "openai")
|
|
prompt = input_data.get("prompt", "")
|
|
|
|
# Update progress
|
|
job.progress = 10
|
|
job.api_provider = provider
|
|
db.commit()
|
|
|
|
image_data = None
|
|
filename = None
|
|
|
|
if provider == "openai" or provider == "dalle3":
|
|
image_data, filename = await _generate_openai(input_data)
|
|
job.api_model = input_data.get("model", "gpt-image-1")
|
|
elif provider == "imagen":
|
|
image_data, filename = await _generate_imagen(input_data)
|
|
job.api_model = input_data.get("model", "imagen-4.0-generate-001")
|
|
elif provider == "nano-banana" or provider == "gemini":
|
|
# The API endpoint converts reference_asset_id to reference_image (base64)
|
|
# So we use that directly instead of loading from database
|
|
ref_image_b64 = input_data.get("reference_image")
|
|
ref_image_data = None
|
|
ref_mime_type = "image/png" # Default
|
|
|
|
if ref_image_b64:
|
|
import base64
|
|
try:
|
|
ref_image_data = base64.b64decode(ref_image_b64)
|
|
# Detect MIME type from decoded data
|
|
ref_mime_type = determine_mime_type(ref_image_data)
|
|
logger.info(f"✓ Decoded reference image from base64 ({ref_mime_type}, {len(ref_image_data)} bytes)")
|
|
except Exception as e:
|
|
logger.error(f"Failed to decode reference_image base64: {e}")
|
|
|
|
image_data, filename = await _generate_nano_banana(input_data, ref_image_data, ref_mime_type)
|
|
job.api_model = input_data.get("model", "gemini-3-pro-image-preview")
|
|
elif provider == "stable-diffusion":
|
|
image_data, filename = await _generate_stability(input_data)
|
|
job.api_model = input_data.get("model", "sd3.5-large")
|
|
elif provider == "leonardo":
|
|
image_data, filename = await _generate_leonardo(input_data)
|
|
job.api_model = "leonardo"
|
|
elif provider == "ideogram":
|
|
image_data, filename = await _generate_ideogram(input_data)
|
|
job.api_model = "ideogram-v2"
|
|
elif provider == "flux":
|
|
image_data, filename = await _generate_flux(input_data)
|
|
job.api_model = "flux-pro"
|
|
elif provider == "bria":
|
|
image_data, filename = await _generate_bria(input_data)
|
|
job.api_model = input_data.get("model", "base")
|
|
elif provider == "runway-image":
|
|
image_data, filename = await _generate_runway_image(input_data)
|
|
job.api_model = "gen4_image"
|
|
else:
|
|
raise ValueError(f"Unknown provider: {provider}")
|
|
|
|
job.progress = 80
|
|
db.commit()
|
|
|
|
# Save image
|
|
if image_data:
|
|
storage_path = os.path.join(settings.storage_path, "images")
|
|
os.makedirs(storage_path, exist_ok=True)
|
|
file_path = os.path.join(storage_path, filename)
|
|
|
|
with open(file_path, "wb") as f:
|
|
f.write(image_data)
|
|
|
|
# Create asset
|
|
asset = Asset(
|
|
user_id=job.user_id,
|
|
project_id=job.project_id,
|
|
original_filename=filename,
|
|
stored_filename=filename,
|
|
file_path=file_path,
|
|
file_type="image",
|
|
mime_type="image/png",
|
|
file_size_bytes=len(image_data),
|
|
source_module="image_generator",
|
|
source_job_id=job.id,
|
|
asset_metadata={
|
|
"prompt": prompt,
|
|
"provider": provider,
|
|
"model": job.api_model
|
|
}
|
|
)
|
|
db.add(asset)
|
|
db.commit()
|
|
db.refresh(asset)
|
|
|
|
job.output_asset_ids = [asset.id]
|
|
job.output_data = {"asset_id": str(asset.id), "file_path": file_path}
|
|
|
|
# Log Usage
|
|
try:
|
|
from app.utils.logging import log_model_usage
|
|
# Other imports are available globally
|
|
|
|
# Placeholder values for logging, these would ideally be returned by _generate_ functions
|
|
# For now, we'll use what's available from input_data and job.api_model
|
|
model = job.api_model
|
|
width = input_data.get("width")
|
|
height = input_data.get("height")
|
|
n = input_data.get("n", 1) # Number of images requested
|
|
ext = "png" # Default, actual ext should come from _generate_ functions
|
|
|
|
# Use existing asset data for logging
|
|
output_asset_ids = job.output_asset_ids or []
|
|
output_paths = []
|
|
if job.output_data and "file_path" in job.output_data:
|
|
output_paths.append(job.output_data["file_path"])
|
|
|
|
duration_ms = 0
|
|
if job.started_at:
|
|
duration_ms = int((datetime.utcnow() - job.started_at).total_seconds() * 1000)
|
|
|
|
log_model_usage(
|
|
db=db,
|
|
job_id=str(job.id),
|
|
user_id=str(job.user_id),
|
|
module="image_generator",
|
|
action="generate",
|
|
provider=provider,
|
|
model=model,
|
|
usage_stats={
|
|
"images": len(output_asset_ids),
|
|
"processing_time_ms": duration_ms
|
|
},
|
|
request_metadata={
|
|
"prompt": prompt,
|
|
"negative_prompt": input_data.get("negative_prompt"),
|
|
"size": f"{width}x{height}" if width and height else None,
|
|
"n": n
|
|
},
|
|
response_metadata={
|
|
"output_assets": [str(a_id) for a_id in output_asset_ids],
|
|
"filenames": [os.path.basename(p) for p in output_paths]
|
|
}
|
|
)
|
|
except Exception as log_e:
|
|
logger.error(f"Failed to log image generation usage: {log_e}")
|
|
|
|
job.output_asset_ids = output_asset_ids
|
|
job.output_data = {
|
|
"prompt": prompt,
|
|
"provider": provider,
|
|
"model": model,
|
|
"image_paths": output_paths
|
|
}
|
|
job.progress = 100
|
|
job.status = "completed"
|
|
job.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
except Exception as e:
|
|
job.status = "failed"
|
|
job.error_message = str(e)
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
async def _generate_openai(input_data: dict) -> Tuple[Optional[bytes], Optional[str]]:
|
|
"""Generate image using OpenAI GPT-Image-1 or DALL-E 3
|
|
|
|
GPT-Image-1 Parameters (default):
|
|
- prompt: Text description (max 32000 chars)
|
|
- quality: 'low', 'medium', 'high' (default: high)
|
|
- size: '1024x1024', '1024x1536', '1536x1024'
|
|
- background: 'transparent', 'opaque', 'auto'
|
|
- output_format: 'png', 'jpeg', 'webp' (default: png)
|
|
- output_compression: 0-100 for jpeg/webp
|
|
- moderation: 'auto' or 'low' (less restrictive)
|
|
- n: 1-10 images
|
|
|
|
DALL-E 3 Parameters:
|
|
- prompt: Text description (max 4000 chars)
|
|
- quality: 'standard' or 'hd' (default: hd)
|
|
- style: 'vivid' or 'natural' (default: vivid)
|
|
- size: '1024x1024', '1024x1792', '1792x1024'
|
|
"""
|
|
prompt = input_data.get("prompt", "")
|
|
model = input_data.get("model", "gpt-image-1")
|
|
width = input_data.get("width", 1024)
|
|
height = input_data.get("height", 1024)
|
|
|
|
# Determine size based on width/height
|
|
if width > height:
|
|
size = "1536x1024" if model == "gpt-image-1" else "1792x1024"
|
|
elif height > width:
|
|
size = "1024x1536" if model == "gpt-image-1" else "1024x1792"
|
|
else:
|
|
size = "1024x1024"
|
|
|
|
async with httpx.AsyncClient(timeout=180) as client:
|
|
if model == "gpt-image-1":
|
|
# GPT-Image-1 (latest model)
|
|
quality = input_data.get("quality", "high")
|
|
background = input_data.get("background", "auto")
|
|
output_format = input_data.get("output_format", "png")
|
|
output_compression = input_data.get("output_compression", 100)
|
|
moderation = input_data.get("moderation", "auto")
|
|
n = min(input_data.get("n", 1), 10)
|
|
|
|
payload = {
|
|
"model": "gpt-image-1",
|
|
"prompt": prompt,
|
|
"size": size,
|
|
"quality": quality,
|
|
"n": n
|
|
}
|
|
|
|
# Add optional parameters
|
|
if background != "auto":
|
|
payload["background"] = background
|
|
if output_format != "png":
|
|
payload["output_format"] = output_format
|
|
if output_format in ["jpeg", "webp"] and output_compression != 100:
|
|
payload["output_compression"] = output_compression
|
|
if moderation != "auto":
|
|
payload["moderation"] = moderation
|
|
|
|
response = await client.post(
|
|
"https://api.openai.com/v1/images/generations",
|
|
headers={
|
|
"Authorization": f"Bearer {settings.openai_api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json=payload
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
if data.get("data") and len(data["data"]) > 0:
|
|
# GPT-Image-1 always returns base64
|
|
b64_image = data["data"][0].get("b64_json")
|
|
if b64_image:
|
|
ext = output_format if output_format in ["png", "jpeg", "webp"] else "png"
|
|
filename = f"gptimage1_{quality}_{uuid4()}.{ext}"
|
|
return base64.b64decode(b64_image), filename
|
|
|
|
else:
|
|
# DALL-E 3 (or DALL-E 2)
|
|
quality = input_data.get("quality", "hd")
|
|
style = input_data.get("style", "vivid")
|
|
|
|
payload = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"size": size,
|
|
"n": 1,
|
|
"response_format": "b64_json"
|
|
}
|
|
|
|
# DALL-E 3 specific options
|
|
if model == "dall-e-3":
|
|
payload["quality"] = quality
|
|
payload["style"] = style
|
|
|
|
response = await client.post(
|
|
"https://api.openai.com/v1/images/generations",
|
|
headers={
|
|
"Authorization": f"Bearer {settings.openai_api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json=payload
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
if data.get("data") and len(data["data"]) > 0:
|
|
b64_image = data["data"][0].get("b64_json")
|
|
if b64_image:
|
|
filename = f"{model.replace('-', '')}_{style if model == 'dall-e-3' else 'gen'}_{uuid4()}.png"
|
|
return base64.b64decode(b64_image), filename
|
|
|
|
return None, None
|
|
|
|
|
|
async def _generate_stability(input_data: dict, input_image_data: Optional[bytes] = None) -> Tuple[Optional[bytes], Optional[str]]:
|
|
"""Generate image using Stability AI
|
|
|
|
Parameters:
|
|
- prompt: Text description (required)
|
|
- negative_prompt: What to avoid in generation
|
|
- model: 'sd3.5-large', 'sd3.5-medium', 'sd3-large', 'sd3-medium'
|
|
- aspect_ratio: '1:1', '16:9', '9:16', '4:3', '3:4', '21:9', '9:21'
|
|
- seed: Optional seed for reproducibility (0-4294967294)
|
|
- mode: 'text-to-image' or 'image-to-image'
|
|
"""
|
|
if not settings.stability_api_key:
|
|
raise ValueError("Stability API key not configured")
|
|
|
|
prompt = input_data.get("prompt", "")
|
|
if not prompt:
|
|
raise ValueError("Prompt is required")
|
|
|
|
negative_prompt = input_data.get("negative_prompt", "")
|
|
model = input_data.get("model", "sd3.5-large")
|
|
aspect_ratio = input_data.get("aspect_ratio", "1:1")
|
|
seed = input_data.get("seed")
|
|
output_format = input_data.get("output_format", "png")
|
|
|
|
async with httpx.AsyncClient(timeout=180) as client:
|
|
# Build multipart form data - Stability requires multipart/form-data
|
|
files = {
|
|
"prompt": (None, prompt),
|
|
"mode": (None, "text-to-image"),
|
|
"model": (None, model),
|
|
"aspect_ratio": (None, aspect_ratio),
|
|
"output_format": (None, output_format),
|
|
}
|
|
|
|
if negative_prompt:
|
|
files["negative_prompt"] = (None, negative_prompt)
|
|
|
|
if seed is not None:
|
|
files["seed"] = (None, str(seed))
|
|
|
|
# Image-to-image mode
|
|
if input_image_data:
|
|
files["mode"] = (None, "image-to-image")
|
|
files["strength"] = (None, str(input_data.get("strength", 0.7)))
|
|
files["image"] = ("input.png", input_image_data, "image/png")
|
|
|
|
try:
|
|
response = await client.post(
|
|
"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
|
headers={
|
|
"Authorization": f"Bearer {settings.stability_api_key}",
|
|
"Accept": "image/*"
|
|
},
|
|
files=files
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
error_text = response.text
|
|
logger.error(f"Stability AI error {response.status_code}: {error_text}")
|
|
raise Exception(f"Stability AI error: {error_text}")
|
|
|
|
model_short = model.replace("-", "").replace(".", "")
|
|
filename = f"stability_{model_short}_{uuid4()}.{output_format}"
|
|
return response.content, filename
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"Stability AI HTTP error: {e.response.status_code} - {e.response.text}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Stability AI generation error: {e}")
|
|
raise
|
|
|
|
|
|
async def _generate_leonardo(input_data: dict) -> tuple:
|
|
"""
|
|
Generate image using Leonardo AI
|
|
|
|
Parameters:
|
|
- prompt: Text description
|
|
- model: Leonardo model ID (default: Phoenix)
|
|
- width: Image width (512, 768, 1024, 1472)
|
|
- height: Image height (512, 768, 832, 1024)
|
|
- preset_style: Style preset (ANIME, CINEMATIC, PHOTOGRAPHY, etc.)
|
|
- num_images: Number of images to generate
|
|
- guidance_scale: How closely to follow prompt (7-15)
|
|
- num_inference_steps: Quality/speed tradeoff (30-60)
|
|
- negative_prompt: What to avoid
|
|
- init_image_id: For image-to-image
|
|
- init_strength: How much to change input image (0.1-0.9)
|
|
"""
|
|
# Default model is Leonardo Phoenix
|
|
model_id = input_data.get("model", "6b645e3a-d64f-4341-a6d8-7a3690fbf042")
|
|
|
|
# Determine dimensions from aspect ratio
|
|
aspect_ratio = input_data.get("aspect_ratio", "1:1")
|
|
dims = IMAGE_PROVIDERS["leonardo"]["dimensions"].get(aspect_ratio, {"width": 1024, "height": 1024})
|
|
|
|
# Allow explicit override if provided (and valid int)
|
|
width = int(input_data.get("width", dims["width"]))
|
|
height = int(input_data.get("height", dims["height"]))
|
|
|
|
# Build request payload
|
|
payload = {
|
|
"prompt": input_data.get("prompt"),
|
|
"modelId": model_id,
|
|
"width": width,
|
|
"height": height,
|
|
"num_images": min(input_data.get("num_images", 1), 4), # Cap at 4 for safety
|
|
"public": input_data.get("public", False)
|
|
}
|
|
|
|
# Alchemy / PhotoReal Logic
|
|
# Phoenix (de7d3faf...) does NOT support Alchemy or PhotoReal (it has its own pipeline).
|
|
# Sending 'alchemy': True with Phoenix causes "Invalid response from authorization hook" (500).
|
|
|
|
is_phoenix = model_id == "de7d3faf-762f-48e0-b3b7-9d0ac3a3fcf3"
|
|
|
|
alchemy = input_data.get("alchemy", False)
|
|
photo_real = input_data.get("photo_real", False)
|
|
|
|
if is_phoenix:
|
|
# Force disable legacy features for Phoenix
|
|
alchemy = False
|
|
photo_real = False
|
|
# Phoenix might support 'elements' or other new params, but definitely not legacy alchemy.
|
|
|
|
if alchemy:
|
|
payload["alchemy"] = True
|
|
payload["contrastRatio"] = input_data.get("contrast_ratio", 0.5)
|
|
|
|
if photo_real:
|
|
payload["photoReal"] = True
|
|
payload["photoRealStrength"] = input_data.get("photo_real_strength", 0.5)
|
|
# If PhotoReal is on, we remove modelId to rely on system default for PhotoReal.
|
|
if "modelId" in payload:
|
|
del payload["modelId"]
|
|
|
|
# Log payload for debugging
|
|
logger.info(f"Leonardo Payload (Model: {model_id}): {payload}")
|
|
|
|
if input_data.get("preset_style") and input_data.get("preset_style") != "NONE":
|
|
payload["presetStyle"] = input_data.get("preset_style")
|
|
|
|
if input_data.get("guidance_scale"):
|
|
payload["guidance_scale"] = int(input_data.get("guidance_scale"))
|
|
|
|
# Image-to-image / Reference
|
|
# Modern Leonardo uses 'imagePrompts' array for reference.
|
|
# 'init_image_id' is legacy but might still work for some models.
|
|
init_image_id = input_data.get("init_image_id")
|
|
if init_image_id:
|
|
# Legacy support
|
|
payload["init_image_id"] = init_image_id
|
|
payload["init_strength"] = float(input_data.get("init_strength", 0.5))
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=180) as client:
|
|
# Create generation
|
|
response = await client.post(
|
|
"https://cloud.leonardo.ai/api/rest/v1/generations",
|
|
headers={
|
|
"Authorization": f"Bearer {settings.leonardo_api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json=payload
|
|
)
|
|
if response.status_code != 200:
|
|
error_text = response.text
|
|
logger.error(f"Leonardo API error {response.status_code}: {error_text}")
|
|
raise ValueError(f"Leonardo API returned {response.status_code}: {error_text}")
|
|
|
|
data = response.json()
|
|
logger.info(f"Leonardo response: {data}")
|
|
|
|
# Poll for result
|
|
generation_id = data.get("sdGenerationJob", {}).get("generationId")
|
|
if generation_id:
|
|
import asyncio
|
|
for _ in range(90): # Wait up to 3 minutes
|
|
await asyncio.sleep(2)
|
|
status_response = await client.get(
|
|
f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
|
headers={"Authorization": f"Bearer {settings.leonardo_api_key}"}
|
|
)
|
|
status_data = status_response.json()
|
|
generation = status_data.get("generations_by_pk", {})
|
|
status = generation.get("status")
|
|
|
|
if status == "COMPLETE":
|
|
images = generation.get("generated_images", [])
|
|
if images:
|
|
image_url = images[0].get("url")
|
|
if image_url:
|
|
img_response = await client.get(image_url)
|
|
model_name = IMAGE_PROVIDERS["leonardo"]["models"].get(model_id, "leonardo")
|
|
filename = f"leonardo_{model_name.replace(' ', '_').lower()}_{uuid4()}.png"
|
|
return img_response.content, filename
|
|
elif status == "FAILED":
|
|
raise Exception("Leonardo generation failed")
|
|
|
|
return None, None
|
|
|
|
|
|
async def _generate_bria(input_data: dict) -> tuple:
|
|
"""
|
|
Generate image using Bria AI
|
|
|
|
Parameters:
|
|
- prompt: Text description
|
|
- model: 'base' (Bria 2.3 Base) or 'fast' (Bria 2.3 Fast)
|
|
- aspect_ratio: Image aspect ratio
|
|
- medium: 'photography' or 'art'
|
|
- prompt_enhancement: Enable AI prompt enhancement
|
|
- steps_num: Number of inference steps
|
|
- guidance_scale: How closely to follow prompt
|
|
- negative_prompt: What to avoid
|
|
"""
|
|
model = input_data.get("model", "base")
|
|
base_url = "https://engine.prod.bria-api.com/v1/text-to-image"
|
|
|
|
# Build request payload
|
|
payload = {
|
|
"prompt": input_data.get("prompt"),
|
|
"num_results": 1
|
|
}
|
|
|
|
# Add aspect ratio
|
|
if input_data.get("aspect_ratio"):
|
|
payload["aspect_ratio"] = input_data.get("aspect_ratio")
|
|
|
|
# Add medium
|
|
if input_data.get("medium"):
|
|
payload["medium"] = input_data.get("medium")
|
|
|
|
# Add prompt enhancement
|
|
if input_data.get("prompt_enhancement"):
|
|
payload["prompt_enhancement"] = True
|
|
|
|
# Add negative prompt
|
|
if input_data.get("negative_prompt"):
|
|
payload["negative_prompt"] = input_data.get("negative_prompt")
|
|
|
|
# Model-specific parameters
|
|
if model == "base":
|
|
url = f"{base_url}/base"
|
|
if input_data.get("steps_num"):
|
|
payload["steps_num"] = input_data.get("steps_num")
|
|
if input_data.get("guidance_scale"):
|
|
payload["text_guidance_scale"] = input_data.get("guidance_scale")
|
|
else:
|
|
url = f"{base_url}/fast"
|
|
if input_data.get("steps_num"):
|
|
payload["steps_num"] = min(input_data.get("steps_num"), 10)
|
|
|
|
async with httpx.AsyncClient(timeout=120) as client:
|
|
response = await client.post(
|
|
url,
|
|
headers={
|
|
"api_token": settings.bria_api_key,
|
|
"Content-Type": "application/json"
|
|
},
|
|
json=payload
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
# Get the result
|
|
result = data.get("result", [])
|
|
if result and len(result) > 0:
|
|
image_url = result[0].get("urls", {}).get("url")
|
|
if image_url:
|
|
img_response = await client.get(image_url)
|
|
filename = f"bria_{model}_{uuid4()}.png"
|
|
return img_response.content, filename
|
|
|
|
return None, None
|
|
|
|
|
|
async def _generate_ideogram(input_data: dict) -> tuple:
|
|
"""Generate image using Ideogram"""
|
|
async with httpx.AsyncClient(timeout=120) as client:
|
|
response = await client.post(
|
|
"https://api.ideogram.ai/generate",
|
|
headers={
|
|
"Api-Key": settings.ideogram_api_key,
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"image_request": {
|
|
"prompt": input_data.get("prompt"),
|
|
"model": "V_2",
|
|
"aspect_ratio": "ASPECT_1_1"
|
|
}
|
|
}
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
if data.get("data") and len(data["data"]) > 0:
|
|
image_url = data["data"][0].get("url")
|
|
if image_url:
|
|
img_response = await client.get(image_url)
|
|
filename = f"ideogram_{uuid4()}.png"
|
|
return img_response.content, filename
|
|
|
|
return None, None
|
|
|
|
|
|
async def _generate_flux(input_data: dict) -> tuple:
|
|
"""Generate image using Flux (Black Forest Labs)
|
|
|
|
Note: Requires FLUX_API_KEY from https://api.bfl.ml/
|
|
May require paid account for flux-pro-1.1 model
|
|
"""
|
|
if not settings.flux_api_key:
|
|
raise ValueError("FLUX_API_KEY not configured")
|
|
|
|
async with httpx.AsyncClient(timeout=120) as client:
|
|
try:
|
|
response = await client.post(
|
|
"https://api.bfl.ml/v1/flux-pro-1.1",
|
|
headers={
|
|
"x-key": settings.flux_api_key,
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"prompt": input_data.get("prompt"),
|
|
"width": input_data.get("width", 1024),
|
|
"height": input_data.get("height", 1024)
|
|
}
|
|
)
|
|
|
|
if response.status_code == 403:
|
|
logger.error("Flux API 403: Invalid API key or insufficient permissions")
|
|
raise ValueError("Flux API key is invalid or your account doesn't have access to flux-pro-1.1")
|
|
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
# Poll for result
|
|
request_id = data.get("id")
|
|
if request_id:
|
|
import asyncio
|
|
for _ in range(60):
|
|
await asyncio.sleep(2)
|
|
status_response = await client.get(
|
|
f"https://api.bfl.ml/v1/get_result?id={request_id}",
|
|
headers={"x-key": settings.flux_api_key}
|
|
)
|
|
status_data = status_response.json()
|
|
if status_data.get("status") == "Ready":
|
|
image_url = status_data.get("result", {}).get("sample")
|
|
if image_url:
|
|
img_response = await client.get(image_url)
|
|
filename = f"flux_{uuid4()}.png"
|
|
return img_response.content, filename
|
|
|
|
except Exception as e:
|
|
logger.error(f"Flux generation error: {e}")
|
|
raise
|
|
|
|
return None, None
|
|
|
|
|
|
async def _generate_gemini(input_data: dict) -> tuple:
|
|
"""Generate image using Google Gemini"""
|
|
import google.generativeai as genai
|
|
|
|
genai.configure(api_key=settings.google_api_key)
|
|
model = genai.GenerativeModel("gemini-2.0-flash-exp")
|
|
|
|
response = model.generate_content(
|
|
input_data.get("prompt"),
|
|
generation_config=genai.types.GenerationConfig(
|
|
response_mime_type="image/png"
|
|
)
|
|
)
|
|
|
|
if response.candidates and response.candidates[0].content.parts:
|
|
for part in response.candidates[0].content.parts:
|
|
if hasattr(part, 'inline_data') and part.inline_data:
|
|
filename = f"gemini_{uuid4()}.png"
|
|
return part.inline_data.data, filename
|
|
|
|
return None, None
|
|
|
|
|
|
async def _generate_imagen(input_data: dict) -> tuple:
|
|
"""
|
|
Generate image using Google Imagen 3 via REST API
|
|
|
|
Note: Imagen 3 is accessed through the generativelanguage API with API key.
|
|
|
|
Parameters:
|
|
- prompt: Text description of the image
|
|
- aspect_ratio: "1:1", "3:4", "4:3", "9:16", "16:9"
|
|
- number_of_images: 1-4
|
|
- negative_prompt: What to avoid in the image
|
|
"""
|
|
if not settings.google_api_key:
|
|
raise ValueError("GOOGLE_API_KEY not configured")
|
|
|
|
prompt = input_data.get("prompt", "")
|
|
negative_prompt = input_data.get("negative_prompt", "")
|
|
aspect_ratio = input_data.get("aspect_ratio", "1:1")
|
|
number_of_images = min(input_data.get("number_of_images", 1), 4)
|
|
|
|
# Use the Generative Language API for Imagen 4
|
|
model_name = input_data.get("model", "imagen-4.0-generate-001")
|
|
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:predict"
|
|
|
|
payload = {
|
|
"instances": [{"prompt": prompt}],
|
|
"parameters": {
|
|
"sampleCount": number_of_images,
|
|
"aspectRatio": aspect_ratio,
|
|
}
|
|
}
|
|
|
|
if negative_prompt:
|
|
payload["instances"][0]["negativePrompt"] = negative_prompt
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
response = await client.post(
|
|
url,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"x-goog-api-key": settings.google_api_key
|
|
},
|
|
json=payload
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
predictions = data.get("predictions", [])
|
|
if predictions and predictions[0].get("bytesBase64Encoded"):
|
|
image_data = base64.b64decode(predictions[0]["bytesBase64Encoded"])
|
|
filename = f"imagen3_{uuid4()}.png"
|
|
return image_data, filename
|
|
else:
|
|
logger.warning(f"Imagen API error: {response.status_code} - {response.text}")
|
|
# Fall back to Nano Banana (Gemini native)
|
|
logger.info("Falling back to Nano Banana (Gemini native image generation)")
|
|
return await _generate_nano_banana(input_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Imagen generation error: {e}")
|
|
# Fallback to Gemini native image generation
|
|
return await _generate_nano_banana(input_data)
|
|
|
|
return None, None
|
|
|
|
|
|
async def _upload_file_http(media_data: bytes, mime_type: str) -> Optional[str]:
|
|
"""
|
|
Upload file using raw HTTP to Google Generative AI Files API
|
|
(Alternative to outdated google-generativeai library)
|
|
Returns: file_uri
|
|
"""
|
|
if not settings.google_api_key:
|
|
return None
|
|
|
|
try:
|
|
url = f"https://generativelanguage.googleapis.com/upload/v1beta/files?key={settings.google_api_key}"
|
|
num_bytes = len(media_data)
|
|
|
|
headers = {
|
|
"X-Goog-Upload-Protocol": "resumable",
|
|
"X-Goog-Upload-Command": "start",
|
|
"X-Goog-Upload-Header-Content-Length": str(num_bytes),
|
|
"X-Goog-Upload-Header-Content-Type": mime_type,
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
metadata = {"file": {"display_name": f"nano_banana_upload_{uuid4()}"}}
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
# 1. Start Upload
|
|
response = await client.post(url, headers=headers, json=metadata)
|
|
if response.status_code != 200:
|
|
logger.error(f"Failed to start upload: {response.status_code} - {response.text}")
|
|
return None
|
|
|
|
upload_url = response.headers.get("x-goog-upload-url")
|
|
if not upload_url:
|
|
logger.error("No upload URL returned")
|
|
return None
|
|
|
|
# 2. Upload Bytes
|
|
headers_upload = {
|
|
"Content-Length": str(num_bytes),
|
|
"X-Goog-Upload-Offset": "0",
|
|
"X-Goog-Upload-Command": "upload, finalize"
|
|
}
|
|
|
|
response_upload = await client.post(upload_url, headers=headers_upload, content=media_data)
|
|
if response_upload.status_code != 200:
|
|
logger.error(f"Failed to upload data: {response_upload.status_code} - {response_upload.text}")
|
|
return None
|
|
|
|
data = response_upload.json()
|
|
file_uri = data.get("file", {}).get("uri")
|
|
logger.info(f"File uploaded successfully: {file_uri}")
|
|
return file_uri
|
|
|
|
except Exception as e:
|
|
logger.error(f"Upload error: {e}")
|
|
return None
|
|
|
|
|
|
async def _generate_nano_banana(input_data: dict, image_data: bytes = None, mime_type: str = "image/png") -> tuple:
|
|
"""
|
|
Generate or Edit image using Google Nano Banana Pro (Gemini 3 Vision)
|
|
STRICT IMPLEMENTATION based on AI Implementation Guide.
|
|
|
|
CRITICAL:
|
|
- Input uses 'inline_data' (snake_case)
|
|
- Output uses 'inlineData' (camelCase)
|
|
- Image part MUST be first in request
|
|
- Input MIME type forced to 'image/jpeg' for compatibility
|
|
- FinishReason MUST be checked first
|
|
"""
|
|
if not settings.google_api_key:
|
|
raise ValueError("GOOGLE_API_KEY not configured")
|
|
|
|
prompt = input_data.get("prompt", "")
|
|
if not prompt:
|
|
raise ValueError("Prompt is required")
|
|
|
|
# --- 1. CONFIGURATION ---
|
|
# Model: gemini-3-pro-image-preview
|
|
model_name = "gemini-3-pro-image-preview"
|
|
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent"
|
|
|
|
# --- 2. BUILD REQUEST (Strict Order: Image -> Text) ---
|
|
parts = []
|
|
|
|
# PART A: Input Image (for editing) - MUST BE FIRST
|
|
if image_data:
|
|
import base64
|
|
# Validate base64 (ensure bytes)
|
|
if isinstance(image_data, str):
|
|
image_data = image_data.encode('utf-8')
|
|
|
|
b64_image = base64.b64encode(image_data).decode("utf-8")
|
|
|
|
# Detect actual MIME type.
|
|
# CAUTION: The Guide example showed 'image/jpeg' because Gemini OUTPUTS jpeg.
|
|
# But for INPUT, if we send a PNG labeled as JPEG, the model ignores it.
|
|
# We must send the CORRECT mime type of the source file.
|
|
real_mime_type = determine_mime_type(image_data)
|
|
|
|
parts.append({
|
|
"inline_data": {
|
|
"mime_type": real_mime_type,
|
|
"data": b64_image
|
|
}
|
|
})
|
|
logger.info(f"Nano Banana: Added input image ({real_mime_type}, {len(b64_image)} chars)")
|
|
|
|
# PART B: Text Prompt - MUST BE SECOND
|
|
# Guide: "Simple prompts WILL fail... Minimum 10 words"
|
|
# We trust the user's prompt but ensure it's passed raw
|
|
parts.append({"text": prompt})
|
|
|
|
# --- 3. CONFIG PARAMETERS ---
|
|
gen_config = {
|
|
"responseModalities": ["IMAGE"],
|
|
"imageConfig": {
|
|
"aspectRatio": input_data.get("aspect_ratio", "16:9"),
|
|
"imageSize": input_data.get("image_size", "2K")
|
|
}
|
|
}
|
|
|
|
payload = {
|
|
"contents": [{"parts": parts}],
|
|
"generationConfig": gen_config
|
|
}
|
|
|
|
# Log request structure for debugging
|
|
# logger.info(f"Nano Banana Request: {{"contents": [...]}}")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
response = await client.post(
|
|
url,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"x-goog-api-key": settings.google_api_key
|
|
},
|
|
json=payload
|
|
)
|
|
|
|
# --- 4. HANDLE RESPONSE ---
|
|
|
|
if response.status_code != 200:
|
|
logger.error(f"Nano Banana HTTP Error: {response.status_code} - {response.text}")
|
|
# Try to extract friendly message
|
|
try:
|
|
err = response.json()
|
|
msg = err.get('error', {}).get('message', response.text)
|
|
raise ValueError(f"API Error: {msg}")
|
|
except:
|
|
raise ValueError(f"Nano Banana API refused connection: {response.status_code}")
|
|
|
|
data = response.json()
|
|
|
|
# --- 5. PARSE CANDIDATES (Check finishReason first) ---
|
|
|
|
candidates = data.get("candidates", [])
|
|
if not candidates:
|
|
logger.error(f"Nano Banana: No candidates returned. Response: {data}")
|
|
return None, None
|
|
|
|
candidate = candidates[0]
|
|
finish_reason = candidate.get("finishReason")
|
|
finish_message = candidate.get("finishMessage", "")
|
|
|
|
# Error Handling Map
|
|
if finish_reason == "IMAGE_RECITATION":
|
|
logger.warning("Nano Banana blocked: IMAGE_RECITATION (Prompt too generic)")
|
|
raise ValueError("Image generation blocked: Prompt too generic or recites existing content. Try a more creative description.")
|
|
|
|
if finish_reason == "SAFETY":
|
|
logger.warning("Nano Banana blocked: SAFETY")
|
|
raise ValueError("Image generation blocked by safety filters.")
|
|
|
|
if finish_reason != "STOP":
|
|
logger.error(f"Nano Banana failed with reason: {finish_reason} - {finish_message}")
|
|
raise ValueError(f"Generation failed: {finish_reason} - {finish_message}")
|
|
|
|
# --- 6. EXTRACT IMAGE (inlineData / camelCase) ---
|
|
|
|
content = candidate.get("content", {})
|
|
parts_resp = content.get("parts", [])
|
|
|
|
for part in parts_resp:
|
|
# Guide: "Response uses inlineData (camelCase)"
|
|
if "inlineData" in part:
|
|
inline_data_resp = part["inlineData"]
|
|
if "data" in inline_data_resp:
|
|
try:
|
|
import base64
|
|
img_bytes = base64.b64decode(inline_data_resp["data"])
|
|
# Guide: "Gemini returns image/jpeg" but we save as PNG usually
|
|
# Let's inspect mime type if available
|
|
resp_mime = inline_data_resp.get("mimeType", "image/png")
|
|
ext = "jpg" if "jpeg" in resp_mime else "png"
|
|
|
|
filename = f"nano_banana_{uuid4()}.{ext}"
|
|
logger.info(f"✓ Nano Banana success: {filename} ({len(img_bytes)} bytes)")
|
|
return img_bytes, filename
|
|
except Exception as e:
|
|
logger.error(f"Base64 decode error: {e}")
|
|
raise ValueError("Failed to decode generated image")
|
|
|
|
# Fallback if no inlineData found
|
|
logger.error(f"Nano Banana: No image data found in STOP response. Parts: {parts_resp}")
|
|
raise ValueError("API returned success status but no image data found.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Nano Banana Critical Error: {e}")
|
|
# Re-raise to let the frontend know the specific error
|
|
raise
|
|
|
|
return None, None
|
|
|
|
|
|
async def _generate_runway_image(input_data: dict) -> tuple:
|
|
"""Generate image using Runway Gen-4 Image"""
|
|
if not settings.runway_api_key:
|
|
raise ValueError("RUNWAY_API_KEY not configured")
|
|
|
|
prompt = input_data.get("prompt", "")
|
|
ratio = input_data.get("ratio", "1360:768")
|
|
seed = input_data.get("seed")
|
|
|
|
payload = {"model": "gen4_image", "promptText": prompt, "ratio": ratio if ratio in ["1024:1024", "1360:768"] else "1360:768"}
|
|
if seed and seed > 0:
|
|
payload["seed"] = seed
|
|
|
|
async with httpx.AsyncClient(timeout=180) as client:
|
|
response = await client.post(
|
|
"https://api.dev.runwayml.com/v1/text_to_image",
|
|
headers={
|
|
"Authorization": f"Bearer {settings.runway_api_key}",
|
|
"Content-Type": "application/json",
|
|
"X-Runway-Version": "2024-11-06"
|
|
},
|
|
json=payload
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
task_id = result.get("id")
|
|
|
|
# Poll for completion
|
|
import asyncio
|
|
for _ in range(90):
|
|
await asyncio.sleep(2)
|
|
status_resp = await client.get(
|
|
f"https://api.dev.runwayml.com/v1/tasks/{task_id}",
|
|
headers={"Authorization": f"Bearer {settings.runway_api_key}", "X-Runway-Version": "2024-11-06"}
|
|
)
|
|
status_data = status_resp.json()
|
|
if status_data.get("status") == "SUCCEEDED":
|
|
url = status_data.get("output", [None])[0]
|
|
if url:
|
|
img_resp = await client.get(url)
|
|
return img_resp.content, f"runway_gen4_{uuid4()}.png"
|
|
elif status_data.get("status") == "FAILED":
|
|
raise ValueError(f"Runway failed: {status_data.get('error')}")
|
|
|
|
return None, None
|