import asyncio import base64 import json import os import aiohttp from fastapi import HTTPException from google import genai from google.genai import types from openai import NOT_GIVEN, AsyncOpenAI from models.image_prompt import ImagePrompt from models.sql.image_asset import ImageAsset from utils.get_env import ( get_dall_e_3_quality_env, get_gpt_image_1_5_quality_env, get_pexels_api_key_env, get_open_webui_image_url_env, get_open_webui_image_api_key_env, ) from utils.get_env import get_pixabay_api_key_env from utils.get_env import get_comfyui_url_env from utils.get_env import get_comfyui_workflow_env from utils.image_provider import ( is_gpt_image_1_5_selected, is_image_generation_disabled, is_pixels_selected, is_pixabay_selected, is_gemini_flash_selected, is_nanobanana_pro_selected, is_dalle3_selected, is_comfyui_selected, is_open_webui_selected, ) import uuid class ImageGenerationService: def __init__(self, output_directory: str): self.output_directory = output_directory self.is_image_generation_disabled = is_image_generation_disabled() self.image_gen_func = self.get_image_gen_func() def get_image_gen_func(self): if self.is_image_generation_disabled: return None if is_pixabay_selected(): return self.get_image_from_pixabay elif is_pixels_selected(): return self.get_image_from_pexels elif is_gemini_flash_selected(): return self.generate_image_gemini_flash elif is_nanobanana_pro_selected(): return self.generate_image_nanobanana_pro elif is_dalle3_selected(): return self.generate_image_openai_dalle3 elif is_gpt_image_1_5_selected(): return self.generate_image_openai_gpt_image_1_5 elif is_comfyui_selected(): return self.generate_image_comfyui elif is_open_webui_selected(): return self.generate_image_open_webui return None def is_stock_provider_selected(self): return is_pixels_selected() or is_pixabay_selected() async def generate_image(self, prompt: ImagePrompt) -> str | ImageAsset: """ Generates an image based on the provided prompt. - If no image generation function is available, returns a placeholder image. - If the stock provider is selected, it uses the prompt directly, otherwise it uses the full image prompt with theme. - Output Directory is used for saving the generated image not the stock provider. """ if self.is_image_generation_disabled: print("Image generation is disabled. Using placeholder image.") return "/static/images/placeholder.jpg" if not self.image_gen_func: print("No image generation function found. Using placeholder image.") return "/static/images/placeholder.jpg" image_prompt = prompt.get_image_prompt( with_theme=not self.is_stock_provider_selected() ) print(f"Request - Generating Image for {image_prompt}") try: if self.is_stock_provider_selected(): image_path = await self.image_gen_func(image_prompt) else: image_path = await self.image_gen_func( image_prompt, self.output_directory ) if image_path: if image_path.startswith("http"): return image_path elif os.path.exists(image_path): return ImageAsset( path=image_path, is_uploaded=False, extras={ "prompt": prompt.prompt, "theme_prompt": prompt.theme_prompt, }, ) elif image_path.startswith("/app_data/") or image_path.startswith( "/static/" ): return image_path raise Exception(f"Image not found at {image_path}") except Exception as e: print(f"Error generating image: {e}") return "/static/images/placeholder.jpg" async def generate_image_openai( self, prompt: str, output_directory: str, model: str, quality: str ) -> str: client = AsyncOpenAI() result = await client.images.generate( model=model, prompt=prompt, n=1, quality=quality, response_format="b64_json" if model == "dall-e-3" else NOT_GIVEN, size="1024x1024", ) image_path = os.path.join(output_directory, f"{uuid.uuid4()}.png") with open(image_path, "wb") as f: f.write(base64.b64decode(result.data[0].b64_json)) return image_path async def generate_image_openai_dalle3( self, prompt: str, output_directory: str ) -> str: return await self.generate_image_openai( prompt, output_directory, "dall-e-3", get_dall_e_3_quality_env() or "standard", ) async def generate_image_openai_gpt_image_1_5( self, prompt: str, output_directory: str ) -> str: return await self.generate_image_openai( prompt, output_directory, "gpt-image-1.5", get_gpt_image_1_5_quality_env() or "medium", ) async def generate_image_open_webui( self, prompt: str, output_directory: str ) -> str: base_url = get_open_webui_image_url_env() if not base_url: raise ValueError("OPEN_WEBUI_IMAGE_URL environment variable is not set") base_url = base_url.rstrip("/") api_key = get_open_webui_image_api_key_env() or "" from urllib.parse import urlparse parsed = urlparse(base_url) origin = f"{parsed.scheme}://{parsed.netloc}" headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" payload = { "prompt": prompt, "n": 1, "size": "1024x1024", } async with aiohttp.ClientSession(trust_env=True) as session: resp = await session.post( f"{base_url}/images/generations", json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=300), ) if resp.status != 200: error_text = await resp.text() raise Exception( f"Open WebUI image generation returned {resp.status}: {error_text}" ) body = await resp.json() # Open WebUI returns a bare [...] array instead of {"data": [...]}. if isinstance(body, list): items = body elif isinstance(body, dict) and "data" in body: items = body["data"] else: raise Exception(f"Unexpected response format: {type(body)}") if not items: raise Exception("Open WebUI returned empty results") item = items[0] image_path = os.path.join(output_directory, f"{uuid.uuid4()}.png") if item.get("b64_json"): with open(image_path, "wb") as f: f.write(base64.b64decode(item["b64_json"])) elif item.get("url"): image_url = item["url"] # Open WebUI returns relative URLs like /api/v1/files/.../content if image_url.startswith("/"): image_url = origin + image_url dl_headers = {} if api_key: dl_headers["Authorization"] = f"Bearer {api_key}" dl_resp = await session.get( image_url, headers=dl_headers, timeout=aiohttp.ClientTimeout(total=120), ) if dl_resp.status != 200: raise Exception( f"Failed to download image: {dl_resp.status}" ) with open(image_path, "wb") as f: f.write(await dl_resp.read()) else: raise Exception("Open WebUI returned no image data") return image_path async def _generate_image_google( self, prompt: str, output_directory: str, model: str ) -> str: """Base method for Google image generation models.""" client = genai.Client() response = await asyncio.to_thread( client.models.generate_content, model=model, contents=prompt, config=types.GenerateContentConfig( response_modalities=["IMAGE"], ), ) # Latest SDK docs expose images in response.parts. response_parts = getattr(response, "parts", None) if not response_parts and getattr(response, "candidates", None): first_candidate = response.candidates[0] if response.candidates else None content = ( getattr(first_candidate, "content", None) if first_candidate else None ) response_parts = getattr(content, "parts", None) if content else None image_path = None for part in response_parts or []: if part.inline_data is not None: mime_type = getattr(part.inline_data, "mime_type", "") or "" ext = ( mime_type.split("/")[-1] if mime_type.startswith("image/") else "png" ) image_path = os.path.join(output_directory, f"{uuid.uuid4()}.{ext}") if hasattr(part, "as_image"): part.as_image().save(image_path) else: # Backward-compatible fallback if helper method is unavailable. image_data = getattr(part.inline_data, "data", None) if image_data is None: continue image_bytes = ( base64.b64decode(image_data) if isinstance(image_data, str) else image_data ) with open(image_path, "wb") as image_file: image_file.write(image_bytes) if not image_path: raise HTTPException( status_code=500, detail=f"No image generated by google {model}" ) return image_path async def generate_image_gemini_flash( self, prompt: str, output_directory: str ) -> str: """Generate image using Gemini Flash (gemini-2.5-flash-image).""" return await self._generate_image_google( prompt, output_directory, "gemini-2.5-flash-image" ) async def generate_image_nanobanana_pro( self, prompt: str, output_directory: str ) -> str: """Generate image using NanoBanana Pro (gemini-3-pro-image-preview).""" return await self._generate_image_google( prompt, output_directory, "gemini-3-pro-image-preview" ) async def get_image_from_pexels( self, prompt: str, api_key: str | None = None, limit: int = 1 ) -> str | list[str]: per_page = max(1, min(limit, 80)) resolved_api_key = (api_key or get_pexels_api_key_env() or "").strip() async with aiohttp.ClientSession(trust_env=True) as session: response = await session.get( "https://api.pexels.com/v1/search", params={"query": prompt, "per_page": per_page}, headers={"Authorization": resolved_api_key} if resolved_api_key else {}, timeout=aiohttp.ClientTimeout(total=20), ) if response.status in {401, 403}: raise HTTPException(status_code=401, detail="Invalid Pexels API key") if response.status != 200: error_text = await response.text() raise HTTPException( status_code=502, detail=f"Pexels request failed: {error_text}", ) data = await response.json() photos = data.get("photos", []) image_urls = [ photo.get("src", {}).get("large") for photo in photos if photo.get("src", {}).get("large") ] if limit <= 1: return image_urls[0] if image_urls else "" return image_urls[:limit] async def get_image_from_pixabay( self, prompt: str, api_key: str | None = None, limit: int = 1 ) -> str | list[str]: per_page = max(3, min(limit, 200)) resolved_api_key = (api_key or get_pixabay_api_key_env() or "").strip() async with aiohttp.ClientSession(trust_env=True) as session: response = await session.get( "https://pixabay.com/api/", params={ "key": resolved_api_key, "q": prompt[:99], "image_type": "photo", "per_page": per_page, }, timeout=aiohttp.ClientTimeout(total=20), ) if response.status in {401, 403}: error_text = await response.text() raise HTTPException( status_code=401, detail=f"Invalid Pixabay API key: {error_text}", ) if response.status == 400: error_text = await response.text() if "api key" in error_text.lower(): raise HTTPException( status_code=401, detail=f"Invalid Pixabay API key: {error_text}", ) raise HTTPException( status_code=400, detail=f"Pixabay request invalid: {error_text}", ) if response.status != 200: error_text = await response.text() raise HTTPException( status_code=502, detail=f"Pixabay request failed: {error_text}", ) data = await response.json() hits = data.get("hits", []) image_urls = [ hit.get("largeImageURL") for hit in hits if hit.get("largeImageURL") ] if limit <= 1: return image_urls[0] if image_urls else "" return image_urls[:limit] async def generate_image_comfyui(self, prompt: str, output_directory: str) -> str: """ Generate image using ComfyUI workflow API. User provides: - COMFYUI_URL: ComfyUI server URL (e.g., http://192.168.1.7:8188) - COMFYUI_WORKFLOW: Workflow JSON exported from ComfyUI The workflow should have a CLIPTextEncode node with "Positive" in the title where the prompt will be injected. Args: prompt: The text prompt for image generation output_directory: Directory to save the generated image Returns: Path to the generated image file """ comfyui_url = get_comfyui_url_env() workflow_json = get_comfyui_workflow_env() if not comfyui_url: raise ValueError("COMFYUI_URL environment variable is not set") if not workflow_json: raise ValueError( "COMFYUI_WORKFLOW environment variable is not set. Please provide a ComfyUI workflow JSON." ) # Ensure URL doesn't have trailing slash comfyui_url = comfyui_url.rstrip("/") # Parse the workflow JSON try: workflow = json.loads(workflow_json) except json.JSONDecodeError as e: raise ValueError(f"Invalid workflow JSON: {str(e)}") # Find and update the positive prompt node workflow = self._inject_prompt_into_workflow(workflow, prompt) async with aiohttp.ClientSession(trust_env=True) as session: # Step 1: Submit workflow prompt_id = await self._submit_comfyui_workflow( session, comfyui_url, workflow ) # Step 2: Wait for completion status_data = await self._wait_for_comfyui_completion( session, comfyui_url, prompt_id ) # Step 3: Download the generated image image_path = await self._download_comfyui_image( session, comfyui_url, status_data, prompt_id, output_directory ) return image_path def _inject_prompt_into_workflow(self, workflow: dict, prompt: str) -> dict: def norm(x) -> str: return str(x or "").strip().lower() def is_link(v) -> bool: return ( isinstance(v, (list, tuple)) and len(v) >= 2 and isinstance(v[0], str) and isinstance(v[1], int) ) preferred_keys = ( "text", "value", "prompt", "string", "content", "instruction", "input", "query" ) # string inputs that are usually NOT prompt text ignore_keys = { "filename_prefix", "ckpt_name", "clip_name", "vae_name", "unet_name", "sampler_name", "scheduler", "type", "device", "model", "lora_name" } visited = set() def try_set(node_id: str) -> bool: node_id = str(node_id) if node_id in visited: return False visited.add(node_id) node = workflow.get(node_id) if not isinstance(node, dict): return False inputs = node.setdefault("inputs", {}) # 1) preferred prompt-like keys for k in preferred_keys: if k in inputs and isinstance(inputs[k], str): inputs[k] = prompt return True # 2) fallback: exactly one unambiguous writable string field string_candidates = [ k for k, v in inputs.items() if isinstance(v, str) and k not in ignore_keys ] if len(string_candidates) == 1: inputs[string_candidates[0]] = prompt return True # 3) follow links from ANY input key (node-type agnostic) for v in inputs.values(): if is_link(v): if try_set(v[0]): return True elif isinstance(v, list): for item in v: if is_link(item) and try_set(item[0]): return True return False input_prompt_nodes = [ node_id for node_id, node_data in workflow.items() if norm(node_data.get("_meta", {}).get("title")) == "input prompt" ] if not input_prompt_nodes: raise ValueError( "Could not find node with title 'Input Prompt'. Rename your prompt node to 'Input Prompt'." ) for nid in input_prompt_nodes: if try_set(nid): return workflow raise ValueError( "Found 'Input Prompt', but no writable prompt string field was found directly or through linked nodes." ) async def _submit_comfyui_workflow( self, session: aiohttp.ClientSession, comfyui_url: str, workflow: dict ) -> str: """Submit workflow to ComfyUI and return the prompt_id.""" client_id = str(uuid.uuid4()) payload = {"prompt": workflow, "client_id": client_id} response = await session.post( f"{comfyui_url}/prompt", json=payload, timeout=aiohttp.ClientTimeout(total=30), ) if response.status != 200: error_text = await response.text() raise Exception(f"Failed to submit workflow to ComfyUI: {error_text}") data = await response.json() prompt_id = data.get("prompt_id") if not prompt_id: raise Exception("No prompt_id returned from ComfyUI") print(f"ComfyUI workflow submitted. Prompt ID: {prompt_id}") return prompt_id async def _wait_for_comfyui_completion( self, session: aiohttp.ClientSession, comfyui_url: str, prompt_id: str, timeout: int = 300, poll_interval: int = 4, ) -> dict: """Poll ComfyUI history endpoint until workflow completes.""" start_time = asyncio.get_event_loop().time() while True: elapsed = asyncio.get_event_loop().time() - start_time if elapsed > timeout: raise Exception(f"ComfyUI workflow timed out after {timeout} seconds") await asyncio.sleep(poll_interval) response = await session.get( f"{comfyui_url}/history/{prompt_id}", timeout=aiohttp.ClientTimeout(total=30), ) if response.status != 200: continue try: status_data = await response.json() except Exception as _: continue if prompt_id in status_data: execution_data = status_data[prompt_id] # Check for completion if "status" in execution_data: status = execution_data["status"] if status.get("completed", False): print("ComfyUI workflow completed successfully") return status_data if "error" in status: raise Exception(f"ComfyUI workflow error: {status['error']}") # Also check if outputs exist (alternative completion check) if "outputs" in execution_data and execution_data["outputs"]: print("ComfyUI workflow completed (outputs found)") return status_data print(f"Waiting for ComfyUI workflow... ({int(elapsed)}s)") async def _download_comfyui_image( self, session: aiohttp.ClientSession, comfyui_url: str, status_data: dict, prompt_id: str, output_directory: str, ) -> str: """Download the generated image from ComfyUI.""" if prompt_id not in status_data: raise Exception("Prompt ID not found in status data") outputs = status_data[prompt_id].get("outputs", {}) if not outputs: raise Exception("No outputs found in ComfyUI response") # Find the first image in outputs for node_id, node_output in outputs.items(): if "images" in node_output: for image_info in node_output["images"]: filename = image_info["filename"] subfolder = image_info.get("subfolder", "") # Build view params params = {"filename": filename, "type": "output"} if subfolder: params["subfolder"] = subfolder # Download the image response = await session.get( f"{comfyui_url}/view", params=params, timeout=aiohttp.ClientTimeout(total=60), ) if response.status == 200: image_data = await response.read() # Determine extension ext = filename.split(".")[-1] if "." in filename else "png" image_path = os.path.join( output_directory, f"{uuid.uuid4()}.{ext}" ) with open(image_path, "wb") as f: f.write(image_data) print(f"Downloaded image from ComfyUI: {image_path}") return image_path else: raise Exception(f"Failed to download image: {response.status}") raise Exception("No images found in ComfyUI outputs")