Merge pull request #424 from presenton/feat/oauthV2
Feat/oauth v2 Introduces Login with ChatGPT Support for Electron and Docker Both
This commit is contained in:
commit
ebab3ba13e
48 changed files with 4467 additions and 28 deletions
|
|
@ -7,6 +7,8 @@ services:
|
|||
ports:
|
||||
# You can replace 5000 with any other port number of your choice to run Presenton on a different port number.
|
||||
- "5000:80"
|
||||
# Required for Codex OAuth callback (OpenAI redirects browser directly to localhost:1455)
|
||||
- "1455:1455"
|
||||
volumes:
|
||||
- ./app_data:/app_data
|
||||
environment:
|
||||
|
|
@ -23,6 +25,7 @@ services:
|
|||
- CUSTOM_LLM_URL=${CUSTOM_LLM_URL}
|
||||
- CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY}
|
||||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- CODEX_MODEL=${CODEX_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- TOOL_CALLS=${TOOL_CALLS}
|
||||
|
|
@ -48,6 +51,8 @@ services:
|
|||
ports:
|
||||
# You can replace 5000 with any other port number of your choice to run Presenton on a different port number.
|
||||
- "5000:80"
|
||||
# Required for Codex OAuth callback (OpenAI redirects browser directly to localhost:1455)
|
||||
- "1455:1455"
|
||||
volumes:
|
||||
- ./app_data:/app_data
|
||||
environment:
|
||||
|
|
@ -64,6 +69,7 @@ services:
|
|||
- CUSTOM_LLM_URL=${CUSTOM_LLM_URL}
|
||||
- CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY}
|
||||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- CODEX_MODEL=${CODEX_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- TOOL_CALLS=${TOOL_CALLS}
|
||||
|
|
@ -80,6 +86,8 @@ services:
|
|||
dockerfile: Dockerfile.dev
|
||||
ports:
|
||||
- "5000:80"
|
||||
# Required for Codex OAuth callback (OpenAI redirects browser directly to localhost:1455)
|
||||
- "1455:1455"
|
||||
volumes:
|
||||
- .:/app
|
||||
- ./app_data:/app_data
|
||||
|
|
@ -97,6 +105,7 @@ services:
|
|||
- CUSTOM_LLM_URL=${CUSTOM_LLM_URL}
|
||||
- CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY}
|
||||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- CODEX_MODEL=${CODEX_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- TOOL_CALLS=${TOOL_CALLS}
|
||||
|
|
@ -120,6 +129,8 @@ services:
|
|||
capabilities: [gpu]
|
||||
ports:
|
||||
- "5000:80"
|
||||
# Required for Codex OAuth callback (OpenAI redirects browser directly to localhost:1455)
|
||||
- "1455:1455"
|
||||
volumes:
|
||||
- .:/app
|
||||
- ./app_data:/app_data
|
||||
|
|
@ -137,6 +148,7 @@ services:
|
|||
- CUSTOM_LLM_URL=${CUSTOM_LLM_URL}
|
||||
- CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY}
|
||||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- CODEX_MODEL=${CODEX_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- TOOL_CALLS=${TOOL_CALLS}
|
||||
|
|
@ -145,4 +157,4 @@ services:
|
|||
- DATABASE_URL=${DATABASE_URL}
|
||||
- DISABLE_ANONYMOUS_TRACKING=${DISABLE_ANONYMOUS_TRACKING}
|
||||
- COMFYUI_URL=${COMFYUI_URL}
|
||||
- COMFYUI_WORKFLOW=${COMFYUI_WORKFLOW}
|
||||
- COMFYUI_WORKFLOW=${COMFYUI_WORKFLOW}
|
||||
|
|
|
|||
5
electron/app/types/index.d.ts
vendored
5
electron/app/types/index.d.ts
vendored
|
|
@ -69,6 +69,11 @@ interface UserConfig {
|
|||
COMFYUI_WORKFLOW?: string,
|
||||
DALL_E_3_QUALITY?: string,
|
||||
GPT_IMAGE_1_5_QUALITY?: string,
|
||||
CODEX_MODEL?: string,
|
||||
CODEX_ACCESS_TOKEN?: string,
|
||||
CODEX_REFRESH_TOKEN?: string,
|
||||
CODEX_TOKEN_EXPIRES?: string,
|
||||
CODEX_ACCOUNT_ID?: string,
|
||||
}
|
||||
|
||||
interface IPCStatus {
|
||||
|
|
|
|||
|
|
@ -38,6 +38,11 @@ export function setUserConfig(userConfig: UserConfig) {
|
|||
COMFYUI_WORKFLOW: userConfig.COMFYUI_WORKFLOW || existingConfig.COMFYUI_WORKFLOW,
|
||||
DALL_E_3_QUALITY: userConfig.DALL_E_3_QUALITY || existingConfig.DALL_E_3_QUALITY,
|
||||
GPT_IMAGE_1_5_QUALITY: userConfig.GPT_IMAGE_1_5_QUALITY || existingConfig.GPT_IMAGE_1_5_QUALITY,
|
||||
CODEX_MODEL: userConfig.CODEX_MODEL || existingConfig.CODEX_MODEL,
|
||||
CODEX_ACCESS_TOKEN: existingConfig.CODEX_ACCESS_TOKEN,
|
||||
CODEX_REFRESH_TOKEN: existingConfig.CODEX_REFRESH_TOKEN,
|
||||
CODEX_TOKEN_EXPIRES: existingConfig.CODEX_TOKEN_EXPIRES,
|
||||
CODEX_ACCOUNT_ID: existingConfig.CODEX_ACCOUNT_ID,
|
||||
}
|
||||
fs.writeFileSync(userConfigPath, JSON.stringify(mergedConfig))
|
||||
}
|
||||
|
|
|
|||
278
electron/servers/fastapi/api/v1/ppt/endpoints/codex_auth.py
Normal file
278
electron/servers/fastapi/api/v1/ppt/endpoints/codex_auth.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
"""
|
||||
OpenAI Codex OAuth endpoints.
|
||||
|
||||
Flow:
|
||||
1. POST /codex/auth/initiate — start the flow, get back an auth URL + session_id
|
||||
2. Browser opens the URL, user authenticates with OpenAI
|
||||
3. OpenAI redirects to http://localhost:1455/auth/callback (captured by local server)
|
||||
4. GET /codex/auth/status/{session_id} — poll until code captured; exchanges and stores tokens
|
||||
5. POST /codex/auth/exchange — manual fallback if browser callback didn't fire
|
||||
6. POST /codex/auth/refresh — refresh a stored token
|
||||
"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils.oauth.openai_codex import (
|
||||
OAuthCallbackServer,
|
||||
TokenSuccess,
|
||||
create_authorization_flow,
|
||||
exchange_authorization_code,
|
||||
get_account_id,
|
||||
parse_authorization_input,
|
||||
refresh_access_token,
|
||||
)
|
||||
from utils.get_env import (
|
||||
get_codex_access_token_env,
|
||||
get_codex_refresh_token_env,
|
||||
get_codex_token_expires_env,
|
||||
)
|
||||
from utils.set_env import (
|
||||
set_codex_access_token_env,
|
||||
set_codex_account_id_env,
|
||||
set_codex_refresh_token_env,
|
||||
set_codex_token_expires_env,
|
||||
set_codex_model_env,
|
||||
)
|
||||
from utils.user_config import save_codex_tokens_to_user_config
|
||||
|
||||
CODEX_AUTH_ROUTER = APIRouter(prefix="/codex/auth", tags=["Codex OAuth"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory session store {session_id: {"verifier": str, "state": str, "server": OAuthCallbackServer}}
|
||||
# Sessions are short-lived; garbage-collected when consumed.
|
||||
# ---------------------------------------------------------------------------
|
||||
_sessions: dict[str, dict] = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class InitiateResponse(BaseModel):
|
||||
session_id: str
|
||||
url: str
|
||||
instructions: str
|
||||
|
||||
|
||||
class StatusResponse(BaseModel):
|
||||
status: str # "pending" | "success" | "failed"
|
||||
account_id: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
|
||||
|
||||
class ExchangeRequest(BaseModel):
|
||||
session_id: str
|
||||
code: str # raw code OR full redirect URL OR code#state shorthand
|
||||
|
||||
|
||||
class ExchangeResponse(BaseModel):
|
||||
account_id: str
|
||||
|
||||
|
||||
class RefreshResponse(BaseModel):
|
||||
account_id: Optional[str]
|
||||
detail: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _store_token(result: TokenSuccess) -> Optional[str]:
|
||||
"""Persist token fields in env vars and userConfig.json. Returns account_id or None."""
|
||||
set_codex_access_token_env(result.access)
|
||||
set_codex_refresh_token_env(result.refresh)
|
||||
set_codex_token_expires_env(str(result.expires))
|
||||
account_id = get_account_id(result.access)
|
||||
if account_id:
|
||||
set_codex_account_id_env(account_id)
|
||||
save_codex_tokens_to_user_config()
|
||||
return account_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@CODEX_AUTH_ROUTER.post("/initiate", response_model=InitiateResponse)
|
||||
async def initiate_codex_auth():
|
||||
"""
|
||||
Start the OpenAI Codex OAuth flow.
|
||||
|
||||
Returns an authorization URL to open in the browser and a session_id to use
|
||||
when polling /status or calling /exchange. A local HTTP server is started
|
||||
on port 1455 to receive the redirect automatically.
|
||||
"""
|
||||
flow = create_authorization_flow()
|
||||
server = OAuthCallbackServer(state=flow.state)
|
||||
server_started = server.start()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
_sessions[session_id] = {
|
||||
"verifier": flow.verifier,
|
||||
"state": flow.state,
|
||||
"server": server,
|
||||
"server_started": server_started,
|
||||
}
|
||||
|
||||
instructions = (
|
||||
"Open the URL in your browser and complete the OpenAI login. "
|
||||
+ (
|
||||
"The callback will be captured automatically."
|
||||
if server_started
|
||||
else "Port 1455 could not be bound — paste the redirect URL or code into /exchange."
|
||||
)
|
||||
)
|
||||
|
||||
return InitiateResponse(
|
||||
session_id=session_id,
|
||||
url=flow.url,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
|
||||
@CODEX_AUTH_ROUTER.get("/status/{session_id}", response_model=StatusResponse)
|
||||
async def poll_codex_auth_status(session_id: str):
|
||||
"""
|
||||
Poll for the result of an ongoing OAuth flow.
|
||||
|
||||
Returns {"status": "pending"} until the callback server captures the code.
|
||||
On success the tokens are stored in environment variables and the session
|
||||
is cleaned up.
|
||||
"""
|
||||
session = _sessions.get(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found or already consumed")
|
||||
|
||||
server: OAuthCallbackServer = session["server"]
|
||||
|
||||
# Non-blocking peek — check whether the callback server already received a code
|
||||
code = server.get_code_nowait() if session.get("server_started") else None
|
||||
|
||||
if code is None:
|
||||
return StatusResponse(status="pending")
|
||||
|
||||
# We have a code — exchange it
|
||||
verifier: str = session["verifier"]
|
||||
result = exchange_authorization_code(code, verifier)
|
||||
|
||||
# Clean up session
|
||||
server.close()
|
||||
_sessions.pop(session_id, None)
|
||||
|
||||
if not isinstance(result, TokenSuccess):
|
||||
return StatusResponse(status="failed", detail=result.reason)
|
||||
|
||||
account_id = _store_token(result)
|
||||
return StatusResponse(status="success", account_id=account_id)
|
||||
|
||||
|
||||
@CODEX_AUTH_ROUTER.post("/exchange", response_model=ExchangeResponse)
|
||||
async def exchange_codex_code(body: ExchangeRequest):
|
||||
"""
|
||||
Manual code exchange fallback.
|
||||
|
||||
Accepts the session_id from /initiate and either:
|
||||
- a bare authorization code
|
||||
- the full redirect URL (http://localhost:1455/auth/callback?code=…&state=…)
|
||||
- the code#state shorthand
|
||||
|
||||
Exchanges the code for tokens and stores them in environment variables.
|
||||
"""
|
||||
session = _sessions.get(body.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found or already consumed")
|
||||
|
||||
parsed = parse_authorization_input(body.code)
|
||||
code = parsed.get("code")
|
||||
incoming_state = parsed.get("state")
|
||||
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Could not extract authorization code from input")
|
||||
|
||||
if incoming_state and incoming_state != session["state"]:
|
||||
raise HTTPException(status_code=400, detail="State mismatch — possible CSRF")
|
||||
|
||||
verifier: str = session["verifier"]
|
||||
server: OAuthCallbackServer = session["server"]
|
||||
|
||||
result = exchange_authorization_code(code, verifier)
|
||||
|
||||
server.close()
|
||||
_sessions.pop(body.session_id, None)
|
||||
|
||||
if not isinstance(result, TokenSuccess):
|
||||
raise HTTPException(status_code=502, detail=f"Token exchange failed: {result.reason}")
|
||||
|
||||
account_id = _store_token(result)
|
||||
if not account_id:
|
||||
raise HTTPException(status_code=502, detail="Token exchanged but could not extract account ID")
|
||||
|
||||
return ExchangeResponse(account_id=account_id)
|
||||
|
||||
|
||||
@CODEX_AUTH_ROUTER.post("/refresh", response_model=RefreshResponse)
|
||||
async def refresh_codex_token():
|
||||
"""
|
||||
Refresh the stored Codex OAuth access token using the refresh token.
|
||||
|
||||
Updates environment variables with the new tokens.
|
||||
"""
|
||||
refresh_token = get_codex_refresh_token_env()
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No Codex refresh token stored. Please authenticate first via /initiate",
|
||||
)
|
||||
|
||||
result = refresh_access_token(refresh_token)
|
||||
if not isinstance(result, TokenSuccess):
|
||||
raise HTTPException(status_code=502, detail=f"Token refresh failed: {result.reason}")
|
||||
|
||||
account_id = _store_token(result)
|
||||
return RefreshResponse(
|
||||
account_id=account_id,
|
||||
detail="Token refreshed successfully",
|
||||
)
|
||||
|
||||
|
||||
@CODEX_AUTH_ROUTER.get("/status", response_model=StatusResponse)
|
||||
async def get_codex_auth_status():
|
||||
"""
|
||||
Return whether a valid Codex OAuth token is currently stored.
|
||||
"""
|
||||
import time
|
||||
|
||||
access_token = get_codex_access_token_env()
|
||||
if not access_token:
|
||||
return StatusResponse(status="not_authenticated", detail="No access token stored")
|
||||
|
||||
expires_str = get_codex_token_expires_env()
|
||||
if expires_str:
|
||||
try:
|
||||
expires_ms = int(expires_str)
|
||||
now_ms = int(time.time() * 1000)
|
||||
if now_ms >= expires_ms:
|
||||
return StatusResponse(status="expired", detail="Access token has expired — call /refresh")
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
account_id = get_account_id(access_token)
|
||||
return StatusResponse(status="authenticated", account_id=account_id)
|
||||
|
||||
|
||||
@CODEX_AUTH_ROUTER.post("/logout")
|
||||
async def logout_codex():
|
||||
"""
|
||||
Clear all stored Codex OAuth credentials from environment variables and userConfig.json.
|
||||
"""
|
||||
set_codex_access_token_env("")
|
||||
set_codex_refresh_token_env("")
|
||||
set_codex_token_expires_env("")
|
||||
set_codex_account_id_env("")
|
||||
set_codex_model_env("")
|
||||
save_codex_tokens_to_user_config()
|
||||
return {"detail": "Logged out successfully"}
|
||||
|
|
@ -14,6 +14,7 @@ from api.v1.ppt.endpoints.images import IMAGES_ROUTER
|
|||
from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER
|
||||
from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER
|
||||
from api.v1.ppt.endpoints.slide import SLIDE_ROUTER
|
||||
from api.v1.ppt.endpoints.codex_auth import CODEX_AUTH_ROUTER
|
||||
from api.v1.ppt.endpoints.pptx_slides import PPTX_FONTS_ROUTER
|
||||
|
||||
|
||||
|
|
@ -36,4 +37,5 @@ API_V1_PPT_ROUTER.include_router(PDF_SLIDES_ROUTER)
|
|||
API_V1_PPT_ROUTER.include_router(OPENAI_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ANTHROPIC_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(GOOGLE_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(CODEX_AUTH_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(PPTX_FONTS_ROUTER)
|
||||
|
|
|
|||
|
|
@ -4,3 +4,4 @@ OPENAI_URL = "https://api.openai.com/v1"
|
|||
DEFAULT_OPENAI_MODEL = "gpt-4.1"
|
||||
DEFAULT_GOOGLE_MODEL = "models/gemini-2.5-flash"
|
||||
DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514"
|
||||
DEFAULT_CODEX_MODEL = "gpt-5.3-codex-spark"
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@ class LLMProvider(Enum):
|
|||
GOOGLE = "google"
|
||||
ANTHROPIC = "anthropic"
|
||||
CUSTOM = "custom"
|
||||
CODEX = "codex"
|
||||
|
|
|
|||
|
|
@ -48,3 +48,10 @@ class UserConfig(BaseModel):
|
|||
|
||||
# Web Search
|
||||
WEB_GROUNDING: Optional[bool] = None
|
||||
|
||||
# Codex OAuth (ChatGPT)
|
||||
CODEX_MODEL: Optional[str] = None
|
||||
CODEX_ACCESS_TOKEN: Optional[str] = None
|
||||
CODEX_REFRESH_TOKEN: Optional[str] = None
|
||||
CODEX_TOKEN_EXPIRES: Optional[str] = None
|
||||
CODEX_ACCOUNT_ID: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import asyncio
|
||||
import dirtyjson
|
||||
import json
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import AsyncGenerator, List, Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from openai import APIStatusError, AsyncOpenAI, OpenAIError
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
|
|
@ -44,6 +44,10 @@ from utils.async_iterator import iterator_to_async
|
|||
from utils.dummy_functions import do_nothing_async
|
||||
from utils.get_env import (
|
||||
get_anthropic_api_key_env,
|
||||
get_codex_access_token_env,
|
||||
get_codex_account_id_env,
|
||||
get_codex_refresh_token_env,
|
||||
get_codex_token_expires_env,
|
||||
get_custom_llm_api_key_env,
|
||||
get_custom_llm_url_env,
|
||||
get_disable_thinking_env,
|
||||
|
|
@ -53,6 +57,12 @@ from utils.get_env import (
|
|||
get_tool_calls_env,
|
||||
get_web_grounding_env,
|
||||
)
|
||||
from utils.set_env import (
|
||||
set_codex_access_token_env,
|
||||
set_codex_account_id_env,
|
||||
set_codex_refresh_token_env,
|
||||
set_codex_token_expires_env,
|
||||
)
|
||||
from utils.llm_provider import get_llm_provider, get_model
|
||||
from utils.parsers import parse_bool_or_none
|
||||
from utils.schema_utils import (
|
||||
|
|
@ -62,6 +72,7 @@ from utils.schema_utils import (
|
|||
)
|
||||
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self):
|
||||
self.llm_provider = get_llm_provider()
|
||||
|
|
@ -100,10 +111,12 @@ class LLMClient:
|
|||
return self._get_ollama_client()
|
||||
case LLMProvider.CUSTOM:
|
||||
return self._get_custom_client()
|
||||
case LLMProvider.CODEX:
|
||||
return self._get_codex_client()
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="LLM Provider must be either openai, google, anthropic, ollama, or custom",
|
||||
detail="LLM Provider must be either openai, google, anthropic, ollama, custom, or codex",
|
||||
)
|
||||
|
||||
def _get_openai_client(self):
|
||||
|
|
@ -147,6 +160,74 @@ class LLMClient:
|
|||
api_key=get_custom_llm_api_key_env() or "null",
|
||||
)
|
||||
|
||||
def _get_codex_headers(self) -> dict:
|
||||
"""Return the HTTP headers required for Codex Responses API requests.
|
||||
|
||||
Handles token auto-refresh if the stored token is expired or within
|
||||
60 s of expiry before building the header dict.
|
||||
"""
|
||||
access_token = get_codex_access_token_env()
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Codex OAuth access token is not set. Please authenticate via /api/v1/ppt/codex/auth/initiate",
|
||||
)
|
||||
|
||||
# Auto-refresh if the token is expired or about to expire (within 60 s)
|
||||
expires_str = get_codex_token_expires_env()
|
||||
if expires_str:
|
||||
try:
|
||||
expires_ms = int(expires_str)
|
||||
now_ms = int(__import__("time").time() * 1000)
|
||||
if now_ms >= expires_ms - 60_000:
|
||||
refresh_token = get_codex_refresh_token_env()
|
||||
if refresh_token:
|
||||
from utils.oauth.openai_codex import (
|
||||
get_account_id,
|
||||
refresh_access_token,
|
||||
TokenSuccess,
|
||||
)
|
||||
result = refresh_access_token(refresh_token)
|
||||
if isinstance(result, TokenSuccess):
|
||||
set_codex_access_token_env(result.access)
|
||||
set_codex_refresh_token_env(result.refresh)
|
||||
set_codex_token_expires_env(str(result.expires))
|
||||
account_id = get_account_id(result.access)
|
||||
if account_id:
|
||||
set_codex_account_id_env(account_id)
|
||||
access_token = result.access
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
account_id = get_codex_account_id_env() or ""
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"chatgpt-account-id": account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"originator": "pi",
|
||||
"content-type": "application/json",
|
||||
"accept": "text/event-stream",
|
||||
}
|
||||
|
||||
def _get_codex_client(self) -> AsyncOpenAI:
|
||||
"""Return an AsyncOpenAI client configured for the Codex Responses API.
|
||||
Client is built per call so headers/token are fresh after refresh.
|
||||
Only Codex-specific headers are passed; content-type and accept are left
|
||||
to the SDK so the server does not reject the request.
|
||||
"""
|
||||
headers = self._get_codex_headers()
|
||||
access_token = (headers.get("Authorization") or "").replace("Bearer ", "").strip()
|
||||
skip = {"authorization", "content-type", "accept"}
|
||||
default_headers = {
|
||||
k: v for k, v in headers.items() if k.lower() not in skip
|
||||
}
|
||||
return AsyncOpenAI(
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
api_key=access_token or "codex",
|
||||
default_headers=default_headers,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
# ? Prompts
|
||||
def _get_system_prompt(self, messages: List[LLMMessage]) -> str:
|
||||
for message in messages:
|
||||
|
|
@ -401,6 +482,147 @@ class LLMClient:
|
|||
depth=depth,
|
||||
)
|
||||
|
||||
async def _generate_codex(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Generate plain text using the Codex Responses API. On tool calls, run
|
||||
handlers and recurse (same pattern as _generate_openai).
|
||||
"""
|
||||
_MAX_RECURSION_DEPTH = 5
|
||||
client: AsyncOpenAI = self._client
|
||||
|
||||
# Flatten tools to Responses API format
|
||||
responses_tools: Optional[List[dict]] = None
|
||||
if tools:
|
||||
responses_tools = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or tool) if isinstance(tool, dict) else {}
|
||||
if isinstance(fn, dict):
|
||||
responses_tools.append({
|
||||
"type": "function",
|
||||
"name": fn.get("name", ""),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
})
|
||||
else:
|
||||
responses_tools.append(tool)
|
||||
|
||||
# Build instructions + input (same shape as _stream_codex_structured)
|
||||
instructions = self._get_system_prompt(messages) or None
|
||||
input_payload: List[Dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if isinstance(m, LLMSystemMessage):
|
||||
continue
|
||||
if isinstance(m, LLMUserMessage):
|
||||
input_payload.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": m.content}],
|
||||
})
|
||||
elif isinstance(m, OpenAIAssistantMessage):
|
||||
text = m.content or ""
|
||||
if text:
|
||||
input_payload.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": text}],
|
||||
})
|
||||
else:
|
||||
text = getattr(m, "content", "") or ""
|
||||
if text:
|
||||
input_payload.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}],
|
||||
})
|
||||
|
||||
create_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
if instructions:
|
||||
create_kwargs["instructions"] = instructions
|
||||
if input_payload:
|
||||
create_kwargs["input"] = input_payload
|
||||
if responses_tools:
|
||||
create_kwargs["tools"] = responses_tools
|
||||
if max_tokens is not None:
|
||||
create_kwargs["max_output_tokens"] = max_tokens
|
||||
|
||||
stream = await client.responses.create(**create_kwargs)
|
||||
|
||||
def _event_dict(ev: Any) -> dict:
|
||||
if hasattr(ev, "model_dump"):
|
||||
return ev.model_dump()
|
||||
return {
|
||||
"type": getattr(ev, "type", None),
|
||||
"delta": getattr(ev, "delta", None),
|
||||
"item": getattr(ev, "item", None),
|
||||
"message": getattr(ev, "message", None),
|
||||
}
|
||||
|
||||
text_parts: List[str] = []
|
||||
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async for ev in stream:
|
||||
event = _event_dict(ev) if not isinstance(ev, dict) else ev
|
||||
event_type = event.get("type") or ""
|
||||
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = event.get("delta") or ""
|
||||
if delta:
|
||||
text_parts.append(delta)
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
cid = item.get("call_id") or item.get("id", "")
|
||||
tool_calls_by_id[cid] = item
|
||||
elif event_type in ("response.error", "response.failed", "error"):
|
||||
err = event.get("message") or event.get("error") or str(event)
|
||||
raise HTTPException(status_code=502, detail=f"Codex error: {err}"[:400])
|
||||
|
||||
if tool_calls_by_id and responses_tools and depth < _MAX_RECURSION_DEPTH:
|
||||
parsed_tool_calls = [
|
||||
OpenAIToolCall(
|
||||
id=cid,
|
||||
type="function",
|
||||
function=OpenAIToolCallFunction(
|
||||
name=data.get("name", ""),
|
||||
arguments=data.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
for cid, data in tool_calls_by_id.items()
|
||||
]
|
||||
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
||||
parsed_tool_calls
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
OpenAIAssistantMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[tc.model_dump() for tc in parsed_tool_calls],
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
return await self._generate_codex(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
depth=depth + 1,
|
||||
)
|
||||
|
||||
return "".join(text_parts) or None
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -419,6 +641,13 @@ class LLMClient:
|
|||
max_tokens=max_tokens,
|
||||
tools=parsed_tools,
|
||||
)
|
||||
case LLMProvider.CODEX:
|
||||
content = await self._generate_codex(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=parsed_tools,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
content = await self._generate_google(
|
||||
model=model,
|
||||
|
|
@ -566,6 +795,48 @@ class LLMClient:
|
|||
return content
|
||||
return None
|
||||
|
||||
async def _generate_codex_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
extra_body: Optional[dict] = None,
|
||||
depth: int = 0,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Generate structured Codex output using the Responses API.
|
||||
|
||||
This reuses the streaming Codex structured implementation and simply
|
||||
accumulates the streamed JSON chunks into a single string, then parses
|
||||
it at the root call.
|
||||
"""
|
||||
# Reuse the Responses API streaming implementation for Codex.
|
||||
accumulated: List[str] = []
|
||||
async for chunk in self._stream_codex_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
extra_body=extra_body,
|
||||
depth=depth,
|
||||
):
|
||||
accumulated.append(chunk)
|
||||
|
||||
raw = "".join(accumulated)
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
# At the root level we parse into a dict; recursive calls just
|
||||
# propagate the raw JSON/text, mirroring other providers.
|
||||
if depth == 0:
|
||||
return dict(dirtyjson.loads(raw))
|
||||
return {"raw": raw}
|
||||
|
||||
async def _generate_google_structured(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -795,6 +1066,15 @@ class LLMClient:
|
|||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.CODEX:
|
||||
content = await self._generate_codex_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
content = await self._generate_google_structured(
|
||||
model=model,
|
||||
|
|
@ -1068,6 +1348,157 @@ class LLMClient:
|
|||
):
|
||||
yield event
|
||||
|
||||
async def _stream_codex(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream plain text from Codex (Responses API). On tool calls, execute tools
|
||||
and recurse, mirroring _stream_openai but using Responses events.
|
||||
"""
|
||||
_MAX_RECURSION_DEPTH = 5
|
||||
client: AsyncOpenAI = (
|
||||
self._get_codex_client()
|
||||
if self.llm_provider == LLMProvider.CODEX
|
||||
else self._client
|
||||
)
|
||||
|
||||
# Flatten tools to Responses API format
|
||||
responses_tools: Optional[List[dict]] = None
|
||||
if tools:
|
||||
responses_tools = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or tool) if isinstance(tool, dict) else {}
|
||||
if isinstance(fn, dict):
|
||||
responses_tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": fn.get("name", ""),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
else:
|
||||
responses_tools.append(tool)
|
||||
|
||||
# Build instructions + input (same shape as _generate_codex/_stream_codex_structured)
|
||||
instructions = self._get_system_prompt(messages) or None
|
||||
input_payload: List[Dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if isinstance(m, LLMSystemMessage):
|
||||
continue
|
||||
if isinstance(m, LLMUserMessage):
|
||||
input_payload.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": m.content}],
|
||||
}
|
||||
)
|
||||
elif isinstance(m, OpenAIAssistantMessage):
|
||||
text = m.content or ""
|
||||
if text:
|
||||
input_payload.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": text}],
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = getattr(m, "content", "") or ""
|
||||
if text:
|
||||
input_payload.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}],
|
||||
}
|
||||
)
|
||||
|
||||
create_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
if instructions:
|
||||
create_kwargs["instructions"] = instructions
|
||||
if input_payload:
|
||||
create_kwargs["input"] = input_payload
|
||||
if responses_tools:
|
||||
create_kwargs["tools"] = responses_tools
|
||||
if max_tokens is not None:
|
||||
create_kwargs["max_output_tokens"] = max_tokens
|
||||
|
||||
stream = await client.responses.create(**create_kwargs)
|
||||
|
||||
def _event_dict(ev: Any) -> dict:
|
||||
if hasattr(ev, "model_dump"):
|
||||
return ev.model_dump()
|
||||
return {
|
||||
"type": getattr(ev, "type", None),
|
||||
"delta": getattr(ev, "delta", None),
|
||||
"item": getattr(ev, "item", None),
|
||||
"message": getattr(ev, "message", None),
|
||||
}
|
||||
|
||||
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async for ev in stream:
|
||||
event = _event_dict(ev) if not isinstance(ev, dict) else ev
|
||||
event_type = event.get("type") or ""
|
||||
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = event.get("delta") or ""
|
||||
if delta:
|
||||
yield delta
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
cid = item.get("call_id") or item.get("id", "")
|
||||
tool_calls_by_id[cid] = item
|
||||
elif event_type in ("response.error", "response.failed", "error"):
|
||||
err = event.get("message") or event.get("error") or str(event)
|
||||
raise HTTPException(status_code=502, detail=f"Codex stream error: {err}"[:400])
|
||||
|
||||
if tool_calls_by_id and responses_tools and depth < _MAX_RECURSION_DEPTH:
|
||||
parsed_tool_calls = [
|
||||
OpenAIToolCall(
|
||||
id=cid,
|
||||
type="function",
|
||||
function=OpenAIToolCallFunction(
|
||||
name=data.get("name", ""),
|
||||
arguments=data.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
for cid, data in tool_calls_by_id.items()
|
||||
]
|
||||
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
||||
parsed_tool_calls
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
OpenAIAssistantMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[tc.model_dump() for tc in parsed_tool_calls],
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
async for chunk in self._stream_codex(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
depth=depth + 1,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _stream_ollama(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1112,6 +1543,13 @@ class LLMClient:
|
|||
max_tokens=max_tokens,
|
||||
tools=parsed_tools,
|
||||
)
|
||||
case LLMProvider.CODEX:
|
||||
return self._stream_codex(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=parsed_tools,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self._stream_google(
|
||||
model=model,
|
||||
|
|
@ -1286,6 +1724,291 @@ class LLMClient:
|
|||
):
|
||||
yield event
|
||||
|
||||
|
||||
|
||||
async def _stream_codex_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
extra_body: Optional[dict] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream structured responses using OpenAI's Responses API (Codex-style models).
|
||||
|
||||
This implementation is intentionally separate from ChatCompletion-based streaming
|
||||
because the Responses API uses a fundamentally different event model.
|
||||
|
||||
Why this function exists:
|
||||
|
||||
1. The Responses API does NOT return `choices[].delta` like ChatCompletions.
|
||||
Instead, it streams typed events such as:
|
||||
- response.output_text.delta
|
||||
- response.output_tool_call.delta
|
||||
- response.completed
|
||||
- response.error
|
||||
|
||||
2. Structured output can be achieved in two ways:
|
||||
a) Native JSON schema enforcement via `response_format`
|
||||
b) Tool-call-based structured output using a synthetic `ResponseSchema` tool
|
||||
|
||||
This function supports both approaches. When tool-call structured mode is enabled,
|
||||
a dynamic `ResponseSchema` tool is injected so the model returns structured data
|
||||
as tool call arguments.
|
||||
|
||||
3. Tool calls must be accumulated incrementally.
|
||||
The Responses API streams tool call arguments in chunks (`arguments_delta`),
|
||||
so we reconstruct the full argument payload before executing the tool.
|
||||
|
||||
4. Recursive tool execution is supported.
|
||||
If the model calls external tools (e.g., web search), we:
|
||||
- Execute the tools asynchronously
|
||||
- Append tool results as new messages
|
||||
- Reinvoke the model recursively
|
||||
This enables multi-step reasoning and grounding workflows.
|
||||
|
||||
5. Provider abstraction is preserved.
|
||||
The Responses API event format is converted into our internal tool-call model
|
||||
before being passed to the tool handler layer. This prevents SDK-specific
|
||||
structures from leaking into business logic.
|
||||
|
||||
6. Strict schema enforcement (optional).
|
||||
When `strict=True`, the provided JSON schema is hardened before being sent
|
||||
to the model to reduce malformed outputs.
|
||||
|
||||
Important architectural note:
|
||||
This function MUST NOT assume ChatCompletion-style streaming fields like
|
||||
`choices`, `delta.content`, or `delta.tool_calls`. It strictly follows the
|
||||
Responses API event model.
|
||||
|
||||
This separation ensures:
|
||||
- Future compatibility with GPT-5 / Codex models
|
||||
- Clean provider abstraction
|
||||
- Streaming-safe structured JSON assembly
|
||||
- Robust multi-tool recursive execution
|
||||
"""
|
||||
client: AsyncOpenAI = self._client
|
||||
response_schema = response_format
|
||||
# Apply strict schema once at root
|
||||
if strict and depth == 0:
|
||||
response_schema = ensure_strict_json_schema(
|
||||
response_schema,
|
||||
path=(),
|
||||
root=response_schema,
|
||||
)
|
||||
|
||||
# Codex Responses API requires all array schemas to specify `items`.
|
||||
def _fix_arrays(node: Any) -> Any:
|
||||
if isinstance(node, dict):
|
||||
# Add default items for arrays missing them
|
||||
if node.get("type") == "array" and "items" not in node:
|
||||
node["items"] = {"type": "string"}
|
||||
for key, value in list(node.items()):
|
||||
node[key] = _fix_arrays(value)
|
||||
elif isinstance(node, list):
|
||||
for idx, value in enumerate(node):
|
||||
node[idx] = _fix_arrays(value)
|
||||
return node
|
||||
|
||||
response_schema = _fix_arrays(response_schema)
|
||||
|
||||
# Responses API tool format: flat {type, name, description, parameters}
|
||||
response_schema_tool = {
|
||||
"type": "function",
|
||||
"name": "ResponseSchema",
|
||||
"description": "Provide structured response",
|
||||
"parameters": response_schema,
|
||||
}
|
||||
all_tools: List[dict] = [response_schema_tool]
|
||||
if tools:
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or tool) if isinstance(tool, dict) else {}
|
||||
if isinstance(fn, dict):
|
||||
all_tools.append({
|
||||
"type": "function",
|
||||
"name": fn.get("name", ""),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
})
|
||||
else:
|
||||
all_tools.append(tool)
|
||||
|
||||
# Build instructions + input like Codex adapter (instructions from system; input_text/output_text)
|
||||
instructions = self._get_system_prompt(messages) or None
|
||||
input_payload: List[Dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if isinstance(m, LLMSystemMessage):
|
||||
continue
|
||||
if isinstance(m, LLMUserMessage):
|
||||
input_payload.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": m.content}],
|
||||
})
|
||||
elif isinstance(m, OpenAIAssistantMessage):
|
||||
text = m.content or ""
|
||||
if text:
|
||||
input_payload.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": text}],
|
||||
})
|
||||
else:
|
||||
text = getattr(m, "content", "") or ""
|
||||
if text:
|
||||
input_payload.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}],
|
||||
})
|
||||
|
||||
# Force model to use ResponseSchema for structured output
|
||||
tool_choice = {"type": "function", "name": "ResponseSchema"}
|
||||
create_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"tool_choice": tool_choice,
|
||||
"parallel_tool_calls": True,
|
||||
"tools": all_tools,
|
||||
}
|
||||
if instructions:
|
||||
create_kwargs["instructions"] = instructions
|
||||
if input_payload:
|
||||
create_kwargs["input"] = input_payload
|
||||
if max_tokens is not None:
|
||||
create_kwargs["max_output_tokens"] = max_tokens
|
||||
if extra_body:
|
||||
create_kwargs.update(extra_body)
|
||||
|
||||
stream = await client.responses.create(**create_kwargs)
|
||||
|
||||
|
||||
def _event_dict(ev: Any) -> dict:
|
||||
if hasattr(ev, "model_dump"):
|
||||
return ev.model_dump()
|
||||
return {
|
||||
"type": getattr(ev, "type", None),
|
||||
"delta": getattr(ev, "delta", None),
|
||||
"arguments": getattr(ev, "arguments", None),
|
||||
"arguments_delta": getattr(ev, "arguments_delta", None),
|
||||
"item": getattr(ev, "item", None),
|
||||
"id": getattr(ev, "id", None),
|
||||
"name": getattr(ev, "name", None),
|
||||
"error": getattr(ev, "error", None),
|
||||
"message": getattr(ev, "message", None),
|
||||
}
|
||||
|
||||
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
|
||||
current_call_id: Optional[str] = None
|
||||
has_response_schema_tool_call = False
|
||||
|
||||
async for ev in stream:
|
||||
event = _event_dict(ev) if not isinstance(ev, dict) else ev
|
||||
event_type = event.get("type") or ""
|
||||
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call" and item.get("name") == "ResponseSchema":
|
||||
current_call_id = item.get("call_id") or item.get("id")
|
||||
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
if current_call_id:
|
||||
delta = event.get("delta") or ""
|
||||
if delta:
|
||||
has_response_schema_tool_call = True
|
||||
yield delta
|
||||
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
if event.get("name") == "ResponseSchema":
|
||||
args = event.get("arguments") or ""
|
||||
if args:
|
||||
has_response_schema_tool_call = True
|
||||
yield args
|
||||
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
cid = item.get("call_id") or item.get("id", "")
|
||||
tool_calls_by_id[cid] = item
|
||||
if item.get("name") == "ResponseSchema":
|
||||
args = item.get("arguments") or ""
|
||||
if args:
|
||||
has_response_schema_tool_call = True
|
||||
yield args
|
||||
|
||||
elif event_type == "response.output_tool_call.delta":
|
||||
call_id = event.get("id")
|
||||
name = event.get("name")
|
||||
arguments_delta = event.get("arguments_delta") or ""
|
||||
if call_id and name:
|
||||
if call_id not in tool_calls_by_id:
|
||||
tool_calls_by_id[call_id] = {"name": name, "arguments": ""}
|
||||
tool_calls_by_id[call_id]["arguments"] += arguments_delta
|
||||
if name == "ResponseSchema" and arguments_delta:
|
||||
has_response_schema_tool_call = True
|
||||
yield arguments_delta
|
||||
|
||||
elif event_type == "response.completed":
|
||||
break
|
||||
|
||||
elif event_type in ("response.error", "response.failed", "error"):
|
||||
err = event.get("error") or event.get("message") or str(event)
|
||||
raise RuntimeError(err)
|
||||
|
||||
# ============================================
|
||||
# EXECUTE NON-STRUCTURED TOOL CALLS (RECURSIVE)
|
||||
# ============================================
|
||||
|
||||
other_tool_calls = {
|
||||
cid: data
|
||||
for cid, data in tool_calls_by_id.items()
|
||||
if data.get("name") != "ResponseSchema"
|
||||
}
|
||||
if other_tool_calls and not has_response_schema_tool_call:
|
||||
parsed_tool_calls = []
|
||||
for call_id, data in other_tool_calls.items():
|
||||
args = data.get("arguments", "") if isinstance(data, dict) else ""
|
||||
parsed_tool_calls.append(
|
||||
OpenAIToolCall(
|
||||
id=call_id,
|
||||
type="function",
|
||||
function=OpenAIToolCallFunction(
|
||||
name=data.get("name", ""),
|
||||
arguments=args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
||||
parsed_tool_calls
|
||||
)
|
||||
|
||||
new_messages = [
|
||||
*messages,
|
||||
OpenAIAssistantMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[tc.model_dump() for tc in parsed_tool_calls],
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
|
||||
async for chunk in self._stream_codex_structured(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
response_format=response_schema,
|
||||
strict=strict,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
extra_body=extra_body,
|
||||
depth=depth + 1,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _stream_google_structured(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1538,6 +2261,15 @@ class LLMClient:
|
|||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.CODEX:
|
||||
return self._stream_codex_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self._stream_google_structured(
|
||||
model=model,
|
||||
|
|
|
|||
|
|
@ -117,3 +117,24 @@ def get_dall_e_3_quality_env():
|
|||
# Gpt Image 1.5 Quality
|
||||
def get_gpt_image_1_5_quality_env():
|
||||
return os.getenv("GPT_IMAGE_1_5_QUALITY")
|
||||
|
||||
|
||||
# Codex OAuth
|
||||
def get_codex_access_token_env():
|
||||
return os.getenv("CODEX_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def get_codex_refresh_token_env():
|
||||
return os.getenv("CODEX_REFRESH_TOKEN")
|
||||
|
||||
|
||||
def get_codex_token_expires_env():
|
||||
return os.getenv("CODEX_TOKEN_EXPIRES")
|
||||
|
||||
|
||||
def get_codex_account_id_env():
|
||||
return os.getenv("CODEX_ACCOUNT_ID")
|
||||
|
||||
|
||||
def get_codex_model_env():
|
||||
return os.getenv("CODEX_MODEL")
|
||||
|
|
|
|||
|
|
@ -2,12 +2,14 @@ from fastapi import HTTPException
|
|||
|
||||
from constants.llm import (
|
||||
DEFAULT_ANTHROPIC_MODEL,
|
||||
DEFAULT_CODEX_MODEL,
|
||||
DEFAULT_GOOGLE_MODEL,
|
||||
DEFAULT_OPENAI_MODEL,
|
||||
)
|
||||
from enums.llm_provider import LLMProvider
|
||||
from utils.get_env import (
|
||||
get_anthropic_model_env,
|
||||
get_codex_model_env,
|
||||
get_custom_model_env,
|
||||
get_google_model_env,
|
||||
get_llm_provider_env,
|
||||
|
|
@ -22,7 +24,7 @@ def get_llm_provider():
|
|||
except:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom",
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom, codex",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -46,6 +48,10 @@ def is_custom_llm_selected():
|
|||
return get_llm_provider() == LLMProvider.CUSTOM
|
||||
|
||||
|
||||
def is_codex_selected():
|
||||
return get_llm_provider() == LLMProvider.CODEX
|
||||
|
||||
|
||||
def get_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
|
|
@ -58,8 +64,10 @@ def get_model():
|
|||
return get_ollama_model_env()
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_model_env()
|
||||
elif selected_llm == LLMProvider.CODEX:
|
||||
return get_codex_model_env() or DEFAULT_CODEX_MODEL
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom",
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom, codex",
|
||||
)
|
||||
|
|
|
|||
0
electron/servers/fastapi/utils/oauth/__init__.py
Normal file
0
electron/servers/fastapi/utils/oauth/__init__.py
Normal file
348
electron/servers/fastapi/utils/oauth/openai_codex.py
Normal file
348
electron/servers/fastapi/utils/oauth/openai_codex.py
Normal file
|
|
@ -0,0 +1,348 @@
|
|||
"""
|
||||
OpenAI Codex (ChatGPT OAuth) flow — Python port of
|
||||
pi-mono-main/packages/ai/src/utils/oauth/openai-codex.ts
|
||||
|
||||
Handles PKCE authorization, local callback server, token exchange and refresh.
|
||||
No FastAPI dependencies; all HTTP is done with the standard library + httpx.
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from utils.oauth.pkce import generate_pkce
|
||||
|
||||
CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize"
|
||||
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
REDIRECT_URI = "http://localhost:1455/auth/callback"
|
||||
SCOPE = "openid profile email offline_access"
|
||||
JWT_CLAIM_PATH = "https://api.openai.com/auth"
|
||||
|
||||
CALLBACK_PORT = 1455
|
||||
|
||||
SUCCESS_HTML = b"""<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>Authentication successful</title>
|
||||
</head>
|
||||
<body>
|
||||
<p>Authentication successful. Return to your terminal / application to continue.</p>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class TokenSuccess:
|
||||
access: str
|
||||
refresh: str
|
||||
expires: int # Unix ms timestamp when the token expires
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenFailure:
|
||||
reason: str
|
||||
|
||||
|
||||
TokenResult = TokenSuccess | TokenFailure
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthorizationFlow:
|
||||
verifier: str
|
||||
state: str
|
||||
url: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JWT helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _decode_jwt_payload(token: str) -> Optional[dict]:
|
||||
"""Decode the payload segment of a JWT without verifying the signature."""
|
||||
try:
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
payload_b64 = parts[1]
|
||||
# Add padding if needed
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
decoded = base64.urlsafe_b64decode(payload_b64)
|
||||
return json.loads(decoded)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_account_id(access_token: str) -> Optional[str]:
|
||||
"""Extract the ChatGPT account ID from an access token JWT."""
|
||||
payload = _decode_jwt_payload(access_token)
|
||||
if not payload:
|
||||
return None
|
||||
auth_claims = payload.get(JWT_CLAIM_PATH)
|
||||
if not isinstance(auth_claims, dict):
|
||||
return None
|
||||
account_id = auth_claims.get("chatgpt_account_id")
|
||||
if isinstance(account_id, str) and account_id:
|
||||
return account_id
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization URL + PKCE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_authorization_flow(originator: str = "pi") -> AuthorizationFlow:
|
||||
"""Generate PKCE verifier/challenge, state, and the full authorization URL."""
|
||||
verifier, challenge = generate_pkce()
|
||||
state = secrets.token_hex(16)
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": CLIENT_ID,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
"scope": SCOPE,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
"id_token_add_organizations": "true",
|
||||
"codex_cli_simplified_flow": "true",
|
||||
"originator": originator,
|
||||
}
|
||||
url = f"{AUTHORIZE_URL}?{urlencode(params)}"
|
||||
return AuthorizationFlow(verifier=verifier, state=state, url=url)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local callback server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _CallbackHandler(BaseHTTPRequestHandler):
|
||||
"""Minimal HTTP handler that captures the OAuth callback code."""
|
||||
|
||||
def do_GET(self): # noqa: N802
|
||||
parsed = urlparse(self.path)
|
||||
if parsed.path != "/auth/callback":
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Not found")
|
||||
return
|
||||
|
||||
qs = parse_qs(parsed.query)
|
||||
state_vals = qs.get("state", [])
|
||||
code_vals = qs.get("code", [])
|
||||
|
||||
expected_state: str = self.server.expected_state # type: ignore[attr-defined]
|
||||
|
||||
if not state_vals or state_vals[0] != expected_state:
|
||||
self.send_response(400)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"State mismatch")
|
||||
return
|
||||
|
||||
if not code_vals:
|
||||
self.send_response(400)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Missing authorization code")
|
||||
return
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html; charset=utf-8")
|
||||
self.end_headers()
|
||||
self.wfile.write(SUCCESS_HTML)
|
||||
|
||||
self.server.captured_code = code_vals[0] # type: ignore[attr-defined]
|
||||
|
||||
def log_message(self, format, *args): # noqa: A002
|
||||
pass # suppress default stderr logging
|
||||
|
||||
|
||||
class OAuthCallbackServer:
|
||||
"""
|
||||
Wraps an HTTPServer that listens on port 1455 for the OAuth callback.
|
||||
Runs in a background daemon thread so it doesn't block the caller.
|
||||
"""
|
||||
|
||||
def __init__(self, state: str):
|
||||
self._state = state
|
||||
self._server: Optional[HTTPServer] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._started = threading.Event()
|
||||
self._cancelled = False
|
||||
|
||||
def start(self) -> bool:
|
||||
"""Start the background HTTP server. Returns True if successful."""
|
||||
try:
|
||||
server = HTTPServer(("0.0.0.0", CALLBACK_PORT), _CallbackHandler)
|
||||
server.expected_state = self._state # type: ignore[attr-defined]
|
||||
server.captured_code = None # type: ignore[attr-defined]
|
||||
server.timeout = 0.2 # short poll interval so we can check cancel
|
||||
self._server = server
|
||||
|
||||
def _serve():
|
||||
self._started.set()
|
||||
while not self._cancelled and server.captured_code is None:
|
||||
server.handle_request()
|
||||
server.server_close()
|
||||
|
||||
self._thread = threading.Thread(target=_serve, daemon=True)
|
||||
self._thread.start()
|
||||
self._started.wait(timeout=2)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
def get_code_nowait(self) -> Optional[str]:
|
||||
"""Non-blocking peek — returns the captured code or None immediately."""
|
||||
if self._server is None:
|
||||
return None
|
||||
return self._server.captured_code # type: ignore[attr-defined]
|
||||
|
||||
def wait_for_code(self, timeout_seconds: int = 120) -> Optional[str]:
|
||||
"""
|
||||
Block until the callback delivers a code or timeout / cancellation.
|
||||
Returns the authorization code or None.
|
||||
"""
|
||||
if self._server is None:
|
||||
return None
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
while time.monotonic() < deadline:
|
||||
if self._cancelled:
|
||||
return None
|
||||
code = self._server.captured_code # type: ignore[attr-defined]
|
||||
if code:
|
||||
return code
|
||||
time.sleep(0.1)
|
||||
return None
|
||||
|
||||
def cancel(self):
|
||||
self._cancelled = True
|
||||
|
||||
def close(self):
|
||||
self._cancelled = True
|
||||
if self._thread:
|
||||
self._thread.join(timeout=2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token exchange / refresh (sync — called from thread or FastAPI background)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def exchange_authorization_code(
|
||||
code: str,
|
||||
verifier: str,
|
||||
redirect_uri: str = REDIRECT_URI,
|
||||
) -> TokenResult:
|
||||
"""Exchange an authorization code for access + refresh tokens."""
|
||||
try:
|
||||
response = httpx.post(
|
||||
TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": CLIENT_ID,
|
||||
"code": code,
|
||||
"code_verifier": verifier,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if not response.is_success:
|
||||
return TokenFailure(reason=f"HTTP {response.status_code}: {response.text[:200]}")
|
||||
|
||||
body = response.json()
|
||||
access = body.get("access_token")
|
||||
refresh = body.get("refresh_token")
|
||||
expires_in = body.get("expires_in")
|
||||
|
||||
if not access or not refresh or not isinstance(expires_in, (int, float)):
|
||||
return TokenFailure(reason=f"Token response missing fields: {list(body.keys())}")
|
||||
|
||||
expires_ms = int(time.time() * 1000) + int(expires_in) * 1000
|
||||
return TokenSuccess(access=access, refresh=refresh, expires=expires_ms)
|
||||
except Exception as exc:
|
||||
return TokenFailure(reason=str(exc))
|
||||
|
||||
|
||||
def refresh_access_token(refresh_token: str) -> TokenResult:
|
||||
"""Use a refresh token to obtain a new access token."""
|
||||
try:
|
||||
response = httpx.post(
|
||||
TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CLIENT_ID,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if not response.is_success:
|
||||
return TokenFailure(reason=f"HTTP {response.status_code}: {response.text[:200]}")
|
||||
|
||||
body = response.json()
|
||||
access = body.get("access_token")
|
||||
refresh = body.get("refresh_token")
|
||||
expires_in = body.get("expires_in")
|
||||
|
||||
if not access or not refresh or not isinstance(expires_in, (int, float)):
|
||||
return TokenFailure(reason=f"Token refresh response missing fields: {list(body.keys())}")
|
||||
|
||||
expires_ms = int(time.time() * 1000) + int(expires_in) * 1000
|
||||
return TokenSuccess(access=access, refresh=refresh, expires=expires_ms)
|
||||
except Exception as exc:
|
||||
return TokenFailure(reason=str(exc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parsing helpers (for manual code paste / redirect URL fallback)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_authorization_input(raw: str) -> dict:
|
||||
"""
|
||||
Accept a variety of user-pasted inputs:
|
||||
- Full redirect URL: http://localhost:1455/auth/callback?code=X&state=Y
|
||||
- code#state shorthand
|
||||
- Raw query string: code=X&state=Y
|
||||
- Bare code value
|
||||
Returns a dict with optional 'code' and 'state' keys.
|
||||
"""
|
||||
value = raw.strip()
|
||||
if not value:
|
||||
return {}
|
||||
|
||||
try:
|
||||
parsed = urlparse(value)
|
||||
if parsed.scheme in ("http", "https"):
|
||||
qs = parse_qs(parsed.query)
|
||||
return {
|
||||
k: qs[k][0]
|
||||
for k in ("code", "state")
|
||||
if k in qs
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if "#" in value:
|
||||
parts = value.split("#", 1)
|
||||
return {"code": parts[0], "state": parts[1]}
|
||||
|
||||
if "code=" in value:
|
||||
qs = parse_qs(value)
|
||||
return {k: qs[k][0] for k in ("code", "state") if k in qs}
|
||||
|
||||
return {"code": value}
|
||||
23
electron/servers/fastapi/utils/oauth/pkce.py
Normal file
23
electron/servers/fastapi/utils/oauth/pkce.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""
|
||||
PKCE utilities using Python's secrets and hashlib.
|
||||
Python port of pi-mono-main/packages/ai/src/utils/oauth/pkce.ts
|
||||
"""
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
|
||||
def generate_pkce() -> tuple[str, str]:
|
||||
"""
|
||||
Generate PKCE code verifier and challenge (S256 method).
|
||||
|
||||
Returns:
|
||||
(verifier, challenge) — both base64url-encoded, no padding
|
||||
"""
|
||||
verifier_bytes = secrets.token_bytes(32)
|
||||
verifier = base64.urlsafe_b64encode(verifier_bytes).rstrip(b"=").decode()
|
||||
|
||||
digest = hashlib.sha256(verifier.encode()).digest()
|
||||
challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
|
||||
return verifier, challenge
|
||||
|
|
@ -103,3 +103,24 @@ def set_dall_e_3_quality_env(value):
|
|||
|
||||
def set_gpt_image_1_5_quality_env(value):
|
||||
os.environ["GPT_IMAGE_1_5_QUALITY"] = value
|
||||
|
||||
|
||||
# Codex OAuth
|
||||
def set_codex_access_token_env(value: str):
|
||||
os.environ["CODEX_ACCESS_TOKEN"] = value
|
||||
|
||||
|
||||
def set_codex_refresh_token_env(value: str):
|
||||
os.environ["CODEX_REFRESH_TOKEN"] = value
|
||||
|
||||
|
||||
def set_codex_token_expires_env(value: str):
|
||||
os.environ["CODEX_TOKEN_EXPIRES"] = value
|
||||
|
||||
|
||||
def set_codex_account_id_env(value: str):
|
||||
os.environ["CODEX_ACCOUNT_ID"] = value
|
||||
|
||||
|
||||
def set_codex_model_env(value: str):
|
||||
os.environ["CODEX_MODEL"] = value
|
||||
|
|
|
|||
|
|
@ -28,6 +28,11 @@ from utils.get_env import (
|
|||
get_pixabay_api_key_env,
|
||||
get_extended_reasoning_env,
|
||||
get_web_grounding_env,
|
||||
get_codex_access_token_env,
|
||||
get_codex_refresh_token_env,
|
||||
get_codex_token_expires_env,
|
||||
get_codex_account_id_env,
|
||||
get_codex_model_env,
|
||||
)
|
||||
from utils.parsers import parse_bool_or_none
|
||||
from utils.set_env import (
|
||||
|
|
@ -55,6 +60,11 @@ from utils.set_env import (
|
|||
set_pixabay_api_key_env,
|
||||
set_tool_calls_env,
|
||||
set_web_grounding_env,
|
||||
set_codex_access_token_env,
|
||||
set_codex_refresh_token_env,
|
||||
set_codex_token_expires_env,
|
||||
set_codex_account_id_env,
|
||||
set_codex_model_env,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -118,6 +128,11 @@ def get_user_config():
|
|||
if existing_config.WEB_GROUNDING is not None
|
||||
else (parse_bool_or_none(get_web_grounding_env()) or False)
|
||||
),
|
||||
CODEX_MODEL=existing_config.CODEX_MODEL or get_codex_model_env(),
|
||||
CODEX_ACCESS_TOKEN=existing_config.CODEX_ACCESS_TOKEN or get_codex_access_token_env(),
|
||||
CODEX_REFRESH_TOKEN=existing_config.CODEX_REFRESH_TOKEN or get_codex_refresh_token_env(),
|
||||
CODEX_TOKEN_EXPIRES=existing_config.CODEX_TOKEN_EXPIRES or get_codex_token_expires_env(),
|
||||
CODEX_ACCOUNT_ID=existing_config.CODEX_ACCOUNT_ID or get_codex_account_id_env(),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -171,3 +186,43 @@ def update_env_with_user_config():
|
|||
set_extended_reasoning_env(str(user_config.EXTENDED_REASONING))
|
||||
if user_config.WEB_GROUNDING is not None:
|
||||
set_web_grounding_env(str(user_config.WEB_GROUNDING))
|
||||
if user_config.CODEX_MODEL:
|
||||
set_codex_model_env(user_config.CODEX_MODEL)
|
||||
if user_config.CODEX_ACCESS_TOKEN:
|
||||
set_codex_access_token_env(user_config.CODEX_ACCESS_TOKEN)
|
||||
if user_config.CODEX_REFRESH_TOKEN:
|
||||
set_codex_refresh_token_env(user_config.CODEX_REFRESH_TOKEN)
|
||||
if user_config.CODEX_TOKEN_EXPIRES:
|
||||
set_codex_token_expires_env(user_config.CODEX_TOKEN_EXPIRES)
|
||||
if user_config.CODEX_ACCOUNT_ID:
|
||||
set_codex_account_id_env(user_config.CODEX_ACCOUNT_ID)
|
||||
|
||||
|
||||
def save_codex_tokens_to_user_config() -> None:
|
||||
"""
|
||||
Write the current in-memory Codex OAuth token env vars back to userConfig.json
|
||||
so they survive container restarts. Called after a successful token exchange
|
||||
and on logout (where the env vars have already been cleared to "").
|
||||
"""
|
||||
user_config_path = get_user_config_path_env()
|
||||
if not user_config_path:
|
||||
return
|
||||
|
||||
existing: dict = {}
|
||||
try:
|
||||
if os.path.exists(user_config_path):
|
||||
with open(user_config_path, "r") as f:
|
||||
existing = json.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
existing["CODEX_ACCESS_TOKEN"] = get_codex_access_token_env()
|
||||
existing["CODEX_REFRESH_TOKEN"] = get_codex_refresh_token_env()
|
||||
existing["CODEX_TOKEN_EXPIRES"] = get_codex_token_expires_env()
|
||||
existing["CODEX_ACCOUNT_ID"] = get_codex_account_id_env()
|
||||
|
||||
try:
|
||||
with open(user_config_path, "w") as f:
|
||||
json.dump(existing, f)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -87,6 +87,11 @@ export async function POST(request: Request) {
|
|||
userConfig.WEB_GROUNDING === undefined
|
||||
? existingConfig.WEB_GROUNDING
|
||||
: userConfig.WEB_GROUNDING,
|
||||
CODEX_MODEL: userConfig.CODEX_MODEL || existingConfig.CODEX_MODEL,
|
||||
CODEX_ACCESS_TOKEN: existingConfig.CODEX_ACCESS_TOKEN,
|
||||
CODEX_REFRESH_TOKEN: existingConfig.CODEX_REFRESH_TOKEN,
|
||||
CODEX_TOKEN_EXPIRES: existingConfig.CODEX_TOKEN_EXPIRES,
|
||||
CODEX_ACCOUNT_ID: existingConfig.CODEX_ACCOUNT_ID,
|
||||
USE_CUSTOM_URL:
|
||||
userConfig.USE_CUSTOM_URL === undefined
|
||||
? existingConfig.USE_CUSTOM_URL
|
||||
|
|
|
|||
430
electron/servers/nextjs/components/CodexConfig.tsx
Normal file
430
electron/servers/nextjs/components/CodexConfig.tsx
Normal file
|
|
@ -0,0 +1,430 @@
|
|||
"use client";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
Check,
|
||||
ChevronsUpDown,
|
||||
Loader2,
|
||||
LogIn,
|
||||
LogOut,
|
||||
RefreshCw,
|
||||
UserCheck,
|
||||
} from "lucide-react";
|
||||
import { Button } from "./ui/button";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "./ui/command";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "sonner";
|
||||
import { getApiUrl } from "@/utils/api";
|
||||
|
||||
interface CodexConfigProps {
|
||||
codexModel: string;
|
||||
onInputChange: (value: string | boolean, field: string) => void;
|
||||
}
|
||||
|
||||
type AuthStatus = "checking" | "unauthenticated" | "polling" | "authenticated";
|
||||
|
||||
interface StatusResponse {
|
||||
status: string;
|
||||
account_id?: string;
|
||||
detail?: string;
|
||||
}
|
||||
|
||||
interface CodexModel {
|
||||
id: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
const CHATGPT_MODELS: CodexModel[] = [
|
||||
{ id: "gpt-5.1", name: "GPT-5.1" },
|
||||
{ id: "gpt-5.1-codex-max", name: "GPT-5.1 Codex Max" },
|
||||
{ id: "gpt-5.1-codex-mini", name: "GPT-5.1 Codex Mini" },
|
||||
{ id: "gpt-5.2", name: "GPT-5.2" },
|
||||
{ id: "gpt-5.2-codex", name: "GPT-5.2 Codex" },
|
||||
{ id: "gpt-5.3-codex", name: "GPT-5.3 Codex" },
|
||||
{ id: "gpt-5.3-codex-spark", name: "GPT-5.3 Codex Spark (Free)" },
|
||||
];
|
||||
|
||||
const DEFAULT_CODEX_MODEL = "gpt-5.3-codex-spark";
|
||||
|
||||
export default function CodexConfig({
|
||||
codexModel,
|
||||
onInputChange,
|
||||
}: CodexConfigProps) {
|
||||
const [authStatus, setAuthStatus] = useState<AuthStatus>("checking");
|
||||
const [accountId, setAccountId] = useState<string | null>(null);
|
||||
const [sessionId, setSessionId] = useState<string | null>(null);
|
||||
const [manualCode, setManualCode] = useState("");
|
||||
const [isExchanging, setIsExchanging] = useState(false);
|
||||
const [isLoggingOut, setIsLoggingOut] = useState(false);
|
||||
const [isRefreshing, setIsRefreshing] = useState(false);
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
const pollIntervalRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
|
||||
const stopPolling = () => {
|
||||
if (pollIntervalRef.current) {
|
||||
clearInterval(pollIntervalRef.current);
|
||||
pollIntervalRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
// Check current auth state on mount
|
||||
useEffect(() => {
|
||||
checkCurrentAuthStatus();
|
||||
return () => stopPolling();
|
||||
}, []);
|
||||
|
||||
const checkCurrentAuthStatus = async () => {
|
||||
try {
|
||||
const res = await fetch(getApiUrl("api/v1/ppt/codex/auth/status"));
|
||||
if (!res.ok) {
|
||||
setAuthStatus("unauthenticated");
|
||||
return;
|
||||
}
|
||||
const data: StatusResponse = await res.json();
|
||||
if (data.status === "authenticated") {
|
||||
setAuthStatus("authenticated");
|
||||
setAccountId(data.account_id ?? null);
|
||||
} else {
|
||||
setAuthStatus("unauthenticated");
|
||||
}
|
||||
} catch {
|
||||
setAuthStatus("unauthenticated");
|
||||
}
|
||||
};
|
||||
|
||||
const handleSignIn = async () => {
|
||||
try {
|
||||
const res = await fetch(getApiUrl("api/v1/ppt/codex/auth/initiate"), {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) throw new Error("Failed to initiate auth");
|
||||
const data = await res.json();
|
||||
const { session_id, url } = data;
|
||||
|
||||
setSessionId(session_id);
|
||||
setAuthStatus("polling");
|
||||
window.open(url, "_blank", "noopener,noreferrer");
|
||||
|
||||
// Start polling the status endpoint every 2s
|
||||
pollIntervalRef.current = setInterval(async () => {
|
||||
try {
|
||||
const pollRes = await fetch(
|
||||
getApiUrl(`api/v1/ppt/codex/auth/status/${session_id}`)
|
||||
);
|
||||
if (!pollRes.ok) return;
|
||||
const pollData: StatusResponse = await pollRes.json();
|
||||
|
||||
if (pollData.status === "success") {
|
||||
stopPolling();
|
||||
setAuthStatus("authenticated");
|
||||
setAccountId(pollData.account_id ?? null);
|
||||
setSessionId(null);
|
||||
// Set a sensible default model if none chosen
|
||||
if (!codexModel) {
|
||||
onInputChange(DEFAULT_CODEX_MODEL, "codex_model");
|
||||
}
|
||||
toast.success("Signed in to ChatGPT successfully");
|
||||
} else if (pollData.status === "failed") {
|
||||
stopPolling();
|
||||
setAuthStatus("unauthenticated");
|
||||
toast.error("Authentication failed. Please try again.");
|
||||
}
|
||||
} catch {
|
||||
// keep polling on transient errors
|
||||
}
|
||||
}, 2000);
|
||||
} catch (err) {
|
||||
toast.error("Failed to start sign-in flow");
|
||||
setAuthStatus("unauthenticated");
|
||||
}
|
||||
};
|
||||
|
||||
const handleManualExchange = async () => {
|
||||
if (!sessionId || !manualCode.trim()) return;
|
||||
setIsExchanging(true);
|
||||
try {
|
||||
const res = await fetch(getApiUrl("api/v1/ppt/codex/auth/exchange"), {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ session_id: sessionId, code: manualCode.trim() }),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const err = await res.json().catch(() => ({}));
|
||||
throw new Error(err.detail || "Exchange failed");
|
||||
}
|
||||
const data = await res.json();
|
||||
stopPolling();
|
||||
setAuthStatus("authenticated");
|
||||
setAccountId(data.account_id);
|
||||
setSessionId(null);
|
||||
setManualCode("");
|
||||
if (!codexModel) {
|
||||
onInputChange(DEFAULT_CODEX_MODEL, "codex_model");
|
||||
}
|
||||
toast.success("Signed in to ChatGPT successfully");
|
||||
} catch (err: any) {
|
||||
toast.error(err.message || "Code exchange failed");
|
||||
} finally {
|
||||
setIsExchanging(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCancelPolling = () => {
|
||||
stopPolling();
|
||||
setSessionId(null);
|
||||
setManualCode("");
|
||||
setAuthStatus("unauthenticated");
|
||||
};
|
||||
|
||||
const handleSignOut = async () => {
|
||||
setIsLoggingOut(true);
|
||||
try {
|
||||
await fetch(getApiUrl("api/v1/ppt/codex/auth/logout"), { method: "POST" });
|
||||
setAuthStatus("unauthenticated");
|
||||
setAccountId(null);
|
||||
onInputChange("", "codex_model");
|
||||
toast.success("Signed out from ChatGPT");
|
||||
} catch {
|
||||
toast.error("Sign out failed");
|
||||
} finally {
|
||||
setIsLoggingOut(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRefreshToken = async () => {
|
||||
setIsRefreshing(true);
|
||||
try {
|
||||
const res = await fetch(getApiUrl("api/v1/ppt/codex/auth/refresh"), {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) throw new Error("Refresh failed");
|
||||
const data = await res.json();
|
||||
if (data.account_id) setAccountId(data.account_id);
|
||||
toast.success("Token refreshed successfully");
|
||||
} catch {
|
||||
toast.error("Token refresh failed. Please sign in again.");
|
||||
setAuthStatus("unauthenticated");
|
||||
} finally {
|
||||
setIsRefreshing(false);
|
||||
}
|
||||
};
|
||||
|
||||
// ─── Checking ────────────────────────────────────────────────────────────
|
||||
if (authStatus === "checking") {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12 gap-3 text-gray-500">
|
||||
<Loader2 className="w-5 h-5 animate-spin" />
|
||||
<span className="text-sm">Checking authentication status…</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Polling / waiting ───────────────────────────────────────────────────
|
||||
if (authStatus === "polling") {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="flex flex-col items-center gap-4 py-8 px-4 bg-blue-50 rounded-xl border border-blue-100">
|
||||
<Loader2 className="w-8 h-8 text-blue-500 animate-spin" />
|
||||
<div className="text-center">
|
||||
<p className="text-sm font-medium text-blue-900">
|
||||
Waiting for authentication…
|
||||
</p>
|
||||
<p className="text-xs text-blue-600 mt-1">
|
||||
Complete the sign-in in the browser tab that just opened.
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleCancelPolling}
|
||||
className="text-gray-600"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Manual fallback */}
|
||||
<div className="space-y-3">
|
||||
<p className="text-sm font-medium text-gray-700">
|
||||
Didn't get redirected automatically?
|
||||
</p>
|
||||
<p className="text-xs text-gray-500">
|
||||
After completing the sign-in, paste the full redirect URL or
|
||||
authorization code below.
|
||||
</p>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Paste redirect URL or authorization code…"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors text-sm"
|
||||
value={manualCode}
|
||||
onChange={(e) => setManualCode(e.target.value)}
|
||||
/>
|
||||
<Button
|
||||
onClick={handleManualExchange}
|
||||
disabled={isExchanging || !manualCode.trim()}
|
||||
className="w-full"
|
||||
>
|
||||
{isExchanging ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Loader2 className="w-4 h-4 animate-spin" />
|
||||
Exchanging…
|
||||
</div>
|
||||
) : (
|
||||
"Submit Code"
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Authenticated ───────────────────────────────────────────────────────
|
||||
if (authStatus === "authenticated") {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* Account info */}
|
||||
<div className="flex items-center gap-3 p-4 bg-green-50 rounded-xl border border-green-100">
|
||||
<UserCheck className="w-6 h-6 text-green-600 shrink-0" />
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="text-sm font-medium text-green-900">
|
||||
Signed in to ChatGPT
|
||||
</p>
|
||||
{accountId && (
|
||||
<p className="text-xs text-green-700 truncate mt-0.5">
|
||||
Account: {accountId}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex gap-2 shrink-0">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleRefreshToken}
|
||||
disabled={isRefreshing}
|
||||
title="Refresh access token"
|
||||
className="text-gray-600 border-gray-300"
|
||||
>
|
||||
{isRefreshing ? (
|
||||
<Loader2 className="w-3.5 h-3.5 animate-spin" />
|
||||
) : (
|
||||
<RefreshCw className="w-3.5 h-3.5" />
|
||||
)}
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleSignOut}
|
||||
disabled={isLoggingOut}
|
||||
className="text-red-600 border-red-200 hover:bg-red-50"
|
||||
>
|
||||
{isLoggingOut ? (
|
||||
<Loader2 className="w-3.5 h-3.5 animate-spin" />
|
||||
) : (
|
||||
<LogOut className="w-3.5 h-3.5" />
|
||||
)}
|
||||
<span className="ml-1.5">Sign out</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Model selection */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-700 mb-3">
|
||||
Select ChatGPT Model
|
||||
</label>
|
||||
<Popover open={openModelSelect} onOpenChange={setOpenModelSelect}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
aria-expanded={openModelSelect}
|
||||
className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between"
|
||||
>
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{codexModel
|
||||
? (CHATGPT_MODELS.find((m) => m.id === codexModel)?.name ?? codexModel)
|
||||
: "Select a model"}
|
||||
</span>
|
||||
<ChevronsUpDown className="w-4 h-4 text-gray-500" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className="p-0"
|
||||
align="start"
|
||||
style={{ width: "var(--radix-popover-trigger-width)" }}
|
||||
>
|
||||
<Command>
|
||||
<CommandInput placeholder="Search models…" />
|
||||
<CommandList>
|
||||
<CommandEmpty>No model found.</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{CHATGPT_MODELS.map((model) => (
|
||||
<CommandItem
|
||||
key={model.id}
|
||||
value={model.id}
|
||||
onSelect={(value) => {
|
||||
onInputChange(value, "codex_model");
|
||||
setOpenModelSelect(false);
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
"mr-2 h-4 w-4",
|
||||
codexModel === model.id ? "opacity-100" : "opacity-0"
|
||||
)}
|
||||
/>
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{model.name}
|
||||
</span>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
<p className="mt-2 text-xs text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400" />
|
||||
Model availability depends on your ChatGPT subscription tier.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Unauthenticated ─────────────────────────────────────────────────────
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="p-4 bg-gray-50 rounded-xl border border-gray-200">
|
||||
<h3 className="text-sm font-semibold text-gray-900 mb-1">
|
||||
ChatGPT Plus / Pro
|
||||
</h3>
|
||||
<p className="text-sm text-gray-600">
|
||||
Sign in with your OpenAI account to use ChatGPT models directly via
|
||||
OAuth — no API key required.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
onClick={handleSignIn}
|
||||
className="w-full h-12 gap-2 bg-[#10a37f] hover:bg-[#0e8f6f] text-white"
|
||||
>
|
||||
<LogIn className="w-4 h-4" />
|
||||
Sign in with ChatGPT
|
||||
</Button>
|
||||
|
||||
<p className="text-xs text-gray-500 flex items-start gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400 mt-1.5 shrink-0" />
|
||||
A browser window will open for you to authenticate with your OpenAI
|
||||
account. Your credentials are stored locally and never shared.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -19,6 +19,7 @@ import GoogleConfig from "./GoogleConfig";
|
|||
import AnthropicConfig from "./AnthropicConfig";
|
||||
import OllamaConfig from "./OllamaConfig";
|
||||
import CustomConfig from "./CustomConfig";
|
||||
import CodexConfig from "./CodexConfig";
|
||||
import {
|
||||
updateLLMConfig,
|
||||
changeProvider as changeProviderUtil,
|
||||
|
|
@ -95,7 +96,8 @@ export default function LLMProviderSelection({
|
|||
(llmConfig.LLM === "google" && !llmConfig.GOOGLE_MODEL) ||
|
||||
(llmConfig.LLM === "ollama" && !llmConfig.OLLAMA_MODEL) ||
|
||||
(llmConfig.LLM === "custom" && !llmConfig.CUSTOM_MODEL) ||
|
||||
(llmConfig.LLM === "anthropic" && !llmConfig.ANTHROPIC_MODEL);
|
||||
(llmConfig.LLM === "anthropic" && !llmConfig.ANTHROPIC_MODEL) ||
|
||||
(llmConfig.LLM === "codex" && !llmConfig.CODEX_MODEL);
|
||||
|
||||
const needsProviderApiKey =
|
||||
(llmConfig.LLM === "openai" && !llmConfig.OPENAI_API_KEY) ||
|
||||
|
|
@ -335,12 +337,13 @@ export default function LLMProviderSelection({
|
|||
onValueChange={handleProviderChange}
|
||||
className="w-full"
|
||||
>
|
||||
<TabsList className="grid w-full grid-cols-5 bg-transparent h-10">
|
||||
<TabsList className="grid w-full grid-cols-6 bg-transparent h-10">
|
||||
<TabsTrigger value="openai">OpenAI</TabsTrigger>
|
||||
<TabsTrigger value="google">Google</TabsTrigger>
|
||||
<TabsTrigger value="anthropic">Anthropic</TabsTrigger>
|
||||
<TabsTrigger value="ollama">Ollama</TabsTrigger>
|
||||
<TabsTrigger value="custom">Custom</TabsTrigger>
|
||||
<TabsTrigger value="codex">ChatGPT</TabsTrigger>
|
||||
</TabsList>
|
||||
</Tabs>
|
||||
</div>
|
||||
|
|
@ -404,6 +407,14 @@ export default function LLMProviderSelection({
|
|||
onInputChange={input_field_changed}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
||||
{/* ChatGPT / Codex Content */}
|
||||
<TabsContent value="codex" className="mt-6">
|
||||
<CodexConfig
|
||||
codexModel={llmConfig.CODEX_MODEL || ""}
|
||||
onInputChange={input_field_changed}
|
||||
/>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
|
||||
{/* Image Generation Toggle */}
|
||||
|
|
@ -652,6 +663,8 @@ export default function LLMProviderSelection({
|
|||
? llmConfig.GOOGLE_MODEL ?? "xxxxx"
|
||||
: llmConfig.LLM === "openai"
|
||||
? llmConfig.OPENAI_MODEL ?? "xxxxx"
|
||||
: llmConfig.LLM === "codex"
|
||||
? llmConfig.CODEX_MODEL ?? "xxxxx"
|
||||
: "xxxxx"}{" "}
|
||||
for text generation{" "}
|
||||
{isImageGenerationDisabled ? (
|
||||
|
|
|
|||
|
|
@ -43,6 +43,13 @@ export interface LLMConfig {
|
|||
EXTENDED_REASONING?: boolean;
|
||||
WEB_GROUNDING?: boolean;
|
||||
|
||||
// Codex OAuth (ChatGPT)
|
||||
CODEX_MODEL?: string;
|
||||
CODEX_ACCESS_TOKEN?: string;
|
||||
CODEX_REFRESH_TOKEN?: string;
|
||||
CODEX_TOKEN_EXPIRES?: string;
|
||||
CODEX_ACCOUNT_ID?: string;
|
||||
|
||||
// Only used in UI settings
|
||||
USE_CUSTOM_URL?: boolean;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -116,4 +116,9 @@ export const LLM_PROVIDERS: Record<string, LLMProviderOption> = {
|
|||
label: "Custom",
|
||||
description: "Custom LLM",
|
||||
},
|
||||
codex: {
|
||||
value: "codex",
|
||||
label: "ChatGPT",
|
||||
description: "ChatGPT Plus/Pro via OAuth",
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ export const updateLLMConfig = (
|
|||
comfyui_workflow: "COMFYUI_WORKFLOW",
|
||||
dall_e_3_quality: "DALL_E_3_QUALITY",
|
||||
gpt_image_1_5_quality: "GPT_IMAGE_1_5_QUALITY",
|
||||
codex_model: "CODEX_MODEL",
|
||||
};
|
||||
|
||||
const configKey = fieldMappings[field];
|
||||
|
|
@ -78,7 +79,7 @@ export const changeProvider = (
|
|||
} else if (provider === "google") {
|
||||
newConfig.IMAGE_PROVIDER = "gemini_flash";
|
||||
} else {
|
||||
newConfig.IMAGE_PROVIDER = "pexels"; // default for ollama and custom
|
||||
newConfig.IMAGE_PROVIDER = "pexels"; // default for ollama, custom, codex
|
||||
}
|
||||
|
||||
return newConfig;
|
||||
|
|
|
|||
|
|
@ -67,6 +67,11 @@ export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
|
|||
llmConfig.CUSTOM_MODEL !== null &&
|
||||
llmConfig.CUSTOM_MODEL !== undefined;
|
||||
|
||||
const isCodexConfigValid =
|
||||
llmConfig.CODEX_MODEL !== "" &&
|
||||
llmConfig.CODEX_MODEL !== null &&
|
||||
llmConfig.CODEX_MODEL !== undefined;
|
||||
|
||||
const shouldValidateImages = !llmConfig.DISABLE_IMAGE_GENERATION;
|
||||
|
||||
const isImageConfigValid = () => {
|
||||
|
|
@ -104,6 +109,8 @@ export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
|
|||
? isOllamaConfigValid
|
||||
: llmConfig.LLM === "custom"
|
||||
? isCustomConfigValid
|
||||
: llmConfig.LLM === "codex"
|
||||
? isCodexConfigValid
|
||||
: false;
|
||||
|
||||
return isLLMConfigValid && isImageConfigValid();
|
||||
|
|
|
|||
278
servers/fastapi/api/v1/ppt/endpoints/codex_auth.py
Normal file
278
servers/fastapi/api/v1/ppt/endpoints/codex_auth.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
"""
|
||||
OpenAI Codex OAuth endpoints.
|
||||
|
||||
Flow:
|
||||
1. POST /codex/auth/initiate — start the flow, get back an auth URL + session_id
|
||||
2. Browser opens the URL, user authenticates with OpenAI
|
||||
3. OpenAI redirects to http://localhost:1455/auth/callback (captured by local server)
|
||||
4. GET /codex/auth/status/{session_id} — poll until code captured; exchanges and stores tokens
|
||||
5. POST /codex/auth/exchange — manual fallback if browser callback didn't fire
|
||||
6. POST /codex/auth/refresh — refresh a stored token
|
||||
"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils.oauth.openai_codex import (
|
||||
OAuthCallbackServer,
|
||||
TokenSuccess,
|
||||
create_authorization_flow,
|
||||
exchange_authorization_code,
|
||||
get_account_id,
|
||||
parse_authorization_input,
|
||||
refresh_access_token,
|
||||
)
|
||||
from utils.get_env import (
|
||||
get_codex_access_token_env,
|
||||
get_codex_refresh_token_env,
|
||||
get_codex_token_expires_env,
|
||||
)
|
||||
from utils.set_env import (
|
||||
set_codex_access_token_env,
|
||||
set_codex_account_id_env,
|
||||
set_codex_refresh_token_env,
|
||||
set_codex_token_expires_env,
|
||||
set_codex_model_env,
|
||||
)
|
||||
from utils.user_config import save_codex_tokens_to_user_config
|
||||
|
||||
CODEX_AUTH_ROUTER = APIRouter(prefix="/codex/auth", tags=["Codex OAuth"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory session store {session_id: {"verifier": str, "state": str, "server": OAuthCallbackServer}}
|
||||
# Sessions are short-lived; garbage-collected when consumed.
|
||||
# ---------------------------------------------------------------------------
|
||||
_sessions: dict[str, dict] = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class InitiateResponse(BaseModel):
|
||||
session_id: str
|
||||
url: str
|
||||
instructions: str
|
||||
|
||||
|
||||
class StatusResponse(BaseModel):
|
||||
status: str # "pending" | "success" | "failed"
|
||||
account_id: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
|
||||
|
||||
class ExchangeRequest(BaseModel):
|
||||
session_id: str
|
||||
code: str # raw code OR full redirect URL OR code#state shorthand
|
||||
|
||||
|
||||
class ExchangeResponse(BaseModel):
|
||||
account_id: str
|
||||
|
||||
|
||||
class RefreshResponse(BaseModel):
|
||||
account_id: Optional[str]
|
||||
detail: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _store_token(result: TokenSuccess) -> Optional[str]:
|
||||
"""Persist token fields in env vars and userConfig.json. Returns account_id or None."""
|
||||
set_codex_access_token_env(result.access)
|
||||
set_codex_refresh_token_env(result.refresh)
|
||||
set_codex_token_expires_env(str(result.expires))
|
||||
account_id = get_account_id(result.access)
|
||||
if account_id:
|
||||
set_codex_account_id_env(account_id)
|
||||
save_codex_tokens_to_user_config()
|
||||
return account_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@CODEX_AUTH_ROUTER.post("/initiate", response_model=InitiateResponse)
|
||||
async def initiate_codex_auth():
|
||||
"""
|
||||
Start the OpenAI Codex OAuth flow.
|
||||
|
||||
Returns an authorization URL to open in the browser and a session_id to use
|
||||
when polling /status or calling /exchange. A local HTTP server is started
|
||||
on port 1455 to receive the redirect automatically.
|
||||
"""
|
||||
flow = create_authorization_flow()
|
||||
server = OAuthCallbackServer(state=flow.state)
|
||||
server_started = server.start()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
_sessions[session_id] = {
|
||||
"verifier": flow.verifier,
|
||||
"state": flow.state,
|
||||
"server": server,
|
||||
"server_started": server_started,
|
||||
}
|
||||
|
||||
instructions = (
|
||||
"Open the URL in your browser and complete the OpenAI login. "
|
||||
+ (
|
||||
"The callback will be captured automatically."
|
||||
if server_started
|
||||
else "Port 1455 could not be bound — paste the redirect URL or code into /exchange."
|
||||
)
|
||||
)
|
||||
|
||||
return InitiateResponse(
|
||||
session_id=session_id,
|
||||
url=flow.url,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
|
||||
@CODEX_AUTH_ROUTER.get("/status/{session_id}", response_model=StatusResponse)
|
||||
async def poll_codex_auth_status(session_id: str):
|
||||
"""
|
||||
Poll for the result of an ongoing OAuth flow.
|
||||
|
||||
Returns {"status": "pending"} until the callback server captures the code.
|
||||
On success the tokens are stored in environment variables and the session
|
||||
is cleaned up.
|
||||
"""
|
||||
session = _sessions.get(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found or already consumed")
|
||||
|
||||
server: OAuthCallbackServer = session["server"]
|
||||
|
||||
# Non-blocking peek — check whether the callback server already received a code
|
||||
code = server.get_code_nowait() if session.get("server_started") else None
|
||||
|
||||
if code is None:
|
||||
return StatusResponse(status="pending")
|
||||
|
||||
# We have a code — exchange it
|
||||
verifier: str = session["verifier"]
|
||||
result = exchange_authorization_code(code, verifier)
|
||||
|
||||
# Clean up session
|
||||
server.close()
|
||||
_sessions.pop(session_id, None)
|
||||
|
||||
if not isinstance(result, TokenSuccess):
|
||||
return StatusResponse(status="failed", detail=result.reason)
|
||||
|
||||
account_id = _store_token(result)
|
||||
return StatusResponse(status="success", account_id=account_id)
|
||||
|
||||
|
||||
@CODEX_AUTH_ROUTER.post("/exchange", response_model=ExchangeResponse)
|
||||
async def exchange_codex_code(body: ExchangeRequest):
|
||||
"""
|
||||
Manual code exchange fallback.
|
||||
|
||||
Accepts the session_id from /initiate and either:
|
||||
- a bare authorization code
|
||||
- the full redirect URL (http://localhost:1455/auth/callback?code=…&state=…)
|
||||
- the code#state shorthand
|
||||
|
||||
Exchanges the code for tokens and stores them in environment variables.
|
||||
"""
|
||||
session = _sessions.get(body.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found or already consumed")
|
||||
|
||||
parsed = parse_authorization_input(body.code)
|
||||
code = parsed.get("code")
|
||||
incoming_state = parsed.get("state")
|
||||
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Could not extract authorization code from input")
|
||||
|
||||
if incoming_state and incoming_state != session["state"]:
|
||||
raise HTTPException(status_code=400, detail="State mismatch — possible CSRF")
|
||||
|
||||
verifier: str = session["verifier"]
|
||||
server: OAuthCallbackServer = session["server"]
|
||||
|
||||
result = exchange_authorization_code(code, verifier)
|
||||
|
||||
server.close()
|
||||
_sessions.pop(body.session_id, None)
|
||||
|
||||
if not isinstance(result, TokenSuccess):
|
||||
raise HTTPException(status_code=502, detail=f"Token exchange failed: {result.reason}")
|
||||
|
||||
account_id = _store_token(result)
|
||||
if not account_id:
|
||||
raise HTTPException(status_code=502, detail="Token exchanged but could not extract account ID")
|
||||
|
||||
return ExchangeResponse(account_id=account_id)
|
||||
|
||||
|
||||
@CODEX_AUTH_ROUTER.post("/refresh", response_model=RefreshResponse)
|
||||
async def refresh_codex_token():
|
||||
"""
|
||||
Refresh the stored Codex OAuth access token using the refresh token.
|
||||
|
||||
Updates environment variables with the new tokens.
|
||||
"""
|
||||
refresh_token = get_codex_refresh_token_env()
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No Codex refresh token stored. Please authenticate first via /initiate",
|
||||
)
|
||||
|
||||
result = refresh_access_token(refresh_token)
|
||||
if not isinstance(result, TokenSuccess):
|
||||
raise HTTPException(status_code=502, detail=f"Token refresh failed: {result.reason}")
|
||||
|
||||
account_id = _store_token(result)
|
||||
return RefreshResponse(
|
||||
account_id=account_id,
|
||||
detail="Token refreshed successfully",
|
||||
)
|
||||
|
||||
|
||||
@CODEX_AUTH_ROUTER.get("/status", response_model=StatusResponse)
|
||||
async def get_codex_auth_status():
|
||||
"""
|
||||
Return whether a valid Codex OAuth token is currently stored.
|
||||
"""
|
||||
import time
|
||||
|
||||
access_token = get_codex_access_token_env()
|
||||
if not access_token:
|
||||
return StatusResponse(status="not_authenticated", detail="No access token stored")
|
||||
|
||||
expires_str = get_codex_token_expires_env()
|
||||
if expires_str:
|
||||
try:
|
||||
expires_ms = int(expires_str)
|
||||
now_ms = int(time.time() * 1000)
|
||||
if now_ms >= expires_ms:
|
||||
return StatusResponse(status="expired", detail="Access token has expired — call /refresh")
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
account_id = get_account_id(access_token)
|
||||
return StatusResponse(status="authenticated", account_id=account_id)
|
||||
|
||||
|
||||
@CODEX_AUTH_ROUTER.post("/logout")
|
||||
async def logout_codex():
|
||||
"""
|
||||
Clear all stored Codex OAuth credentials from environment variables and userConfig.json.
|
||||
"""
|
||||
set_codex_access_token_env("")
|
||||
set_codex_refresh_token_env("")
|
||||
set_codex_token_expires_env("")
|
||||
set_codex_account_id_env("")
|
||||
set_codex_model_env("")
|
||||
save_codex_tokens_to_user_config()
|
||||
return {"detail": "Logged out successfully"}
|
||||
|
|
@ -3,6 +3,7 @@ from fastapi import APIRouter
|
|||
from api.v1.ppt.endpoints.slide_to_html import SLIDE_TO_HTML_ROUTER, HTML_TO_REACT_ROUTER, HTML_EDIT_ROUTER, LAYOUT_MANAGEMENT_ROUTER
|
||||
from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER
|
||||
from api.v1.ppt.endpoints.anthropic import ANTHROPIC_ROUTER
|
||||
from api.v1.ppt.endpoints.codex_auth import CODEX_AUTH_ROUTER
|
||||
from api.v1.ppt.endpoints.google import GOOGLE_ROUTER
|
||||
from api.v1.ppt.endpoints.openai import OPENAI_ROUTER
|
||||
from api.v1.ppt.endpoints.files import FILES_ROUTER
|
||||
|
|
@ -36,4 +37,5 @@ API_V1_PPT_ROUTER.include_router(PDF_SLIDES_ROUTER)
|
|||
API_V1_PPT_ROUTER.include_router(OPENAI_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ANTHROPIC_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(GOOGLE_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(CODEX_AUTH_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(PPTX_FONTS_ROUTER)
|
||||
|
|
|
|||
|
|
@ -4,3 +4,4 @@ OPENAI_URL = "https://api.openai.com/v1"
|
|||
DEFAULT_OPENAI_MODEL = "gpt-4.1"
|
||||
DEFAULT_GOOGLE_MODEL = "models/gemini-2.5-flash"
|
||||
DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514"
|
||||
DEFAULT_CODEX_MODEL = "gpt-5.3-codex-spark"
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@ class LLMProvider(Enum):
|
|||
GOOGLE = "google"
|
||||
ANTHROPIC = "anthropic"
|
||||
CUSTOM = "custom"
|
||||
CODEX = "codex"
|
||||
|
|
|
|||
|
|
@ -48,3 +48,10 @@ class UserConfig(BaseModel):
|
|||
|
||||
# Web Search
|
||||
WEB_GROUNDING: Optional[bool] = None
|
||||
|
||||
# Codex OAuth (ChatGPT)
|
||||
CODEX_MODEL: Optional[str] = None
|
||||
CODEX_ACCESS_TOKEN: Optional[str] = None
|
||||
CODEX_REFRESH_TOKEN: Optional[str] = None
|
||||
CODEX_TOKEN_EXPIRES: Optional[str] = None
|
||||
CODEX_ACCOUNT_ID: Optional[str] = None
|
||||
|
|
|
|||
431
servers/fastapi/services/codex_llm.py
Normal file
431
servers/fastapi/services/codex_llm.py
Normal file
|
|
@ -0,0 +1,431 @@
|
|||
"""Codex (Responses API) adapter for structured and unstructured LLM calls.
|
||||
|
||||
Stateless adapter: receives AsyncOpenAI client and tool_calls_handler at call time.
|
||||
Auth and client creation stay in LLMClient. Structure matches other providers:
|
||||
generate = call API, collect content + tool_calls, recurse on tool_calls; stream = same but yield deltas.
|
||||
|
||||
Uses LLMToolCallsHandler directly: tools are parsed via parse_tools() in llm_client (handler supports
|
||||
Codex and returns OpenAI-style dicts); this module flattens them for the Responses API. Tool execution
|
||||
uses tool_calls_handler.handle_tool_calls_openai().
|
||||
"""
|
||||
|
||||
import dirtyjson
|
||||
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from openai import APIStatusError, AsyncOpenAI, OpenAIError
|
||||
|
||||
from models.llm_message import (
|
||||
LLMMessage,
|
||||
OpenAIAssistantMessage,
|
||||
LLMSystemMessage,
|
||||
LLMUserMessage,
|
||||
)
|
||||
from models.llm_tool_call import OpenAIToolCall, OpenAIToolCallFunction
|
||||
from utils.schema_utils import ensure_strict_json_schema
|
||||
|
||||
# Responses API requires flat tool format: {"type":"function","name":...,"description":...,"parameters":...}
|
||||
RESPONSE_SCHEMA_NAME = "ResponseSchema"
|
||||
# Required tool choice for structured: force ResponseSchema (no plain-text fallback).
|
||||
STRUCTURED_TOOL_CHOICE = {"type": "function", "name": RESPONSE_SCHEMA_NAME}
|
||||
MAX_RECURSION_DEPTH = 5
|
||||
|
||||
|
||||
def _to_responses_tools(chat_tools: List[dict]) -> List[dict]:
|
||||
"""Convert Chat Completions tool format to flat Responses API format."""
|
||||
result = []
|
||||
for tool in chat_tools:
|
||||
if tool.get("type") != "function":
|
||||
result.append(tool)
|
||||
continue
|
||||
fn = tool.get("function") or tool
|
||||
result.append({
|
||||
"type": "function",
|
||||
"name": fn.get("name", ""),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def _items_to_openai_calls(items_by_id: dict[str, dict]) -> List[OpenAIToolCall]:
|
||||
"""Build OpenAIToolCall list from Responses API output_item map."""
|
||||
return [
|
||||
OpenAIToolCall(
|
||||
id=item.get("call_id", item.get("id", "")),
|
||||
type="function",
|
||||
function=OpenAIToolCallFunction(
|
||||
name=item.get("name", ""),
|
||||
arguments=item.get("arguments", "{}"),
|
||||
),
|
||||
)
|
||||
for item in items_by_id.values()
|
||||
]
|
||||
|
||||
|
||||
async def _messages_after_tool_turn(
|
||||
messages: List[LLMMessage],
|
||||
items_by_id: dict[str, dict],
|
||||
tool_calls_handler: Any,
|
||||
) -> List[LLMMessage]:
|
||||
"""Handle tool calls and return messages extended with assistant turn + tool results."""
|
||||
openai_calls = _items_to_openai_calls(items_by_id)
|
||||
tool_call_messages = await tool_calls_handler.handle_tool_calls_openai(openai_calls)
|
||||
return [
|
||||
*messages,
|
||||
OpenAIAssistantMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[tc.model_dump() for tc in openai_calls],
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
|
||||
|
||||
def _build_body(
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> dict:
|
||||
"""Build Responses API request body."""
|
||||
instructions = None
|
||||
input_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if isinstance(msg, LLMSystemMessage):
|
||||
instructions = msg.content
|
||||
elif isinstance(msg, LLMUserMessage):
|
||||
input_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": msg.content}],
|
||||
})
|
||||
elif isinstance(msg, OpenAIAssistantMessage):
|
||||
text = msg.content or ""
|
||||
if text:
|
||||
input_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": text}],
|
||||
})
|
||||
else:
|
||||
text = getattr(msg, "content", "") or ""
|
||||
if text:
|
||||
input_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}],
|
||||
})
|
||||
|
||||
body: dict = {
|
||||
"model": model,
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"tool_choice": tool_choice if tool_choice is not None else "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
if instructions:
|
||||
body["instructions"] = instructions
|
||||
if input_messages:
|
||||
body["input"] = input_messages
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
|
||||
return body
|
||||
|
||||
|
||||
def _event_to_dict(event: Any) -> dict:
|
||||
"""Convert SDK event to dict."""
|
||||
if hasattr(event, "model_dump"):
|
||||
return event.model_dump()
|
||||
return {
|
||||
"type": getattr(event, "type", None),
|
||||
"delta": getattr(event, "delta", None),
|
||||
"item": getattr(event, "item", None),
|
||||
"message": getattr(event, "message", None),
|
||||
"arguments": getattr(event, "arguments", None),
|
||||
"name": getattr(event, "name", None),
|
||||
}
|
||||
|
||||
|
||||
async def _stream_raw(
|
||||
client: AsyncOpenAI,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""Yield raw SSE event dicts from Codex Responses API."""
|
||||
body = _build_body(model, messages, tools, tool_choice=tool_choice)
|
||||
create_kwargs = {k: v for k, v in body.items() if k != "stream"}
|
||||
|
||||
try:
|
||||
stream = await client.responses.create(stream=True, **create_kwargs)
|
||||
except (APIStatusError, OpenAIError) as e:
|
||||
status = getattr(e, "status_code", 502)
|
||||
detail = getattr(e, "message", str(e)) or str(e)
|
||||
raise HTTPException(
|
||||
status_code=status,
|
||||
detail=f"Codex API error: {detail}"[:400],
|
||||
) from e
|
||||
|
||||
async for event in stream:
|
||||
yield _event_to_dict(event)
|
||||
|
||||
|
||||
class CodexLLMAdapter:
|
||||
"""Stateless adapter for Codex Responses API. Matches other providers: generate/stream + tool recursion."""
|
||||
|
||||
@staticmethod
|
||||
async def generate_codex(
|
||||
client: AsyncOpenAI,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
tool_calls_handler: Any,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
) -> Optional[str]:
|
||||
"""Generate text; on tool_calls handle and recurse (like _generate_openai / _generate_anthropic)."""
|
||||
print(
|
||||
f"Codex generate: model={model} depth={depth} tools_count={len(tools) if tools else 0}"
|
||||
)
|
||||
responses_tools = _to_responses_tools(tools) if tools else None
|
||||
text_parts: List[str] = []
|
||||
tool_calls_by_id: dict[str, dict] = {}
|
||||
|
||||
async for event in _stream_raw(client, model, messages, responses_tools, tool_choice=None):
|
||||
event_type = event.get("type", "")
|
||||
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = event.get("delta", "")
|
||||
if delta:
|
||||
text_parts.append(delta)
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
tool_calls_by_id[item.get("call_id", item.get("id", ""))] = item
|
||||
elif event_type in ("response.failed", "error"):
|
||||
msg_text = event.get("message") or str(event)
|
||||
raise HTTPException(status_code=502, detail=f"Codex error: {msg_text}")
|
||||
|
||||
if tool_calls_by_id and tools and depth < MAX_RECURSION_DEPTH:
|
||||
print(
|
||||
f"Codex generate: tool calls detected depth={depth} count={len(tool_calls_by_id)}"
|
||||
)
|
||||
new_messages = await _messages_after_tool_turn(
|
||||
messages, tool_calls_by_id, tool_calls_handler
|
||||
)
|
||||
return await CodexLLMAdapter.generate_codex(
|
||||
client, model, new_messages, tool_calls_handler,
|
||||
max_tokens=max_tokens, tools=tools, depth=depth + 1,
|
||||
)
|
||||
|
||||
return "".join(text_parts) or None
|
||||
|
||||
@staticmethod
|
||||
async def stream_codex(
|
||||
client: AsyncOpenAI,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
tool_calls_handler: Any,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream text deltas; on tool_calls handle and recurse (like _stream_openai)."""
|
||||
print(
|
||||
f"Codex stream: model={model} depth={depth} tools_count={len(tools) if tools else 0}"
|
||||
)
|
||||
responses_tools = _to_responses_tools(tools) if tools else None
|
||||
tool_calls_by_id: dict[str, dict] = {}
|
||||
|
||||
async for event in _stream_raw(client, model, messages, responses_tools, tool_choice=None):
|
||||
event_type = event.get("type", "")
|
||||
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = event.get("delta", "")
|
||||
if delta:
|
||||
yield delta
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
tool_calls_by_id[item.get("call_id", item.get("id", ""))] = item
|
||||
elif event_type in ("response.failed", "error"):
|
||||
msg_text = event.get("message") or str(event)
|
||||
raise HTTPException(status_code=502, detail=f"Codex stream error: {msg_text}")
|
||||
|
||||
if tool_calls_by_id and tools and depth < MAX_RECURSION_DEPTH:
|
||||
print(
|
||||
f"Codex stream: tool calls detected depth={depth} count={len(tool_calls_by_id)}"
|
||||
)
|
||||
new_messages = await _messages_after_tool_turn(
|
||||
messages, tool_calls_by_id, tool_calls_handler
|
||||
)
|
||||
async for chunk in CodexLLMAdapter.stream_codex(
|
||||
client, model, new_messages, tool_calls_handler,
|
||||
max_tokens=max_tokens, tools=tools, depth=depth + 1,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
@staticmethod
|
||||
async def stream_codex_structured(
|
||||
client: AsyncOpenAI,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
tool_calls_handler: Any,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream JSON chunks from ResponseSchema tool; recurse for other tool_calls.
|
||||
|
||||
Structured output is achieved by always adding an internal ResponseSchema "tool"
|
||||
(with response_format as its parameters) and tool_choice=ResponseSchema. So
|
||||
user_tools=0 only means no extra tools like web search; we still use the
|
||||
ResponseSchema tool to receive the model's JSON.
|
||||
"""
|
||||
user_tools_count = len(tools) if tools else 0
|
||||
print(
|
||||
f"Codex stream_structured: model={model} depth={depth} strict={strict} "
|
||||
f"user_tools={user_tools_count} (always adding ResponseSchema tool for structured JSON)"
|
||||
)
|
||||
schema = ensure_strict_json_schema(response_format, path=(), root=response_format) if strict and depth == 0 else response_format
|
||||
response_schema_tool = {
|
||||
"type": "function",
|
||||
"name": RESPONSE_SCHEMA_NAME,
|
||||
"description": "Provide response to the user",
|
||||
"parameters": schema,
|
||||
}
|
||||
all_tools: List[dict] = [response_schema_tool]
|
||||
if tools:
|
||||
all_tools.extend(_to_responses_tools(tools))
|
||||
|
||||
tool_calls_by_id: dict[str, dict] = {}
|
||||
current_call_id: Optional[str] = None
|
||||
|
||||
async for event in _stream_raw(
|
||||
client, model, messages, all_tools, tool_choice=STRUCTURED_TOOL_CHOICE
|
||||
):
|
||||
event_type = event.get("type", "")
|
||||
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call" and item.get("name") == RESPONSE_SCHEMA_NAME:
|
||||
current_call_id = item.get("call_id", item.get("id"))
|
||||
print(
|
||||
f"Codex stream_structured: ResponseSchema call started call_id={current_call_id}"
|
||||
)
|
||||
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
if current_call_id is not None:
|
||||
delta = event.get("delta", "")
|
||||
if delta:
|
||||
# Log only first few chunks to avoid log spam
|
||||
print(
|
||||
f"Codex stream_structured: ResponseSchema delta chunk len={len(delta)}"
|
||||
)
|
||||
yield delta
|
||||
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
if event.get("name") == RESPONSE_SCHEMA_NAME:
|
||||
arguments = event.get("arguments", "")
|
||||
if arguments:
|
||||
print(
|
||||
f"Codex stream_structured: ResponseSchema arguments.done len={len(arguments)}"
|
||||
)
|
||||
yield arguments
|
||||
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
tool_calls_by_id[item.get("call_id", item.get("id", ""))] = item
|
||||
if item.get("name") == RESPONSE_SCHEMA_NAME:
|
||||
arguments = item.get("arguments", "")
|
||||
if arguments:
|
||||
print(
|
||||
f"Codex stream_structured: ResponseSchema output_item.done len={len(arguments)}"
|
||||
)
|
||||
yield arguments
|
||||
|
||||
elif event_type in ("response.failed", "error"):
|
||||
msg_text = event.get("message") or str(event)
|
||||
raise HTTPException(status_code=502, detail=f"Codex structured error: {msg_text}")
|
||||
|
||||
other_tool_calls = {
|
||||
k: v for k, v in tool_calls_by_id.items()
|
||||
if v.get("name") != RESPONSE_SCHEMA_NAME
|
||||
}
|
||||
if other_tool_calls and tools and depth < MAX_RECURSION_DEPTH:
|
||||
print(
|
||||
f"Codex stream_structured: recursing for non-ResponseSchema tool calls "
|
||||
f"depth={depth} count={len(other_tool_calls)}"
|
||||
)
|
||||
new_messages = await _messages_after_tool_turn(
|
||||
messages, other_tool_calls, tool_calls_handler
|
||||
)
|
||||
async for chunk in CodexLLMAdapter.stream_codex_structured(
|
||||
client, model, new_messages, response_format, tool_calls_handler,
|
||||
strict=strict, max_tokens=max_tokens, tools=tools, depth=depth + 1,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
@staticmethod
|
||||
async def generate_codex_structured(
|
||||
client: AsyncOpenAI,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
tool_calls_handler: Any,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
) -> Optional[dict]:
|
||||
"""Collect stream and parse JSON (like _generate_openai_structured)."""
|
||||
user_tools_count = len(tools) if tools else 0
|
||||
print(
|
||||
f"Codex generate_structured: model={model} depth={depth} strict={strict} "
|
||||
f"user_tools={user_tools_count} (using ResponseSchema tool for structured JSON)"
|
||||
)
|
||||
accumulated: List[str] = []
|
||||
async for chunk in CodexLLMAdapter.stream_codex_structured(
|
||||
client, model, messages, response_format, tool_calls_handler,
|
||||
strict=strict, max_tokens=max_tokens, tools=tools, depth=depth,
|
||||
):
|
||||
accumulated.append(chunk)
|
||||
|
||||
raw = "".join(accumulated)
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
if depth == 0:
|
||||
try:
|
||||
parsed = dict(dirtyjson.loads(raw))
|
||||
print(
|
||||
f"Codex generate_structured: parsed JSON keys={list(parsed.keys())[:8]}"
|
||||
)
|
||||
return parsed
|
||||
except Exception:
|
||||
start = raw.find("{")
|
||||
if start >= 0:
|
||||
try:
|
||||
parsed = dict(dirtyjson.loads(raw[start:]))
|
||||
print(
|
||||
"Codex generate_structured: parsed JSON from offset "
|
||||
f"{start} keys={list(parsed.keys())[:8]}"
|
||||
)
|
||||
return parsed
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Model did not return valid structured output (expected JSON from ResponseSchema). "
|
||||
"Please retry."
|
||||
),
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
import asyncio
|
||||
import dirtyjson
|
||||
import json
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import AsyncGenerator, List, Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from openai import APIStatusError, AsyncOpenAI, OpenAIError
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
|
|
@ -44,6 +44,10 @@ from utils.async_iterator import iterator_to_async
|
|||
from utils.dummy_functions import do_nothing_async
|
||||
from utils.get_env import (
|
||||
get_anthropic_api_key_env,
|
||||
get_codex_access_token_env,
|
||||
get_codex_account_id_env,
|
||||
get_codex_refresh_token_env,
|
||||
get_codex_token_expires_env,
|
||||
get_custom_llm_api_key_env,
|
||||
get_custom_llm_url_env,
|
||||
get_disable_thinking_env,
|
||||
|
|
@ -53,6 +57,12 @@ from utils.get_env import (
|
|||
get_tool_calls_env,
|
||||
get_web_grounding_env,
|
||||
)
|
||||
from utils.set_env import (
|
||||
set_codex_access_token_env,
|
||||
set_codex_account_id_env,
|
||||
set_codex_refresh_token_env,
|
||||
set_codex_token_expires_env,
|
||||
)
|
||||
from utils.llm_provider import get_llm_provider, get_model
|
||||
from utils.parsers import parse_bool_or_none
|
||||
from utils.schema_utils import (
|
||||
|
|
@ -62,6 +72,7 @@ from utils.schema_utils import (
|
|||
)
|
||||
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self):
|
||||
self.llm_provider = get_llm_provider()
|
||||
|
|
@ -100,10 +111,12 @@ class LLMClient:
|
|||
return self._get_ollama_client()
|
||||
case LLMProvider.CUSTOM:
|
||||
return self._get_custom_client()
|
||||
case LLMProvider.CODEX:
|
||||
return self._get_codex_client()
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="LLM Provider must be either openai, google, anthropic, ollama, or custom",
|
||||
detail="LLM Provider must be either openai, google, anthropic, ollama, custom, or codex",
|
||||
)
|
||||
|
||||
def _get_openai_client(self):
|
||||
|
|
@ -147,6 +160,74 @@ class LLMClient:
|
|||
api_key=get_custom_llm_api_key_env() or "null",
|
||||
)
|
||||
|
||||
def _get_codex_headers(self) -> dict:
|
||||
"""Return the HTTP headers required for Codex Responses API requests.
|
||||
|
||||
Handles token auto-refresh if the stored token is expired or within
|
||||
60 s of expiry before building the header dict.
|
||||
"""
|
||||
access_token = get_codex_access_token_env()
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Codex OAuth access token is not set. Please authenticate via /api/v1/ppt/codex/auth/initiate",
|
||||
)
|
||||
|
||||
# Auto-refresh if the token is expired or about to expire (within 60 s)
|
||||
expires_str = get_codex_token_expires_env()
|
||||
if expires_str:
|
||||
try:
|
||||
expires_ms = int(expires_str)
|
||||
now_ms = int(__import__("time").time() * 1000)
|
||||
if now_ms >= expires_ms - 60_000:
|
||||
refresh_token = get_codex_refresh_token_env()
|
||||
if refresh_token:
|
||||
from utils.oauth.openai_codex import (
|
||||
get_account_id,
|
||||
refresh_access_token,
|
||||
TokenSuccess,
|
||||
)
|
||||
result = refresh_access_token(refresh_token)
|
||||
if isinstance(result, TokenSuccess):
|
||||
set_codex_access_token_env(result.access)
|
||||
set_codex_refresh_token_env(result.refresh)
|
||||
set_codex_token_expires_env(str(result.expires))
|
||||
account_id = get_account_id(result.access)
|
||||
if account_id:
|
||||
set_codex_account_id_env(account_id)
|
||||
access_token = result.access
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
account_id = get_codex_account_id_env() or ""
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"chatgpt-account-id": account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"originator": "pi",
|
||||
"content-type": "application/json",
|
||||
"accept": "text/event-stream",
|
||||
}
|
||||
|
||||
def _get_codex_client(self) -> AsyncOpenAI:
|
||||
"""Return an AsyncOpenAI client configured for the Codex Responses API.
|
||||
Client is built per call so headers/token are fresh after refresh.
|
||||
Only Codex-specific headers are passed; content-type and accept are left
|
||||
to the SDK so the server does not reject the request.
|
||||
"""
|
||||
headers = self._get_codex_headers()
|
||||
access_token = (headers.get("Authorization") or "").replace("Bearer ", "").strip()
|
||||
skip = {"authorization", "content-type", "accept"}
|
||||
default_headers = {
|
||||
k: v for k, v in headers.items() if k.lower() not in skip
|
||||
}
|
||||
return AsyncOpenAI(
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
api_key=access_token or "codex",
|
||||
default_headers=default_headers,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
# ? Prompts
|
||||
def _get_system_prompt(self, messages: List[LLMMessage]) -> str:
|
||||
for message in messages:
|
||||
|
|
@ -401,6 +482,147 @@ class LLMClient:
|
|||
depth=depth,
|
||||
)
|
||||
|
||||
async def _generate_codex(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Generate plain text using the Codex Responses API. On tool calls, run
|
||||
handlers and recurse (same pattern as _generate_openai).
|
||||
"""
|
||||
_MAX_RECURSION_DEPTH = 5
|
||||
client: AsyncOpenAI = self._client
|
||||
|
||||
# Flatten tools to Responses API format
|
||||
responses_tools: Optional[List[dict]] = None
|
||||
if tools:
|
||||
responses_tools = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or tool) if isinstance(tool, dict) else {}
|
||||
if isinstance(fn, dict):
|
||||
responses_tools.append({
|
||||
"type": "function",
|
||||
"name": fn.get("name", ""),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
})
|
||||
else:
|
||||
responses_tools.append(tool)
|
||||
|
||||
# Build instructions + input (same shape as _stream_codex_structured)
|
||||
instructions = self._get_system_prompt(messages) or None
|
||||
input_payload: List[Dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if isinstance(m, LLMSystemMessage):
|
||||
continue
|
||||
if isinstance(m, LLMUserMessage):
|
||||
input_payload.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": m.content}],
|
||||
})
|
||||
elif isinstance(m, OpenAIAssistantMessage):
|
||||
text = m.content or ""
|
||||
if text:
|
||||
input_payload.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": text}],
|
||||
})
|
||||
else:
|
||||
text = getattr(m, "content", "") or ""
|
||||
if text:
|
||||
input_payload.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}],
|
||||
})
|
||||
|
||||
create_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
if instructions:
|
||||
create_kwargs["instructions"] = instructions
|
||||
if input_payload:
|
||||
create_kwargs["input"] = input_payload
|
||||
if responses_tools:
|
||||
create_kwargs["tools"] = responses_tools
|
||||
if max_tokens is not None:
|
||||
create_kwargs["max_output_tokens"] = max_tokens
|
||||
|
||||
stream = await client.responses.create(**create_kwargs)
|
||||
|
||||
def _event_dict(ev: Any) -> dict:
|
||||
if hasattr(ev, "model_dump"):
|
||||
return ev.model_dump()
|
||||
return {
|
||||
"type": getattr(ev, "type", None),
|
||||
"delta": getattr(ev, "delta", None),
|
||||
"item": getattr(ev, "item", None),
|
||||
"message": getattr(ev, "message", None),
|
||||
}
|
||||
|
||||
text_parts: List[str] = []
|
||||
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async for ev in stream:
|
||||
event = _event_dict(ev) if not isinstance(ev, dict) else ev
|
||||
event_type = event.get("type") or ""
|
||||
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = event.get("delta") or ""
|
||||
if delta:
|
||||
text_parts.append(delta)
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
cid = item.get("call_id") or item.get("id", "")
|
||||
tool_calls_by_id[cid] = item
|
||||
elif event_type in ("response.error", "response.failed", "error"):
|
||||
err = event.get("message") or event.get("error") or str(event)
|
||||
raise HTTPException(status_code=502, detail=f"Codex error: {err}"[:400])
|
||||
|
||||
if tool_calls_by_id and responses_tools and depth < _MAX_RECURSION_DEPTH:
|
||||
parsed_tool_calls = [
|
||||
OpenAIToolCall(
|
||||
id=cid,
|
||||
type="function",
|
||||
function=OpenAIToolCallFunction(
|
||||
name=data.get("name", ""),
|
||||
arguments=data.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
for cid, data in tool_calls_by_id.items()
|
||||
]
|
||||
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
||||
parsed_tool_calls
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
OpenAIAssistantMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[tc.model_dump() for tc in parsed_tool_calls],
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
return await self._generate_codex(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
depth=depth + 1,
|
||||
)
|
||||
|
||||
return "".join(text_parts) or None
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -419,6 +641,13 @@ class LLMClient:
|
|||
max_tokens=max_tokens,
|
||||
tools=parsed_tools,
|
||||
)
|
||||
case LLMProvider.CODEX:
|
||||
content = await self._generate_codex(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=parsed_tools,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
content = await self._generate_google(
|
||||
model=model,
|
||||
|
|
@ -566,6 +795,48 @@ class LLMClient:
|
|||
return content
|
||||
return None
|
||||
|
||||
async def _generate_codex_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
extra_body: Optional[dict] = None,
|
||||
depth: int = 0,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Generate structured Codex output using the Responses API.
|
||||
|
||||
This reuses the streaming Codex structured implementation and simply
|
||||
accumulates the streamed JSON chunks into a single string, then parses
|
||||
it at the root call.
|
||||
"""
|
||||
# Reuse the Responses API streaming implementation for Codex.
|
||||
accumulated: List[str] = []
|
||||
async for chunk in self._stream_codex_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
extra_body=extra_body,
|
||||
depth=depth,
|
||||
):
|
||||
accumulated.append(chunk)
|
||||
|
||||
raw = "".join(accumulated)
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
# At the root level we parse into a dict; recursive calls just
|
||||
# propagate the raw JSON/text, mirroring other providers.
|
||||
if depth == 0:
|
||||
return dict(dirtyjson.loads(raw))
|
||||
return {"raw": raw}
|
||||
|
||||
async def _generate_google_structured(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -795,6 +1066,15 @@ class LLMClient:
|
|||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.CODEX:
|
||||
content = await self._generate_codex_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
content = await self._generate_google_structured(
|
||||
model=model,
|
||||
|
|
@ -1068,6 +1348,157 @@ class LLMClient:
|
|||
):
|
||||
yield event
|
||||
|
||||
async def _stream_codex(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream plain text from Codex (Responses API). On tool calls, execute tools
|
||||
and recurse, mirroring _stream_openai but using Responses events.
|
||||
"""
|
||||
_MAX_RECURSION_DEPTH = 5
|
||||
client: AsyncOpenAI = (
|
||||
self._get_codex_client()
|
||||
if self.llm_provider == LLMProvider.CODEX
|
||||
else self._client
|
||||
)
|
||||
|
||||
# Flatten tools to Responses API format
|
||||
responses_tools: Optional[List[dict]] = None
|
||||
if tools:
|
||||
responses_tools = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or tool) if isinstance(tool, dict) else {}
|
||||
if isinstance(fn, dict):
|
||||
responses_tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": fn.get("name", ""),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
else:
|
||||
responses_tools.append(tool)
|
||||
|
||||
# Build instructions + input (same shape as _generate_codex/_stream_codex_structured)
|
||||
instructions = self._get_system_prompt(messages) or None
|
||||
input_payload: List[Dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if isinstance(m, LLMSystemMessage):
|
||||
continue
|
||||
if isinstance(m, LLMUserMessage):
|
||||
input_payload.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": m.content}],
|
||||
}
|
||||
)
|
||||
elif isinstance(m, OpenAIAssistantMessage):
|
||||
text = m.content or ""
|
||||
if text:
|
||||
input_payload.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": text}],
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = getattr(m, "content", "") or ""
|
||||
if text:
|
||||
input_payload.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}],
|
||||
}
|
||||
)
|
||||
|
||||
create_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
if instructions:
|
||||
create_kwargs["instructions"] = instructions
|
||||
if input_payload:
|
||||
create_kwargs["input"] = input_payload
|
||||
if responses_tools:
|
||||
create_kwargs["tools"] = responses_tools
|
||||
if max_tokens is not None:
|
||||
create_kwargs["max_output_tokens"] = max_tokens
|
||||
|
||||
stream = await client.responses.create(**create_kwargs)
|
||||
|
||||
def _event_dict(ev: Any) -> dict:
|
||||
if hasattr(ev, "model_dump"):
|
||||
return ev.model_dump()
|
||||
return {
|
||||
"type": getattr(ev, "type", None),
|
||||
"delta": getattr(ev, "delta", None),
|
||||
"item": getattr(ev, "item", None),
|
||||
"message": getattr(ev, "message", None),
|
||||
}
|
||||
|
||||
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async for ev in stream:
|
||||
event = _event_dict(ev) if not isinstance(ev, dict) else ev
|
||||
event_type = event.get("type") or ""
|
||||
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = event.get("delta") or ""
|
||||
if delta:
|
||||
yield delta
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
cid = item.get("call_id") or item.get("id", "")
|
||||
tool_calls_by_id[cid] = item
|
||||
elif event_type in ("response.error", "response.failed", "error"):
|
||||
err = event.get("message") or event.get("error") or str(event)
|
||||
raise HTTPException(status_code=502, detail=f"Codex stream error: {err}"[:400])
|
||||
|
||||
if tool_calls_by_id and responses_tools and depth < _MAX_RECURSION_DEPTH:
|
||||
parsed_tool_calls = [
|
||||
OpenAIToolCall(
|
||||
id=cid,
|
||||
type="function",
|
||||
function=OpenAIToolCallFunction(
|
||||
name=data.get("name", ""),
|
||||
arguments=data.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
for cid, data in tool_calls_by_id.items()
|
||||
]
|
||||
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
||||
parsed_tool_calls
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
OpenAIAssistantMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[tc.model_dump() for tc in parsed_tool_calls],
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
async for chunk in self._stream_codex(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
depth=depth + 1,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _stream_ollama(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1112,6 +1543,13 @@ class LLMClient:
|
|||
max_tokens=max_tokens,
|
||||
tools=parsed_tools,
|
||||
)
|
||||
case LLMProvider.CODEX:
|
||||
return self._stream_codex(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=parsed_tools,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self._stream_google(
|
||||
model=model,
|
||||
|
|
@ -1286,6 +1724,291 @@ class LLMClient:
|
|||
):
|
||||
yield event
|
||||
|
||||
|
||||
|
||||
async def _stream_codex_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
extra_body: Optional[dict] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream structured responses using OpenAI's Responses API (Codex-style models).
|
||||
|
||||
This implementation is intentionally separate from ChatCompletion-based streaming
|
||||
because the Responses API uses a fundamentally different event model.
|
||||
|
||||
Why this function exists:
|
||||
|
||||
1. The Responses API does NOT return `choices[].delta` like ChatCompletions.
|
||||
Instead, it streams typed events such as:
|
||||
- response.output_text.delta
|
||||
- response.output_tool_call.delta
|
||||
- response.completed
|
||||
- response.error
|
||||
|
||||
2. Structured output can be achieved in two ways:
|
||||
a) Native JSON schema enforcement via `response_format`
|
||||
b) Tool-call-based structured output using a synthetic `ResponseSchema` tool
|
||||
|
||||
This function supports both approaches. When tool-call structured mode is enabled,
|
||||
a dynamic `ResponseSchema` tool is injected so the model returns structured data
|
||||
as tool call arguments.
|
||||
|
||||
3. Tool calls must be accumulated incrementally.
|
||||
The Responses API streams tool call arguments in chunks (`arguments_delta`),
|
||||
so we reconstruct the full argument payload before executing the tool.
|
||||
|
||||
4. Recursive tool execution is supported.
|
||||
If the model calls external tools (e.g., web search), we:
|
||||
- Execute the tools asynchronously
|
||||
- Append tool results as new messages
|
||||
- Reinvoke the model recursively
|
||||
This enables multi-step reasoning and grounding workflows.
|
||||
|
||||
5. Provider abstraction is preserved.
|
||||
The Responses API event format is converted into our internal tool-call model
|
||||
before being passed to the tool handler layer. This prevents SDK-specific
|
||||
structures from leaking into business logic.
|
||||
|
||||
6. Strict schema enforcement (optional).
|
||||
When `strict=True`, the provided JSON schema is hardened before being sent
|
||||
to the model to reduce malformed outputs.
|
||||
|
||||
Important architectural note:
|
||||
This function MUST NOT assume ChatCompletion-style streaming fields like
|
||||
`choices`, `delta.content`, or `delta.tool_calls`. It strictly follows the
|
||||
Responses API event model.
|
||||
|
||||
This separation ensures:
|
||||
- Future compatibility with GPT-5 / Codex models
|
||||
- Clean provider abstraction
|
||||
- Streaming-safe structured JSON assembly
|
||||
- Robust multi-tool recursive execution
|
||||
"""
|
||||
client: AsyncOpenAI = self._client
|
||||
response_schema = response_format
|
||||
# Apply strict schema once at root
|
||||
if strict and depth == 0:
|
||||
response_schema = ensure_strict_json_schema(
|
||||
response_schema,
|
||||
path=(),
|
||||
root=response_schema,
|
||||
)
|
||||
|
||||
# Codex Responses API requires all array schemas to specify `items`.
|
||||
def _fix_arrays(node: Any) -> Any:
|
||||
if isinstance(node, dict):
|
||||
# Add default items for arrays missing them
|
||||
if node.get("type") == "array" and "items" not in node:
|
||||
node["items"] = {"type": "string"}
|
||||
for key, value in list(node.items()):
|
||||
node[key] = _fix_arrays(value)
|
||||
elif isinstance(node, list):
|
||||
for idx, value in enumerate(node):
|
||||
node[idx] = _fix_arrays(value)
|
||||
return node
|
||||
|
||||
response_schema = _fix_arrays(response_schema)
|
||||
|
||||
# Responses API tool format: flat {type, name, description, parameters}
|
||||
response_schema_tool = {
|
||||
"type": "function",
|
||||
"name": "ResponseSchema",
|
||||
"description": "Provide structured response",
|
||||
"parameters": response_schema,
|
||||
}
|
||||
all_tools: List[dict] = [response_schema_tool]
|
||||
if tools:
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or tool) if isinstance(tool, dict) else {}
|
||||
if isinstance(fn, dict):
|
||||
all_tools.append({
|
||||
"type": "function",
|
||||
"name": fn.get("name", ""),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
})
|
||||
else:
|
||||
all_tools.append(tool)
|
||||
|
||||
# Build instructions + input like Codex adapter (instructions from system; input_text/output_text)
|
||||
instructions = self._get_system_prompt(messages) or None
|
||||
input_payload: List[Dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if isinstance(m, LLMSystemMessage):
|
||||
continue
|
||||
if isinstance(m, LLMUserMessage):
|
||||
input_payload.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": m.content}],
|
||||
})
|
||||
elif isinstance(m, OpenAIAssistantMessage):
|
||||
text = m.content or ""
|
||||
if text:
|
||||
input_payload.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": text}],
|
||||
})
|
||||
else:
|
||||
text = getattr(m, "content", "") or ""
|
||||
if text:
|
||||
input_payload.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}],
|
||||
})
|
||||
|
||||
# Force model to use ResponseSchema for structured output
|
||||
tool_choice = {"type": "function", "name": "ResponseSchema"}
|
||||
create_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"tool_choice": tool_choice,
|
||||
"parallel_tool_calls": True,
|
||||
"tools": all_tools,
|
||||
}
|
||||
if instructions:
|
||||
create_kwargs["instructions"] = instructions
|
||||
if input_payload:
|
||||
create_kwargs["input"] = input_payload
|
||||
if max_tokens is not None:
|
||||
create_kwargs["max_output_tokens"] = max_tokens
|
||||
if extra_body:
|
||||
create_kwargs.update(extra_body)
|
||||
|
||||
stream = await client.responses.create(**create_kwargs)
|
||||
|
||||
|
||||
def _event_dict(ev: Any) -> dict:
|
||||
if hasattr(ev, "model_dump"):
|
||||
return ev.model_dump()
|
||||
return {
|
||||
"type": getattr(ev, "type", None),
|
||||
"delta": getattr(ev, "delta", None),
|
||||
"arguments": getattr(ev, "arguments", None),
|
||||
"arguments_delta": getattr(ev, "arguments_delta", None),
|
||||
"item": getattr(ev, "item", None),
|
||||
"id": getattr(ev, "id", None),
|
||||
"name": getattr(ev, "name", None),
|
||||
"error": getattr(ev, "error", None),
|
||||
"message": getattr(ev, "message", None),
|
||||
}
|
||||
|
||||
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
|
||||
current_call_id: Optional[str] = None
|
||||
has_response_schema_tool_call = False
|
||||
|
||||
async for ev in stream:
|
||||
event = _event_dict(ev) if not isinstance(ev, dict) else ev
|
||||
event_type = event.get("type") or ""
|
||||
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call" and item.get("name") == "ResponseSchema":
|
||||
current_call_id = item.get("call_id") or item.get("id")
|
||||
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
if current_call_id:
|
||||
delta = event.get("delta") or ""
|
||||
if delta:
|
||||
has_response_schema_tool_call = True
|
||||
yield delta
|
||||
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
if event.get("name") == "ResponseSchema":
|
||||
args = event.get("arguments") or ""
|
||||
if args:
|
||||
has_response_schema_tool_call = True
|
||||
yield args
|
||||
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
cid = item.get("call_id") or item.get("id", "")
|
||||
tool_calls_by_id[cid] = item
|
||||
if item.get("name") == "ResponseSchema":
|
||||
args = item.get("arguments") or ""
|
||||
if args:
|
||||
has_response_schema_tool_call = True
|
||||
yield args
|
||||
|
||||
elif event_type == "response.output_tool_call.delta":
|
||||
call_id = event.get("id")
|
||||
name = event.get("name")
|
||||
arguments_delta = event.get("arguments_delta") or ""
|
||||
if call_id and name:
|
||||
if call_id not in tool_calls_by_id:
|
||||
tool_calls_by_id[call_id] = {"name": name, "arguments": ""}
|
||||
tool_calls_by_id[call_id]["arguments"] += arguments_delta
|
||||
if name == "ResponseSchema" and arguments_delta:
|
||||
has_response_schema_tool_call = True
|
||||
yield arguments_delta
|
||||
|
||||
elif event_type == "response.completed":
|
||||
break
|
||||
|
||||
elif event_type in ("response.error", "response.failed", "error"):
|
||||
err = event.get("error") or event.get("message") or str(event)
|
||||
raise RuntimeError(err)
|
||||
|
||||
# ============================================
|
||||
# EXECUTE NON-STRUCTURED TOOL CALLS (RECURSIVE)
|
||||
# ============================================
|
||||
|
||||
other_tool_calls = {
|
||||
cid: data
|
||||
for cid, data in tool_calls_by_id.items()
|
||||
if data.get("name") != "ResponseSchema"
|
||||
}
|
||||
if other_tool_calls and not has_response_schema_tool_call:
|
||||
parsed_tool_calls = []
|
||||
for call_id, data in other_tool_calls.items():
|
||||
args = data.get("arguments", "") if isinstance(data, dict) else ""
|
||||
parsed_tool_calls.append(
|
||||
OpenAIToolCall(
|
||||
id=call_id,
|
||||
type="function",
|
||||
function=OpenAIToolCallFunction(
|
||||
name=data.get("name", ""),
|
||||
arguments=args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
||||
parsed_tool_calls
|
||||
)
|
||||
|
||||
new_messages = [
|
||||
*messages,
|
||||
OpenAIAssistantMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[tc.model_dump() for tc in parsed_tool_calls],
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
|
||||
async for chunk in self._stream_codex_structured(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
response_format=response_schema,
|
||||
strict=strict,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
extra_body=extra_body,
|
||||
depth=depth + 1,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _stream_google_structured(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1538,6 +2261,15 @@ class LLMClient:
|
|||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.CODEX:
|
||||
return self._stream_codex_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self._stream_google_structured(
|
||||
model=model,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class LLMToolCallsHandler:
|
|||
self.dynamic_tools.append(tool)
|
||||
|
||||
match self.client.llm_provider:
|
||||
case LLMProvider.OPENAI | LLMProvider.OLLAMA | LLMProvider.CUSTOM:
|
||||
case LLMProvider.OPENAI | LLMProvider.OLLAMA | LLMProvider.CUSTOM | LLMProvider.CODEX:
|
||||
return self.parse_tool_openai(tool, strict)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return self.parse_tool_anthropic(tool)
|
||||
|
|
@ -63,7 +63,7 @@ class LLMToolCallsHandler:
|
|||
return self.parse_tool_google(tool)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"LLM provider must be either openai, anthropic, or google"
|
||||
f"LLM provider must be one of: openai, anthropic, google, codex"
|
||||
)
|
||||
|
||||
def parse_tool_openai(
|
||||
|
|
|
|||
|
|
@ -117,3 +117,24 @@ def get_dall_e_3_quality_env():
|
|||
# Gpt Image 1.5 Quality
|
||||
def get_gpt_image_1_5_quality_env():
|
||||
return os.getenv("GPT_IMAGE_1_5_QUALITY")
|
||||
|
||||
|
||||
# Codex OAuth
|
||||
def get_codex_access_token_env():
|
||||
return os.getenv("CODEX_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def get_codex_refresh_token_env():
|
||||
return os.getenv("CODEX_REFRESH_TOKEN")
|
||||
|
||||
|
||||
def get_codex_token_expires_env():
|
||||
return os.getenv("CODEX_TOKEN_EXPIRES")
|
||||
|
||||
|
||||
def get_codex_account_id_env():
|
||||
return os.getenv("CODEX_ACCOUNT_ID")
|
||||
|
||||
|
||||
def get_codex_model_env():
|
||||
return os.getenv("CODEX_MODEL")
|
||||
|
|
|
|||
|
|
@ -125,20 +125,28 @@ async def get_slide_content_from_type_and_outline(
|
|||
True,
|
||||
)
|
||||
|
||||
messages = get_messages(
|
||||
outline.content,
|
||||
language,
|
||||
tone,
|
||||
verbosity,
|
||||
instructions,
|
||||
)
|
||||
print(
|
||||
f"get_slide_content_from_type_and_outline: model={model} outline_len={len(outline.content or '')} language={language}"
|
||||
)
|
||||
try:
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
messages=get_messages(
|
||||
outline.content,
|
||||
language,
|
||||
tone,
|
||||
verbosity,
|
||||
instructions,
|
||||
),
|
||||
messages=messages,
|
||||
response_format=response_schema,
|
||||
strict=False,
|
||||
)
|
||||
print(
|
||||
f"get_slide_content_from_type_and_outline: response is None={response is None} keys={list(response.keys())[:6] if isinstance(response, dict) else None}"
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"get_slide_content_from_type_and_outline: exception={e}")
|
||||
raise handle_llm_client_exceptions(e)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from constants.llm import (
|
|||
from enums.llm_provider import LLMProvider
|
||||
from utils.get_env import (
|
||||
get_anthropic_model_env,
|
||||
get_codex_model_env,
|
||||
get_custom_model_env,
|
||||
get_google_model_env,
|
||||
get_llm_provider_env,
|
||||
|
|
@ -22,7 +23,7 @@ def get_llm_provider():
|
|||
except:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom",
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom, codex",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -46,6 +47,10 @@ def is_custom_llm_selected():
|
|||
return get_llm_provider() == LLMProvider.CUSTOM
|
||||
|
||||
|
||||
def is_codex_selected():
|
||||
return get_llm_provider() == LLMProvider.CODEX
|
||||
|
||||
|
||||
def get_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
|
|
@ -58,8 +63,10 @@ def get_model():
|
|||
return get_ollama_model_env()
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_model_env()
|
||||
elif selected_llm == LLMProvider.CODEX:
|
||||
return get_codex_model_env()
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom",
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom, codex",
|
||||
)
|
||||
|
|
|
|||
0
servers/fastapi/utils/oauth/__init__.py
Normal file
0
servers/fastapi/utils/oauth/__init__.py
Normal file
348
servers/fastapi/utils/oauth/openai_codex.py
Normal file
348
servers/fastapi/utils/oauth/openai_codex.py
Normal file
|
|
@ -0,0 +1,348 @@
|
|||
"""
|
||||
OpenAI Codex (ChatGPT OAuth) flow — Python port of
|
||||
pi-mono-main/packages/ai/src/utils/oauth/openai-codex.ts
|
||||
|
||||
Handles PKCE authorization, local callback server, token exchange and refresh.
|
||||
No FastAPI dependencies; all HTTP is done with the standard library + httpx.
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from utils.oauth.pkce import generate_pkce
|
||||
|
||||
CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize"
|
||||
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
REDIRECT_URI = "http://localhost:1455/auth/callback"
|
||||
SCOPE = "openid profile email offline_access"
|
||||
JWT_CLAIM_PATH = "https://api.openai.com/auth"
|
||||
|
||||
CALLBACK_PORT = 1455
|
||||
|
||||
SUCCESS_HTML = b"""<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>Authentication successful</title>
|
||||
</head>
|
||||
<body>
|
||||
<p>Authentication successful. Return to your terminal / application to continue.</p>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class TokenSuccess:
|
||||
access: str
|
||||
refresh: str
|
||||
expires: int # Unix ms timestamp when the token expires
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenFailure:
|
||||
reason: str
|
||||
|
||||
|
||||
TokenResult = TokenSuccess | TokenFailure
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthorizationFlow:
|
||||
verifier: str
|
||||
state: str
|
||||
url: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JWT helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _decode_jwt_payload(token: str) -> Optional[dict]:
|
||||
"""Decode the payload segment of a JWT without verifying the signature."""
|
||||
try:
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
payload_b64 = parts[1]
|
||||
# Add padding if needed
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
decoded = base64.urlsafe_b64decode(payload_b64)
|
||||
return json.loads(decoded)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_account_id(access_token: str) -> Optional[str]:
|
||||
"""Extract the ChatGPT account ID from an access token JWT."""
|
||||
payload = _decode_jwt_payload(access_token)
|
||||
if not payload:
|
||||
return None
|
||||
auth_claims = payload.get(JWT_CLAIM_PATH)
|
||||
if not isinstance(auth_claims, dict):
|
||||
return None
|
||||
account_id = auth_claims.get("chatgpt_account_id")
|
||||
if isinstance(account_id, str) and account_id:
|
||||
return account_id
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization URL + PKCE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_authorization_flow(originator: str = "pi") -> AuthorizationFlow:
|
||||
"""Generate PKCE verifier/challenge, state, and the full authorization URL."""
|
||||
verifier, challenge = generate_pkce()
|
||||
state = secrets.token_hex(16)
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": CLIENT_ID,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
"scope": SCOPE,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
"id_token_add_organizations": "true",
|
||||
"codex_cli_simplified_flow": "true",
|
||||
"originator": originator,
|
||||
}
|
||||
url = f"{AUTHORIZE_URL}?{urlencode(params)}"
|
||||
return AuthorizationFlow(verifier=verifier, state=state, url=url)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local callback server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _CallbackHandler(BaseHTTPRequestHandler):
|
||||
"""Minimal HTTP handler that captures the OAuth callback code."""
|
||||
|
||||
def do_GET(self): # noqa: N802
|
||||
parsed = urlparse(self.path)
|
||||
if parsed.path != "/auth/callback":
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Not found")
|
||||
return
|
||||
|
||||
qs = parse_qs(parsed.query)
|
||||
state_vals = qs.get("state", [])
|
||||
code_vals = qs.get("code", [])
|
||||
|
||||
expected_state: str = self.server.expected_state # type: ignore[attr-defined]
|
||||
|
||||
if not state_vals or state_vals[0] != expected_state:
|
||||
self.send_response(400)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"State mismatch")
|
||||
return
|
||||
|
||||
if not code_vals:
|
||||
self.send_response(400)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Missing authorization code")
|
||||
return
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html; charset=utf-8")
|
||||
self.end_headers()
|
||||
self.wfile.write(SUCCESS_HTML)
|
||||
|
||||
self.server.captured_code = code_vals[0] # type: ignore[attr-defined]
|
||||
|
||||
def log_message(self, format, *args): # noqa: A002
|
||||
pass # suppress default stderr logging
|
||||
|
||||
|
||||
class OAuthCallbackServer:
|
||||
"""
|
||||
Wraps an HTTPServer that listens on port 1455 for the OAuth callback.
|
||||
Runs in a background daemon thread so it doesn't block the caller.
|
||||
"""
|
||||
|
||||
def __init__(self, state: str):
|
||||
self._state = state
|
||||
self._server: Optional[HTTPServer] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._started = threading.Event()
|
||||
self._cancelled = False
|
||||
|
||||
def start(self) -> bool:
|
||||
"""Start the background HTTP server. Returns True if successful."""
|
||||
try:
|
||||
server = HTTPServer(("0.0.0.0", CALLBACK_PORT), _CallbackHandler)
|
||||
server.expected_state = self._state # type: ignore[attr-defined]
|
||||
server.captured_code = None # type: ignore[attr-defined]
|
||||
server.timeout = 0.2 # short poll interval so we can check cancel
|
||||
self._server = server
|
||||
|
||||
def _serve():
|
||||
self._started.set()
|
||||
while not self._cancelled and server.captured_code is None:
|
||||
server.handle_request()
|
||||
server.server_close()
|
||||
|
||||
self._thread = threading.Thread(target=_serve, daemon=True)
|
||||
self._thread.start()
|
||||
self._started.wait(timeout=2)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
def get_code_nowait(self) -> Optional[str]:
|
||||
"""Non-blocking peek — returns the captured code or None immediately."""
|
||||
if self._server is None:
|
||||
return None
|
||||
return self._server.captured_code # type: ignore[attr-defined]
|
||||
|
||||
def wait_for_code(self, timeout_seconds: int = 120) -> Optional[str]:
|
||||
"""
|
||||
Block until the callback delivers a code or timeout / cancellation.
|
||||
Returns the authorization code or None.
|
||||
"""
|
||||
if self._server is None:
|
||||
return None
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
while time.monotonic() < deadline:
|
||||
if self._cancelled:
|
||||
return None
|
||||
code = self._server.captured_code # type: ignore[attr-defined]
|
||||
if code:
|
||||
return code
|
||||
time.sleep(0.1)
|
||||
return None
|
||||
|
||||
def cancel(self):
|
||||
self._cancelled = True
|
||||
|
||||
def close(self):
|
||||
self._cancelled = True
|
||||
if self._thread:
|
||||
self._thread.join(timeout=2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token exchange / refresh (sync — called from thread or FastAPI background)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def exchange_authorization_code(
|
||||
code: str,
|
||||
verifier: str,
|
||||
redirect_uri: str = REDIRECT_URI,
|
||||
) -> TokenResult:
|
||||
"""Exchange an authorization code for access + refresh tokens."""
|
||||
try:
|
||||
response = httpx.post(
|
||||
TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": CLIENT_ID,
|
||||
"code": code,
|
||||
"code_verifier": verifier,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if not response.is_success:
|
||||
return TokenFailure(reason=f"HTTP {response.status_code}: {response.text[:200]}")
|
||||
|
||||
body = response.json()
|
||||
access = body.get("access_token")
|
||||
refresh = body.get("refresh_token")
|
||||
expires_in = body.get("expires_in")
|
||||
|
||||
if not access or not refresh or not isinstance(expires_in, (int, float)):
|
||||
return TokenFailure(reason=f"Token response missing fields: {list(body.keys())}")
|
||||
|
||||
expires_ms = int(time.time() * 1000) + int(expires_in) * 1000
|
||||
return TokenSuccess(access=access, refresh=refresh, expires=expires_ms)
|
||||
except Exception as exc:
|
||||
return TokenFailure(reason=str(exc))
|
||||
|
||||
|
||||
def refresh_access_token(refresh_token: str) -> TokenResult:
|
||||
"""Use a refresh token to obtain a new access token."""
|
||||
try:
|
||||
response = httpx.post(
|
||||
TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CLIENT_ID,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if not response.is_success:
|
||||
return TokenFailure(reason=f"HTTP {response.status_code}: {response.text[:200]}")
|
||||
|
||||
body = response.json()
|
||||
access = body.get("access_token")
|
||||
refresh = body.get("refresh_token")
|
||||
expires_in = body.get("expires_in")
|
||||
|
||||
if not access or not refresh or not isinstance(expires_in, (int, float)):
|
||||
return TokenFailure(reason=f"Token refresh response missing fields: {list(body.keys())}")
|
||||
|
||||
expires_ms = int(time.time() * 1000) + int(expires_in) * 1000
|
||||
return TokenSuccess(access=access, refresh=refresh, expires=expires_ms)
|
||||
except Exception as exc:
|
||||
return TokenFailure(reason=str(exc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parsing helpers (for manual code paste / redirect URL fallback)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_authorization_input(raw: str) -> dict:
|
||||
"""
|
||||
Accept a variety of user-pasted inputs:
|
||||
- Full redirect URL: http://localhost:1455/auth/callback?code=X&state=Y
|
||||
- code#state shorthand
|
||||
- Raw query string: code=X&state=Y
|
||||
- Bare code value
|
||||
Returns a dict with optional 'code' and 'state' keys.
|
||||
"""
|
||||
value = raw.strip()
|
||||
if not value:
|
||||
return {}
|
||||
|
||||
try:
|
||||
parsed = urlparse(value)
|
||||
if parsed.scheme in ("http", "https"):
|
||||
qs = parse_qs(parsed.query)
|
||||
return {
|
||||
k: qs[k][0]
|
||||
for k in ("code", "state")
|
||||
if k in qs
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if "#" in value:
|
||||
parts = value.split("#", 1)
|
||||
return {"code": parts[0], "state": parts[1]}
|
||||
|
||||
if "code=" in value:
|
||||
qs = parse_qs(value)
|
||||
return {k: qs[k][0] for k in ("code", "state") if k in qs}
|
||||
|
||||
return {"code": value}
|
||||
23
servers/fastapi/utils/oauth/pkce.py
Normal file
23
servers/fastapi/utils/oauth/pkce.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""
|
||||
PKCE utilities using Python's secrets and hashlib.
|
||||
Python port of pi-mono-main/packages/ai/src/utils/oauth/pkce.ts
|
||||
"""
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
|
||||
def generate_pkce() -> tuple[str, str]:
|
||||
"""
|
||||
Generate PKCE code verifier and challenge (S256 method).
|
||||
|
||||
Returns:
|
||||
(verifier, challenge) — both base64url-encoded, no padding
|
||||
"""
|
||||
verifier_bytes = secrets.token_bytes(32)
|
||||
verifier = base64.urlsafe_b64encode(verifier_bytes).rstrip(b"=").decode()
|
||||
|
||||
digest = hashlib.sha256(verifier.encode()).digest()
|
||||
challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
|
||||
return verifier, challenge
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from copy import deepcopy
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Mapping, Union
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
|
|
@ -22,6 +22,51 @@ supported_string_formats = [
|
|||
]
|
||||
|
||||
|
||||
def _is_json_object(value: object) -> bool:
|
||||
"""True if value is a dict-like object but not a list."""
|
||||
return isinstance(value, Mapping) and not isinstance(value, list)
|
||||
|
||||
|
||||
def _convert_pydantic_schema(schema: object) -> dict | None:
|
||||
"""Return JSON schema dict from a Pydantic model or class, else None."""
|
||||
if BaseModel is None:
|
||||
return None
|
||||
if isinstance(schema, BaseModel):
|
||||
return schema.model_json_schema()
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
return schema.model_json_schema()
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
try:
|
||||
return getattr(schema, "model_json_schema")()
|
||||
except TypeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def normalize_output_schema(
|
||||
schema: Union[dict, type, object] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Normalize output schema to a plain JSON schema dict (SDK-style).
|
||||
Accepts a JSON schema dict, a Pydantic model class, or a Pydantic instance.
|
||||
Returns None if schema is None; otherwise returns a dict suitable for
|
||||
ResponseSchema / structured output.
|
||||
"""
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
converted = _convert_pydantic_schema(schema)
|
||||
if converted is not None:
|
||||
return converted
|
||||
|
||||
if not _is_json_object(schema):
|
||||
raise ValueError(
|
||||
"output_schema must be a plain JSON object (dict) or a Pydantic model"
|
||||
)
|
||||
|
||||
return dict(schema)
|
||||
|
||||
|
||||
def remove_fields_from_schema(schema: dict, fields_to_remove: List[str]):
|
||||
schema = deepcopy(schema)
|
||||
properties_paths = get_dict_paths_with_key(schema, "properties")
|
||||
|
|
|
|||
|
|
@ -103,3 +103,24 @@ def set_dall_e_3_quality_env(value):
|
|||
|
||||
def set_gpt_image_1_5_quality_env(value):
|
||||
os.environ["GPT_IMAGE_1_5_QUALITY"] = value
|
||||
|
||||
|
||||
# Codex OAuth
|
||||
def set_codex_access_token_env(value: str):
|
||||
os.environ["CODEX_ACCESS_TOKEN"] = value
|
||||
|
||||
|
||||
def set_codex_refresh_token_env(value: str):
|
||||
os.environ["CODEX_REFRESH_TOKEN"] = value
|
||||
|
||||
|
||||
def set_codex_token_expires_env(value: str):
|
||||
os.environ["CODEX_TOKEN_EXPIRES"] = value
|
||||
|
||||
|
||||
def set_codex_account_id_env(value: str):
|
||||
os.environ["CODEX_ACCOUNT_ID"] = value
|
||||
|
||||
|
||||
def set_codex_model_env(value: str):
|
||||
os.environ["CODEX_MODEL"] = value
|
||||
|
|
|
|||
|
|
@ -28,6 +28,11 @@ from utils.get_env import (
|
|||
get_pixabay_api_key_env,
|
||||
get_extended_reasoning_env,
|
||||
get_web_grounding_env,
|
||||
get_codex_access_token_env,
|
||||
get_codex_refresh_token_env,
|
||||
get_codex_token_expires_env,
|
||||
get_codex_account_id_env,
|
||||
get_codex_model_env,
|
||||
)
|
||||
from utils.parsers import parse_bool_or_none
|
||||
from utils.set_env import (
|
||||
|
|
@ -55,6 +60,11 @@ from utils.set_env import (
|
|||
set_pixabay_api_key_env,
|
||||
set_tool_calls_env,
|
||||
set_web_grounding_env,
|
||||
set_codex_access_token_env,
|
||||
set_codex_refresh_token_env,
|
||||
set_codex_token_expires_env,
|
||||
set_codex_account_id_env,
|
||||
set_codex_model_env,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -118,6 +128,11 @@ def get_user_config():
|
|||
if existing_config.WEB_GROUNDING is not None
|
||||
else (parse_bool_or_none(get_web_grounding_env()) or False)
|
||||
),
|
||||
CODEX_MODEL=existing_config.CODEX_MODEL or get_codex_model_env(),
|
||||
CODEX_ACCESS_TOKEN=existing_config.CODEX_ACCESS_TOKEN or get_codex_access_token_env(),
|
||||
CODEX_REFRESH_TOKEN=existing_config.CODEX_REFRESH_TOKEN or get_codex_refresh_token_env(),
|
||||
CODEX_TOKEN_EXPIRES=existing_config.CODEX_TOKEN_EXPIRES or get_codex_token_expires_env(),
|
||||
CODEX_ACCOUNT_ID=existing_config.CODEX_ACCOUNT_ID or get_codex_account_id_env(),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -171,3 +186,43 @@ def update_env_with_user_config():
|
|||
set_extended_reasoning_env(str(user_config.EXTENDED_REASONING))
|
||||
if user_config.WEB_GROUNDING is not None:
|
||||
set_web_grounding_env(str(user_config.WEB_GROUNDING))
|
||||
if user_config.CODEX_MODEL:
|
||||
set_codex_model_env(user_config.CODEX_MODEL)
|
||||
if user_config.CODEX_ACCESS_TOKEN:
|
||||
set_codex_access_token_env(user_config.CODEX_ACCESS_TOKEN)
|
||||
if user_config.CODEX_REFRESH_TOKEN:
|
||||
set_codex_refresh_token_env(user_config.CODEX_REFRESH_TOKEN)
|
||||
if user_config.CODEX_TOKEN_EXPIRES:
|
||||
set_codex_token_expires_env(user_config.CODEX_TOKEN_EXPIRES)
|
||||
if user_config.CODEX_ACCOUNT_ID:
|
||||
set_codex_account_id_env(user_config.CODEX_ACCOUNT_ID)
|
||||
|
||||
|
||||
def save_codex_tokens_to_user_config() -> None:
|
||||
"""
|
||||
Write the current in-memory Codex OAuth token env vars back to userConfig.json
|
||||
so they survive container restarts. Called after a successful token exchange
|
||||
and on logout (where the env vars have already been cleared to "").
|
||||
"""
|
||||
user_config_path = get_user_config_path_env()
|
||||
if not user_config_path:
|
||||
return
|
||||
|
||||
existing: dict = {}
|
||||
try:
|
||||
if os.path.exists(user_config_path):
|
||||
with open(user_config_path, "r") as f:
|
||||
existing = json.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
existing["CODEX_ACCESS_TOKEN"] = get_codex_access_token_env()
|
||||
existing["CODEX_REFRESH_TOKEN"] = get_codex_refresh_token_env()
|
||||
existing["CODEX_TOKEN_EXPIRES"] = get_codex_token_expires_env()
|
||||
existing["CODEX_ACCOUNT_ID"] = get_codex_account_id_env()
|
||||
|
||||
try:
|
||||
with open(user_config_path, "w") as f:
|
||||
json.dump(existing, f)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -91,6 +91,11 @@ export async function POST(request: Request) {
|
|||
userConfig.USE_CUSTOM_URL === undefined
|
||||
? existingConfig.USE_CUSTOM_URL
|
||||
: userConfig.USE_CUSTOM_URL,
|
||||
CODEX_MODEL: userConfig.CODEX_MODEL || existingConfig.CODEX_MODEL,
|
||||
CODEX_ACCESS_TOKEN: existingConfig.CODEX_ACCESS_TOKEN,
|
||||
CODEX_REFRESH_TOKEN: existingConfig.CODEX_REFRESH_TOKEN,
|
||||
CODEX_TOKEN_EXPIRES: existingConfig.CODEX_TOKEN_EXPIRES,
|
||||
CODEX_ACCOUNT_ID: existingConfig.CODEX_ACCOUNT_ID,
|
||||
};
|
||||
fs.writeFileSync(userConfigPath, JSON.stringify(mergedConfig));
|
||||
return NextResponse.json(mergedConfig);
|
||||
|
|
|
|||
429
servers/nextjs/components/CodexConfig.tsx
Normal file
429
servers/nextjs/components/CodexConfig.tsx
Normal file
|
|
@ -0,0 +1,429 @@
|
|||
"use client";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
Check,
|
||||
ChevronsUpDown,
|
||||
Loader2,
|
||||
LogIn,
|
||||
LogOut,
|
||||
RefreshCw,
|
||||
UserCheck,
|
||||
} from "lucide-react";
|
||||
import { Button } from "./ui/button";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "./ui/command";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface CodexConfigProps {
|
||||
codexModel: string;
|
||||
onInputChange: (value: string | boolean, field: string) => void;
|
||||
}
|
||||
|
||||
type AuthStatus = "checking" | "unauthenticated" | "polling" | "authenticated";
|
||||
|
||||
interface StatusResponse {
|
||||
status: string;
|
||||
account_id?: string;
|
||||
detail?: string;
|
||||
}
|
||||
|
||||
interface CodexModel {
|
||||
id: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
const CHATGPT_MODELS: CodexModel[] = [
|
||||
{ id: "gpt-5.1", name: "GPT-5.1" },
|
||||
{ id: "gpt-5.1-codex-max", name: "GPT-5.1 Codex Max" },
|
||||
{ id: "gpt-5.1-codex-mini", name: "GPT-5.1 Codex Mini" },
|
||||
{ id: "gpt-5.2", name: "GPT-5.2" },
|
||||
{ id: "gpt-5.2-codex", name: "GPT-5.2 Codex" },
|
||||
{ id: "gpt-5.3-codex", name: "GPT-5.3 Codex" },
|
||||
{ id: "gpt-5.3-codex-spark", name: "GPT-5.3 Codex Spark (Free)" },
|
||||
];
|
||||
|
||||
const DEFAULT_CODEX_MODEL = "gpt-5.3-codex-spark";
|
||||
|
||||
export default function CodexConfig({
|
||||
codexModel,
|
||||
onInputChange,
|
||||
}: CodexConfigProps) {
|
||||
const [authStatus, setAuthStatus] = useState<AuthStatus>("checking");
|
||||
const [accountId, setAccountId] = useState<string | null>(null);
|
||||
const [sessionId, setSessionId] = useState<string | null>(null);
|
||||
const [manualCode, setManualCode] = useState("");
|
||||
const [isExchanging, setIsExchanging] = useState(false);
|
||||
const [isLoggingOut, setIsLoggingOut] = useState(false);
|
||||
const [isRefreshing, setIsRefreshing] = useState(false);
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
const pollIntervalRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
|
||||
const stopPolling = () => {
|
||||
if (pollIntervalRef.current) {
|
||||
clearInterval(pollIntervalRef.current);
|
||||
pollIntervalRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
// Check current auth state on mount
|
||||
useEffect(() => {
|
||||
checkCurrentAuthStatus();
|
||||
return () => stopPolling();
|
||||
}, []);
|
||||
|
||||
const checkCurrentAuthStatus = async () => {
|
||||
try {
|
||||
const res = await fetch("/api/v1/ppt/codex/auth/status");
|
||||
if (!res.ok) {
|
||||
setAuthStatus("unauthenticated");
|
||||
return;
|
||||
}
|
||||
const data: StatusResponse = await res.json();
|
||||
if (data.status === "authenticated") {
|
||||
setAuthStatus("authenticated");
|
||||
setAccountId(data.account_id ?? null);
|
||||
} else {
|
||||
setAuthStatus("unauthenticated");
|
||||
}
|
||||
} catch {
|
||||
setAuthStatus("unauthenticated");
|
||||
}
|
||||
};
|
||||
|
||||
const handleSignIn = async () => {
|
||||
try {
|
||||
const res = await fetch("/api/v1/ppt/codex/auth/initiate", {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) throw new Error("Failed to initiate auth");
|
||||
const data = await res.json();
|
||||
const { session_id, url } = data;
|
||||
|
||||
setSessionId(session_id);
|
||||
setAuthStatus("polling");
|
||||
window.open(url, "_blank", "noopener,noreferrer");
|
||||
|
||||
// Start polling the status endpoint every 2s
|
||||
pollIntervalRef.current = setInterval(async () => {
|
||||
try {
|
||||
const pollRes = await fetch(
|
||||
`/api/v1/ppt/codex/auth/status/${session_id}`
|
||||
);
|
||||
if (!pollRes.ok) return;
|
||||
const pollData: StatusResponse = await pollRes.json();
|
||||
|
||||
if (pollData.status === "success") {
|
||||
stopPolling();
|
||||
setAuthStatus("authenticated");
|
||||
setAccountId(pollData.account_id ?? null);
|
||||
setSessionId(null);
|
||||
// Set a sensible default model if none chosen
|
||||
if (!codexModel) {
|
||||
onInputChange(DEFAULT_CODEX_MODEL, "codex_model");
|
||||
}
|
||||
toast.success("Signed in to ChatGPT successfully");
|
||||
} else if (pollData.status === "failed") {
|
||||
stopPolling();
|
||||
setAuthStatus("unauthenticated");
|
||||
toast.error("Authentication failed. Please try again.");
|
||||
}
|
||||
} catch {
|
||||
// keep polling on transient errors
|
||||
}
|
||||
}, 2000);
|
||||
} catch (err) {
|
||||
toast.error("Failed to start sign-in flow");
|
||||
setAuthStatus("unauthenticated");
|
||||
}
|
||||
};
|
||||
|
||||
const handleManualExchange = async () => {
|
||||
if (!sessionId || !manualCode.trim()) return;
|
||||
setIsExchanging(true);
|
||||
try {
|
||||
const res = await fetch("/api/v1/ppt/codex/auth/exchange", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ session_id: sessionId, code: manualCode.trim() }),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const err = await res.json().catch(() => ({}));
|
||||
throw new Error(err.detail || "Exchange failed");
|
||||
}
|
||||
const data = await res.json();
|
||||
stopPolling();
|
||||
setAuthStatus("authenticated");
|
||||
setAccountId(data.account_id);
|
||||
setSessionId(null);
|
||||
setManualCode("");
|
||||
if (!codexModel) {
|
||||
onInputChange(DEFAULT_CODEX_MODEL, "codex_model");
|
||||
}
|
||||
toast.success("Signed in to ChatGPT successfully");
|
||||
} catch (err: any) {
|
||||
toast.error(err.message || "Code exchange failed");
|
||||
} finally {
|
||||
setIsExchanging(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCancelPolling = () => {
|
||||
stopPolling();
|
||||
setSessionId(null);
|
||||
setManualCode("");
|
||||
setAuthStatus("unauthenticated");
|
||||
};
|
||||
|
||||
const handleSignOut = async () => {
|
||||
setIsLoggingOut(true);
|
||||
try {
|
||||
await fetch("/api/v1/ppt/codex/auth/logout", { method: "POST" });
|
||||
setAuthStatus("unauthenticated");
|
||||
setAccountId(null);
|
||||
onInputChange("", "codex_model");
|
||||
toast.success("Signed out from ChatGPT");
|
||||
} catch {
|
||||
toast.error("Sign out failed");
|
||||
} finally {
|
||||
setIsLoggingOut(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRefreshToken = async () => {
|
||||
setIsRefreshing(true);
|
||||
try {
|
||||
const res = await fetch("/api/v1/ppt/codex/auth/refresh", {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) throw new Error("Refresh failed");
|
||||
const data = await res.json();
|
||||
if (data.account_id) setAccountId(data.account_id);
|
||||
toast.success("Token refreshed successfully");
|
||||
} catch {
|
||||
toast.error("Token refresh failed. Please sign in again.");
|
||||
setAuthStatus("unauthenticated");
|
||||
} finally {
|
||||
setIsRefreshing(false);
|
||||
}
|
||||
};
|
||||
|
||||
// ─── Checking ────────────────────────────────────────────────────────────
|
||||
if (authStatus === "checking") {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12 gap-3 text-gray-500">
|
||||
<Loader2 className="w-5 h-5 animate-spin" />
|
||||
<span className="text-sm">Checking authentication status…</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Polling / waiting ───────────────────────────────────────────────────
|
||||
if (authStatus === "polling") {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="flex flex-col items-center gap-4 py-8 px-4 bg-blue-50 rounded-xl border border-blue-100">
|
||||
<Loader2 className="w-8 h-8 text-blue-500 animate-spin" />
|
||||
<div className="text-center">
|
||||
<p className="text-sm font-medium text-blue-900">
|
||||
Waiting for authentication…
|
||||
</p>
|
||||
<p className="text-xs text-blue-600 mt-1">
|
||||
Complete the sign-in in the browser tab that just opened.
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleCancelPolling}
|
||||
className="text-gray-600"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Manual fallback */}
|
||||
<div className="space-y-3">
|
||||
<p className="text-sm font-medium text-gray-700">
|
||||
Didn't get redirected automatically?
|
||||
</p>
|
||||
<p className="text-xs text-gray-500">
|
||||
After completing the sign-in, paste the full redirect URL or
|
||||
authorization code below.
|
||||
</p>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Paste redirect URL or authorization code…"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors text-sm"
|
||||
value={manualCode}
|
||||
onChange={(e) => setManualCode(e.target.value)}
|
||||
/>
|
||||
<Button
|
||||
onClick={handleManualExchange}
|
||||
disabled={isExchanging || !manualCode.trim()}
|
||||
className="w-full"
|
||||
>
|
||||
{isExchanging ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Loader2 className="w-4 h-4 animate-spin" />
|
||||
Exchanging…
|
||||
</div>
|
||||
) : (
|
||||
"Submit Code"
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Authenticated ───────────────────────────────────────────────────────
|
||||
if (authStatus === "authenticated") {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* Account info */}
|
||||
<div className="flex items-center gap-3 p-4 bg-green-50 rounded-xl border border-green-100">
|
||||
<UserCheck className="w-6 h-6 text-green-600 shrink-0" />
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="text-sm font-medium text-green-900">
|
||||
Signed in to ChatGPT
|
||||
</p>
|
||||
{accountId && (
|
||||
<p className="text-xs text-green-700 truncate mt-0.5">
|
||||
Account: {accountId}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex gap-2 shrink-0">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleRefreshToken}
|
||||
disabled={isRefreshing}
|
||||
title="Refresh access token"
|
||||
className="text-gray-600 border-gray-300"
|
||||
>
|
||||
{isRefreshing ? (
|
||||
<Loader2 className="w-3.5 h-3.5 animate-spin" />
|
||||
) : (
|
||||
<RefreshCw className="w-3.5 h-3.5" />
|
||||
)}
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleSignOut}
|
||||
disabled={isLoggingOut}
|
||||
className="text-red-600 border-red-200 hover:bg-red-50"
|
||||
>
|
||||
{isLoggingOut ? (
|
||||
<Loader2 className="w-3.5 h-3.5 animate-spin" />
|
||||
) : (
|
||||
<LogOut className="w-3.5 h-3.5" />
|
||||
)}
|
||||
<span className="ml-1.5">Sign out</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Model selection */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-700 mb-3">
|
||||
Select ChatGPT Model
|
||||
</label>
|
||||
<Popover open={openModelSelect} onOpenChange={setOpenModelSelect}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
aria-expanded={openModelSelect}
|
||||
className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between"
|
||||
>
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{codexModel
|
||||
? (CHATGPT_MODELS.find((m) => m.id === codexModel)?.name ?? codexModel)
|
||||
: "Select a model"}
|
||||
</span>
|
||||
<ChevronsUpDown className="w-4 h-4 text-gray-500" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className="p-0"
|
||||
align="start"
|
||||
style={{ width: "var(--radix-popover-trigger-width)" }}
|
||||
>
|
||||
<Command>
|
||||
<CommandInput placeholder="Search models…" />
|
||||
<CommandList>
|
||||
<CommandEmpty>No model found.</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{CHATGPT_MODELS.map((model) => (
|
||||
<CommandItem
|
||||
key={model.id}
|
||||
value={model.id}
|
||||
onSelect={(value) => {
|
||||
onInputChange(value, "codex_model");
|
||||
setOpenModelSelect(false);
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
"mr-2 h-4 w-4",
|
||||
codexModel === model.id ? "opacity-100" : "opacity-0"
|
||||
)}
|
||||
/>
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{model.name}
|
||||
</span>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
<p className="mt-2 text-xs text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400" />
|
||||
Model availability depends on your ChatGPT subscription tier.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Unauthenticated ─────────────────────────────────────────────────────
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="p-4 bg-gray-50 rounded-xl border border-gray-200">
|
||||
<h3 className="text-sm font-semibold text-gray-900 mb-1">
|
||||
ChatGPT Plus / Pro
|
||||
</h3>
|
||||
<p className="text-sm text-gray-600">
|
||||
Sign in with your OpenAI account to use ChatGPT models directly via
|
||||
OAuth — no API key required.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
onClick={handleSignIn}
|
||||
className="w-full h-12 gap-2 bg-[#10a37f] hover:bg-[#0e8f6f] text-white"
|
||||
>
|
||||
<LogIn className="w-4 h-4" />
|
||||
Sign in with ChatGPT
|
||||
</Button>
|
||||
|
||||
<p className="text-xs text-gray-500 flex items-start gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400 mt-1.5 shrink-0" />
|
||||
A browser window will open for you to authenticate with your OpenAI
|
||||
account. Your credentials are stored locally and never shared.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -19,6 +19,7 @@ import GoogleConfig from "./GoogleConfig";
|
|||
import AnthropicConfig from "./AnthropicConfig";
|
||||
import OllamaConfig from "./OllamaConfig";
|
||||
import CustomConfig from "./CustomConfig";
|
||||
import CodexConfig from "./CodexConfig";
|
||||
import {
|
||||
updateLLMConfig,
|
||||
changeProvider as changeProviderUtil,
|
||||
|
|
@ -95,7 +96,8 @@ export default function LLMProviderSelection({
|
|||
(llmConfig.LLM === "google" && !llmConfig.GOOGLE_MODEL) ||
|
||||
(llmConfig.LLM === "ollama" && !llmConfig.OLLAMA_MODEL) ||
|
||||
(llmConfig.LLM === "custom" && !llmConfig.CUSTOM_MODEL) ||
|
||||
(llmConfig.LLM === "anthropic" && !llmConfig.ANTHROPIC_MODEL);
|
||||
(llmConfig.LLM === "anthropic" && !llmConfig.ANTHROPIC_MODEL) ||
|
||||
(llmConfig.LLM === "codex" && !llmConfig.CODEX_MODEL);
|
||||
|
||||
const needsProviderApiKey =
|
||||
(llmConfig.LLM === "openai" && !llmConfig.OPENAI_API_KEY) ||
|
||||
|
|
@ -335,12 +337,13 @@ export default function LLMProviderSelection({
|
|||
onValueChange={handleProviderChange}
|
||||
className="w-full"
|
||||
>
|
||||
<TabsList className="grid w-full grid-cols-5 bg-transparent h-10">
|
||||
<TabsList className="grid w-full grid-cols-6 bg-transparent h-10">
|
||||
<TabsTrigger value="openai">OpenAI</TabsTrigger>
|
||||
<TabsTrigger value="google">Google</TabsTrigger>
|
||||
<TabsTrigger value="anthropic">Anthropic</TabsTrigger>
|
||||
<TabsTrigger value="ollama">Ollama</TabsTrigger>
|
||||
<TabsTrigger value="custom">Custom</TabsTrigger>
|
||||
<TabsTrigger value="codex">ChatGPT</TabsTrigger>
|
||||
</TabsList>
|
||||
</Tabs>
|
||||
</div>
|
||||
|
|
@ -404,6 +407,14 @@ export default function LLMProviderSelection({
|
|||
onInputChange={input_field_changed}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
||||
{/* ChatGPT / Codex Content */}
|
||||
<TabsContent value="codex" className="mt-6">
|
||||
<CodexConfig
|
||||
codexModel={llmConfig.CODEX_MODEL || ""}
|
||||
onInputChange={input_field_changed}
|
||||
/>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
|
||||
{/* Image Generation Toggle */}
|
||||
|
|
@ -652,6 +663,8 @@ export default function LLMProviderSelection({
|
|||
? llmConfig.GOOGLE_MODEL ?? "xxxxx"
|
||||
: llmConfig.LLM === "openai"
|
||||
? llmConfig.OPENAI_MODEL ?? "xxxxx"
|
||||
: llmConfig.LLM === "codex"
|
||||
? llmConfig.CODEX_MODEL ?? "xxxxx"
|
||||
: "xxxxx"}{" "}
|
||||
for text generation{" "}
|
||||
{isImageGenerationDisabled ? (
|
||||
|
|
|
|||
|
|
@ -43,6 +43,13 @@ export interface LLMConfig {
|
|||
EXTENDED_REASONING?: boolean;
|
||||
WEB_GROUNDING?: boolean;
|
||||
|
||||
// Codex OAuth (ChatGPT)
|
||||
CODEX_MODEL?: string;
|
||||
CODEX_ACCESS_TOKEN?: string;
|
||||
CODEX_REFRESH_TOKEN?: string;
|
||||
CODEX_TOKEN_EXPIRES?: string;
|
||||
CODEX_ACCOUNT_ID?: string;
|
||||
|
||||
// Only used in UI settings
|
||||
USE_CUSTOM_URL?: boolean;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -116,4 +116,9 @@ export const LLM_PROVIDERS: Record<string, LLMProviderOption> = {
|
|||
label: "Custom",
|
||||
description: "Custom LLM",
|
||||
},
|
||||
codex: {
|
||||
value: "codex",
|
||||
label: "ChatGPT",
|
||||
description: "ChatGPT Plus/Pro via OAuth",
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ export const updateLLMConfig = (
|
|||
comfyui_workflow: "COMFYUI_WORKFLOW",
|
||||
dall_e_3_quality: "DALL_E_3_QUALITY",
|
||||
gpt_image_1_5_quality: "GPT_IMAGE_1_5_QUALITY",
|
||||
codex_model: "CODEX_MODEL",
|
||||
};
|
||||
|
||||
const configKey = fieldMappings[field];
|
||||
|
|
@ -77,7 +78,7 @@ export const changeProvider = (
|
|||
} else if (provider === "google") {
|
||||
newConfig.IMAGE_PROVIDER = "gemini_flash";
|
||||
} else {
|
||||
newConfig.IMAGE_PROVIDER = "pexels"; // default for ollama and custom
|
||||
newConfig.IMAGE_PROVIDER = "pexels"; // default for ollama, custom, codex
|
||||
}
|
||||
|
||||
return newConfig;
|
||||
|
|
|
|||
|
|
@ -59,6 +59,11 @@ export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
|
|||
llmConfig.CUSTOM_MODEL !== null &&
|
||||
llmConfig.CUSTOM_MODEL !== undefined;
|
||||
|
||||
const isCodexConfigValid =
|
||||
llmConfig.CODEX_MODEL !== "" &&
|
||||
llmConfig.CODEX_MODEL !== null &&
|
||||
llmConfig.CODEX_MODEL !== undefined;
|
||||
|
||||
const shouldValidateImages = !llmConfig.DISABLE_IMAGE_GENERATION;
|
||||
|
||||
const isImageConfigValid = () => {
|
||||
|
|
@ -96,6 +101,8 @@ export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
|
|||
? isOllamaConfigValid
|
||||
: llmConfig.LLM === "custom"
|
||||
? isCustomConfigValid
|
||||
: llmConfig.LLM === "codex"
|
||||
? isCodexConfigValid
|
||||
: false;
|
||||
|
||||
return isLLMConfigValid && isImageConfigValid();
|
||||
|
|
|
|||
7
start.js
7
start.js
|
|
@ -65,7 +65,7 @@ const setupUserConfigFromEnv = () => {
|
|||
existingConfig = JSON.parse(readFileSync(userConfigPath, "utf8"));
|
||||
}
|
||||
|
||||
if (!["ollama", "openai", "google"].includes(existingConfig.LLM)) {
|
||||
if (!["ollama", "openai", "google", "anthropic", "custom", "codex"].includes(existingConfig.LLM)) {
|
||||
existingConfig.LLM = undefined;
|
||||
}
|
||||
|
||||
|
|
@ -103,6 +103,11 @@ const setupUserConfigFromEnv = () => {
|
|||
process.env.DALL_E_3_QUALITY || existingConfig.DALL_E_3_QUALITY,
|
||||
GPT_IMAGE_1_5_QUALITY:
|
||||
process.env.GPT_IMAGE_1_5_QUALITY || existingConfig.GPT_IMAGE_1_5_QUALITY,
|
||||
CODEX_MODEL: process.env.CODEX_MODEL || existingConfig.CODEX_MODEL,
|
||||
CODEX_ACCESS_TOKEN: existingConfig.CODEX_ACCESS_TOKEN,
|
||||
CODEX_REFRESH_TOKEN: existingConfig.CODEX_REFRESH_TOKEN,
|
||||
CODEX_TOKEN_EXPIRES: existingConfig.CODEX_TOKEN_EXPIRES,
|
||||
CODEX_ACCOUNT_ID: existingConfig.CODEX_ACCOUNT_ID,
|
||||
};
|
||||
|
||||
writeFileSync(userConfigPath, JSON.stringify(userConfig));
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue