presenton/servers/fastapi/services/image_generation_service.py

666 lines
24 KiB
Python

import asyncio
import base64
import json
import os
import aiohttp
from fastapi import HTTPException
from google import genai
from google.genai import types
from openai import NOT_GIVEN, AsyncOpenAI
from models.image_prompt import ImagePrompt
from models.sql.image_asset import ImageAsset
from utils.get_env import (
get_dall_e_3_quality_env,
get_gpt_image_1_5_quality_env,
get_pexels_api_key_env,
get_open_webui_image_url_env,
get_open_webui_image_api_key_env,
)
from utils.get_env import get_pixabay_api_key_env
from utils.get_env import get_comfyui_url_env
from utils.get_env import get_comfyui_workflow_env
from utils.image_provider import (
is_gpt_image_1_5_selected,
is_image_generation_disabled,
is_pixels_selected,
is_pixabay_selected,
is_gemini_flash_selected,
is_nanobanana_pro_selected,
is_dalle3_selected,
is_comfyui_selected,
is_open_webui_selected,
)
import uuid
class ImageGenerationService:
def __init__(self, output_directory: str):
self.output_directory = output_directory
self.is_image_generation_disabled = is_image_generation_disabled()
self.image_gen_func = self.get_image_gen_func()
def get_image_gen_func(self):
if self.is_image_generation_disabled:
return None
if is_pixabay_selected():
return self.get_image_from_pixabay
elif is_pixels_selected():
return self.get_image_from_pexels
elif is_gemini_flash_selected():
return self.generate_image_gemini_flash
elif is_nanobanana_pro_selected():
return self.generate_image_nanobanana_pro
elif is_dalle3_selected():
return self.generate_image_openai_dalle3
elif is_gpt_image_1_5_selected():
return self.generate_image_openai_gpt_image_1_5
elif is_comfyui_selected():
return self.generate_image_comfyui
elif is_open_webui_selected():
return self.generate_image_open_webui
return None
def is_stock_provider_selected(self):
return is_pixels_selected() or is_pixabay_selected()
async def generate_image(self, prompt: ImagePrompt) -> str | ImageAsset:
"""
Generates an image based on the provided prompt.
- If no image generation function is available, returns a placeholder image.
- If the stock provider is selected, it uses the prompt directly,
otherwise it uses the full image prompt with theme.
- Output Directory is used for saving the generated image not the stock provider.
"""
if self.is_image_generation_disabled:
print("Image generation is disabled. Using placeholder image.")
return "/static/images/placeholder.jpg"
if not self.image_gen_func:
print("No image generation function found. Using placeholder image.")
return "/static/images/placeholder.jpg"
image_prompt = prompt.get_image_prompt(
with_theme=not self.is_stock_provider_selected()
)
print(f"Request - Generating Image for {image_prompt}")
try:
if self.is_stock_provider_selected():
image_path = await self.image_gen_func(image_prompt)
else:
image_path = await self.image_gen_func(
image_prompt, self.output_directory
)
if image_path:
if image_path.startswith("http"):
return image_path
elif os.path.exists(image_path):
return ImageAsset(
path=image_path,
is_uploaded=False,
extras={
"prompt": prompt.prompt,
"theme_prompt": prompt.theme_prompt,
},
)
elif image_path.startswith("/app_data/") or image_path.startswith(
"/static/"
):
return image_path
raise Exception(f"Image not found at {image_path}")
except Exception as e:
print(f"Error generating image: {e}")
return "/static/images/placeholder.jpg"
async def generate_image_openai(
self, prompt: str, output_directory: str, model: str, quality: str
) -> str:
client = AsyncOpenAI()
result = await client.images.generate(
model=model,
prompt=prompt,
n=1,
quality=quality,
response_format="b64_json" if model == "dall-e-3" else NOT_GIVEN,
size="1024x1024",
)
image_path = os.path.join(output_directory, f"{uuid.uuid4()}.png")
with open(image_path, "wb") as f:
f.write(base64.b64decode(result.data[0].b64_json))
return image_path
async def generate_image_openai_dalle3(
self, prompt: str, output_directory: str
) -> str:
return await self.generate_image_openai(
prompt,
output_directory,
"dall-e-3",
get_dall_e_3_quality_env() or "standard",
)
async def generate_image_openai_gpt_image_1_5(
self, prompt: str, output_directory: str
) -> str:
return await self.generate_image_openai(
prompt,
output_directory,
"gpt-image-1.5",
get_gpt_image_1_5_quality_env() or "medium",
)
async def generate_image_open_webui(
self, prompt: str, output_directory: str
) -> str:
base_url = get_open_webui_image_url_env()
if not base_url:
raise ValueError("OPEN_WEBUI_IMAGE_URL environment variable is not set")
base_url = base_url.rstrip("/")
api_key = get_open_webui_image_api_key_env() or ""
from urllib.parse import urlparse
parsed = urlparse(base_url)
origin = f"{parsed.scheme}://{parsed.netloc}"
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload = {
"prompt": prompt,
"n": 1,
"size": "1024x1024",
}
async with aiohttp.ClientSession(trust_env=True) as session:
resp = await session.post(
f"{base_url}/images/generations",
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=300),
)
if resp.status != 200:
error_text = await resp.text()
raise Exception(
f"Open WebUI image generation returned {resp.status}: {error_text}"
)
body = await resp.json()
# Open WebUI returns a bare [...] array instead of {"data": [...]}.
if isinstance(body, list):
items = body
elif isinstance(body, dict) and "data" in body:
items = body["data"]
else:
raise Exception(f"Unexpected response format: {type(body)}")
if not items:
raise Exception("Open WebUI returned empty results")
item = items[0]
image_path = os.path.join(output_directory, f"{uuid.uuid4()}.png")
if item.get("b64_json"):
with open(image_path, "wb") as f:
f.write(base64.b64decode(item["b64_json"]))
elif item.get("url"):
image_url = item["url"]
# Open WebUI returns relative URLs like /api/v1/files/.../content
if image_url.startswith("/"):
image_url = origin + image_url
dl_headers = {}
if api_key:
dl_headers["Authorization"] = f"Bearer {api_key}"
dl_resp = await session.get(
image_url,
headers=dl_headers,
timeout=aiohttp.ClientTimeout(total=120),
)
if dl_resp.status != 200:
raise Exception(
f"Failed to download image: {dl_resp.status}"
)
with open(image_path, "wb") as f:
f.write(await dl_resp.read())
else:
raise Exception("Open WebUI returned no image data")
return image_path
async def _generate_image_google(
self, prompt: str, output_directory: str, model: str
) -> str:
"""Base method for Google image generation models."""
client = genai.Client()
response = await asyncio.to_thread(
client.models.generate_content,
model=model,
contents=prompt,
config=types.GenerateContentConfig(
response_modalities=["IMAGE"],
),
)
# Latest SDK docs expose images in response.parts.
response_parts = getattr(response, "parts", None)
if not response_parts and getattr(response, "candidates", None):
first_candidate = response.candidates[0] if response.candidates else None
content = (
getattr(first_candidate, "content", None) if first_candidate else None
)
response_parts = getattr(content, "parts", None) if content else None
image_path = None
for part in response_parts or []:
if part.inline_data is not None:
mime_type = getattr(part.inline_data, "mime_type", "") or ""
ext = (
mime_type.split("/")[-1]
if mime_type.startswith("image/")
else "png"
)
image_path = os.path.join(output_directory, f"{uuid.uuid4()}.{ext}")
if hasattr(part, "as_image"):
part.as_image().save(image_path)
else:
# Backward-compatible fallback if helper method is unavailable.
image_data = getattr(part.inline_data, "data", None)
if image_data is None:
continue
image_bytes = (
base64.b64decode(image_data)
if isinstance(image_data, str)
else image_data
)
with open(image_path, "wb") as image_file:
image_file.write(image_bytes)
if not image_path:
raise HTTPException(
status_code=500, detail=f"No image generated by google {model}"
)
return image_path
async def generate_image_gemini_flash(
self, prompt: str, output_directory: str
) -> str:
"""Generate image using Gemini Flash (gemini-2.5-flash-image)."""
return await self._generate_image_google(
prompt, output_directory, "gemini-2.5-flash-image"
)
async def generate_image_nanobanana_pro(
self, prompt: str, output_directory: str
) -> str:
"""Generate image using NanoBanana Pro (gemini-3-pro-image-preview)."""
return await self._generate_image_google(
prompt, output_directory, "gemini-3-pro-image-preview"
)
async def get_image_from_pexels(
self, prompt: str, api_key: str | None = None, limit: int = 1
) -> str | list[str]:
per_page = max(1, min(limit, 80))
resolved_api_key = (api_key or get_pexels_api_key_env() or "").strip()
async with aiohttp.ClientSession(trust_env=True) as session:
response = await session.get(
"https://api.pexels.com/v1/search",
params={"query": prompt, "per_page": per_page},
headers={"Authorization": resolved_api_key} if resolved_api_key else {},
timeout=aiohttp.ClientTimeout(total=20),
)
if response.status in {401, 403}:
raise HTTPException(status_code=401, detail="Invalid Pexels API key")
if response.status != 200:
error_text = await response.text()
raise HTTPException(
status_code=502,
detail=f"Pexels request failed: {error_text}",
)
data = await response.json()
photos = data.get("photos", [])
image_urls = [
photo.get("src", {}).get("large")
for photo in photos
if photo.get("src", {}).get("large")
]
if limit <= 1:
return image_urls[0] if image_urls else ""
return image_urls[:limit]
async def get_image_from_pixabay(
self, prompt: str, api_key: str | None = None, limit: int = 1
) -> str | list[str]:
per_page = max(3, min(limit, 200))
resolved_api_key = (api_key or get_pixabay_api_key_env() or "").strip()
async with aiohttp.ClientSession(trust_env=True) as session:
response = await session.get(
"https://pixabay.com/api/",
params={
"key": resolved_api_key,
"q": prompt[:99],
"image_type": "photo",
"per_page": per_page,
},
timeout=aiohttp.ClientTimeout(total=20),
)
if response.status in {401, 403}:
error_text = await response.text()
raise HTTPException(
status_code=401,
detail=f"Invalid Pixabay API key: {error_text}",
)
if response.status == 400:
error_text = await response.text()
if "api key" in error_text.lower():
raise HTTPException(
status_code=401,
detail=f"Invalid Pixabay API key: {error_text}",
)
raise HTTPException(
status_code=400,
detail=f"Pixabay request invalid: {error_text}",
)
if response.status != 200:
error_text = await response.text()
raise HTTPException(
status_code=502,
detail=f"Pixabay request failed: {error_text}",
)
data = await response.json()
hits = data.get("hits", [])
image_urls = [
hit.get("largeImageURL") for hit in hits if hit.get("largeImageURL")
]
if limit <= 1:
return image_urls[0] if image_urls else ""
return image_urls[:limit]
async def generate_image_comfyui(self, prompt: str, output_directory: str) -> str:
"""
Generate image using ComfyUI workflow API.
User provides:
- COMFYUI_URL: ComfyUI server URL (e.g., http://192.168.1.7:8188)
- COMFYUI_WORKFLOW: Workflow JSON exported from ComfyUI
The workflow should have a CLIPTextEncode node with "Positive" in the title
where the prompt will be injected.
Args:
prompt: The text prompt for image generation
output_directory: Directory to save the generated image
Returns:
Path to the generated image file
"""
comfyui_url = get_comfyui_url_env()
workflow_json = get_comfyui_workflow_env()
if not comfyui_url:
raise ValueError("COMFYUI_URL environment variable is not set")
if not workflow_json:
raise ValueError(
"COMFYUI_WORKFLOW environment variable is not set. Please provide a ComfyUI workflow JSON."
)
# Ensure URL doesn't have trailing slash
comfyui_url = comfyui_url.rstrip("/")
# Parse the workflow JSON
try:
workflow = json.loads(workflow_json)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid workflow JSON: {str(e)}")
# Find and update the positive prompt node
workflow = self._inject_prompt_into_workflow(workflow, prompt)
async with aiohttp.ClientSession(trust_env=True) as session:
# Step 1: Submit workflow
prompt_id = await self._submit_comfyui_workflow(
session, comfyui_url, workflow
)
# Step 2: Wait for completion
status_data = await self._wait_for_comfyui_completion(
session, comfyui_url, prompt_id
)
# Step 3: Download the generated image
image_path = await self._download_comfyui_image(
session, comfyui_url, status_data, prompt_id, output_directory
)
return image_path
def _inject_prompt_into_workflow(self, workflow: dict, prompt: str) -> dict:
def norm(x) -> str:
return str(x or "").strip().lower()
def is_link(v) -> bool:
return (
isinstance(v, (list, tuple))
and len(v) >= 2
and isinstance(v[0], str)
and isinstance(v[1], int)
)
preferred_keys = (
"text", "value", "prompt", "string", "content", "instruction", "input", "query"
)
# string inputs that are usually NOT prompt text
ignore_keys = {
"filename_prefix", "ckpt_name", "clip_name", "vae_name", "unet_name",
"sampler_name", "scheduler", "type", "device", "model", "lora_name"
}
visited = set()
def try_set(node_id: str) -> bool:
node_id = str(node_id)
if node_id in visited:
return False
visited.add(node_id)
node = workflow.get(node_id)
if not isinstance(node, dict):
return False
inputs = node.setdefault("inputs", {})
# 1) preferred prompt-like keys
for k in preferred_keys:
if k in inputs and isinstance(inputs[k], str):
inputs[k] = prompt
return True
# 2) fallback: exactly one unambiguous writable string field
string_candidates = [
k for k, v in inputs.items()
if isinstance(v, str) and k not in ignore_keys
]
if len(string_candidates) == 1:
inputs[string_candidates[0]] = prompt
return True
# 3) follow links from ANY input key (node-type agnostic)
for v in inputs.values():
if is_link(v):
if try_set(v[0]):
return True
elif isinstance(v, list):
for item in v:
if is_link(item) and try_set(item[0]):
return True
return False
input_prompt_nodes = [
node_id
for node_id, node_data in workflow.items()
if norm(node_data.get("_meta", {}).get("title")) == "input prompt"
]
if not input_prompt_nodes:
raise ValueError(
"Could not find node with title 'Input Prompt'. Rename your prompt node to 'Input Prompt'."
)
for nid in input_prompt_nodes:
if try_set(nid):
return workflow
raise ValueError(
"Found 'Input Prompt', but no writable prompt string field was found directly or through linked nodes."
)
async def _submit_comfyui_workflow(
self, session: aiohttp.ClientSession, comfyui_url: str, workflow: dict
) -> str:
"""Submit workflow to ComfyUI and return the prompt_id."""
client_id = str(uuid.uuid4())
payload = {"prompt": workflow, "client_id": client_id}
response = await session.post(
f"{comfyui_url}/prompt",
json=payload,
timeout=aiohttp.ClientTimeout(total=30),
)
if response.status != 200:
error_text = await response.text()
raise Exception(f"Failed to submit workflow to ComfyUI: {error_text}")
data = await response.json()
prompt_id = data.get("prompt_id")
if not prompt_id:
raise Exception("No prompt_id returned from ComfyUI")
print(f"ComfyUI workflow submitted. Prompt ID: {prompt_id}")
return prompt_id
async def _wait_for_comfyui_completion(
self,
session: aiohttp.ClientSession,
comfyui_url: str,
prompt_id: str,
timeout: int = 300,
poll_interval: int = 4,
) -> dict:
"""Poll ComfyUI history endpoint until workflow completes."""
start_time = asyncio.get_event_loop().time()
while True:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > timeout:
raise Exception(f"ComfyUI workflow timed out after {timeout} seconds")
await asyncio.sleep(poll_interval)
response = await session.get(
f"{comfyui_url}/history/{prompt_id}",
timeout=aiohttp.ClientTimeout(total=30),
)
if response.status != 200:
continue
try:
status_data = await response.json()
except Exception as _:
continue
if prompt_id in status_data:
execution_data = status_data[prompt_id]
# Check for completion
if "status" in execution_data:
status = execution_data["status"]
if status.get("completed", False):
print("ComfyUI workflow completed successfully")
return status_data
if "error" in status:
raise Exception(f"ComfyUI workflow error: {status['error']}")
# Also check if outputs exist (alternative completion check)
if "outputs" in execution_data and execution_data["outputs"]:
print("ComfyUI workflow completed (outputs found)")
return status_data
print(f"Waiting for ComfyUI workflow... ({int(elapsed)}s)")
async def _download_comfyui_image(
self,
session: aiohttp.ClientSession,
comfyui_url: str,
status_data: dict,
prompt_id: str,
output_directory: str,
) -> str:
"""Download the generated image from ComfyUI."""
if prompt_id not in status_data:
raise Exception("Prompt ID not found in status data")
outputs = status_data[prompt_id].get("outputs", {})
if not outputs:
raise Exception("No outputs found in ComfyUI response")
# Find the first image in outputs
for node_id, node_output in outputs.items():
if "images" in node_output:
for image_info in node_output["images"]:
filename = image_info["filename"]
subfolder = image_info.get("subfolder", "")
# Build view params
params = {"filename": filename, "type": "output"}
if subfolder:
params["subfolder"] = subfolder
# Download the image
response = await session.get(
f"{comfyui_url}/view",
params=params,
timeout=aiohttp.ClientTimeout(total=60),
)
if response.status == 200:
image_data = await response.read()
# Determine extension
ext = filename.split(".")[-1] if "." in filename else "png"
image_path = os.path.join(
output_directory, f"{uuid.uuid4()}.{ext}"
)
with open(image_path, "wb") as f:
f.write(image_data)
print(f"Downloaded image from ComfyUI: {image_path}")
return image_path
else:
raise Exception(f"Failed to download image: {response.status}")
raise Exception("No images found in ComfyUI outputs")