presenton/servers/fastapi/utils/llm_utils.py
sudipnext b20199a4e3 feat: Integrate Vertex AI and Azure OpenAI support
- Added environment variables for Vertex AI and Azure OpenAI configurations in docker-compose and user configuration models.
- Updated the application logic to handle Vertex and Azure as new LLM providers, including validation and API key management.
- Enhanced the UI components to support model selection and API key input for Vertex and Azure.
- Updated relevant utility functions and constants to accommodate the new providers.
- Ensured proper error handling for configuration requirements specific to Vertex and Azure.
2026-04-30 06:03:39 +05:45

171 lines
4.9 KiB
Python

import asyncio
import json
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Optional
import dirtyjson
from llmai.shared import (
LLMTool,
Message,
ResponseFormat,
normalize_content_parts,
)
from llmai.shared.tools import Tool # type: ignore[import-not-found]
from pydantic import BaseModel
from enums.llm_provider import LLMProvider
from utils.llm_config import get_extra_body
from utils.llm_provider import get_llm_provider
from utils.schema_utils import flatten_json_schema
def _tools_for_google_gemini(tools: list[LLMTool]) -> list[LLMTool]:
"""Gemini's Python SDK rejects ``$ref`` / ``$defs`` in function parameters; inline them."""
converted: list[LLMTool] = []
for tool in tools:
if not isinstance(tool, Tool):
converted.append(tool)
continue
schema_obj = tool.input_schema
if isinstance(schema_obj, dict):
raw = dict(schema_obj)
elif isinstance(schema_obj, type) and issubclass(schema_obj, BaseModel):
raw = schema_obj.model_json_schema()
elif isinstance(schema_obj, BaseModel):
raw = schema_obj.__class__.model_json_schema()
else:
converted.append(tool)
continue
flat = flatten_json_schema(raw)
converted.append(
Tool(
name=tool.name,
description=tool.description,
schema=flat,
strict=tool.strict,
)
)
return converted
def get_generate_kwargs(
model: str,
messages: Sequence[Message],
max_tokens: Optional[int] = None,
tools: Optional[list[LLMTool]] = None,
response_format: Optional[ResponseFormat] = None,
stream: bool = False,
) -> dict[str, Any]:
kwargs: dict[str, Any] = {
"model": model,
"messages": list(messages),
"stream": stream,
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if tools:
if get_llm_provider() in (LLMProvider.GOOGLE, LLMProvider.VERTEX):
kwargs["tools"] = _tools_for_google_gemini(tools)
else:
kwargs["tools"] = tools
if response_format is not None:
kwargs["response_format"] = response_format
extra_body = get_extra_body()
if extra_body:
kwargs["extra_body"] = extra_body
return kwargs
def extract_text(content: Any) -> Optional[str]:
if content is None:
return None
if isinstance(content, str):
return content
if isinstance(content, Sequence) and not isinstance(content, (bytes, bytearray)):
parts: list[str] = []
for part in content:
if isinstance(part, str):
parts.append(part)
continue
text = getattr(part, "text", None)
if isinstance(text, str):
parts.append(text)
joined = "".join(parts)
return joined or None
text = getattr(content, "text", None)
if isinstance(text, str):
return text
return None
def extract_structured_content(content: Any) -> Optional[dict]:
if content is None:
return None
if isinstance(content, dict):
return content
if hasattr(content, "model_dump"):
dumped = content.model_dump(mode="json")
if isinstance(dumped, dict):
return dumped
raw_text = extract_text(content)
if not raw_text:
return None
try:
parsed = dirtyjson.loads(raw_text)
except Exception:
return None
if isinstance(parsed, dict):
return dict(parsed)
return None
def serialize_structured_content(content: Any) -> Optional[str]:
parsed = extract_structured_content(content)
if parsed is not None:
return json.dumps(parsed, ensure_ascii=False)
raw_text = extract_text(content)
if raw_text:
return raw_text
return None
def message_content_to_text(content: Sequence[Any] | str | None) -> Optional[str]:
joined = "".join(
part.text
for part in normalize_content_parts(content)
if isinstance(getattr(part, "text", None), str)
)
return joined or None
async def stream_generate_events(client: Any, **kwargs) -> AsyncGenerator[Any, None]:
loop = asyncio.get_running_loop()
queue: asyncio.Queue[Any] = asyncio.Queue()
sentinel = object()
def worker():
try:
for event in client.generate(**kwargs):
loop.call_soon_threadsafe(queue.put_nowait, event)
except Exception as exc:
loop.call_soon_threadsafe(queue.put_nowait, exc)
finally:
loop.call_soon_threadsafe(queue.put_nowait, sentinel)
worker_task = asyncio.create_task(asyncio.to_thread(worker))
try:
while True:
item = await queue.get()
if item is sentinel:
break
if isinstance(item, Exception):
raise item
yield item
finally:
await worker_task