presenton/servers/fastapi/utils/llm_utils.py

134 lines
3.6 KiB
Python

import asyncio
import json
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Optional
import dirtyjson
from llmai.shared import (
LLMTool,
Message,
ResponseFormat,
normalize_content_parts,
)
from utils.llm_config import get_extra_body
def get_generate_kwargs(
model: str,
messages: Sequence[Message],
max_tokens: Optional[int] = None,
tools: Optional[list[LLMTool]] = None,
response_format: Optional[ResponseFormat] = None,
stream: bool = False,
) -> dict[str, Any]:
kwargs: dict[str, Any] = {
"model": model,
"messages": list(messages),
"stream": stream,
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if tools:
kwargs["tools"] = tools
if response_format is not None:
kwargs["response_format"] = response_format
extra_body = get_extra_body()
if extra_body:
kwargs["extra_body"] = extra_body
return kwargs
def extract_text(content: Any) -> Optional[str]:
if content is None:
return None
if isinstance(content, str):
return content
if isinstance(content, Sequence) and not isinstance(content, (bytes, bytearray)):
parts: list[str] = []
for part in content:
if isinstance(part, str):
parts.append(part)
continue
text = getattr(part, "text", None)
if isinstance(text, str):
parts.append(text)
joined = "".join(parts)
return joined or None
text = getattr(content, "text", None)
if isinstance(text, str):
return text
return None
def extract_structured_content(content: Any) -> Optional[dict]:
if content is None:
return None
if isinstance(content, dict):
return content
if hasattr(content, "model_dump"):
dumped = content.model_dump(mode="json")
if isinstance(dumped, dict):
return dumped
raw_text = extract_text(content)
if not raw_text:
return None
try:
parsed = dirtyjson.loads(raw_text)
except Exception:
return None
if isinstance(parsed, dict):
return dict(parsed)
return None
def serialize_structured_content(content: Any) -> Optional[str]:
parsed = extract_structured_content(content)
if parsed is not None:
return json.dumps(parsed, ensure_ascii=False)
raw_text = extract_text(content)
if raw_text:
return raw_text
return None
def message_content_to_text(content: Sequence[Any] | str | None) -> Optional[str]:
joined = "".join(
part.text
for part in normalize_content_parts(content)
if isinstance(getattr(part, "text", None), str)
)
return joined or None
async def stream_generate_events(client: Any, **kwargs) -> AsyncGenerator[Any, None]:
loop = asyncio.get_running_loop()
queue: asyncio.Queue[Any] = asyncio.Queue()
sentinel = object()
def worker():
try:
for event in client.generate(**kwargs):
loop.call_soon_threadsafe(queue.put_nowait, event)
except Exception as exc:
loop.call_soon_threadsafe(queue.put_nowait, exc)
finally:
loop.call_soon_threadsafe(queue.put_nowait, sentinel)
worker_task = asyncio.create_task(asyncio.to_thread(worker))
try:
while True:
item = await queue.get()
if item is sentinel:
break
if isinstance(item, Exception):
raise item
yield item
finally:
await worker_task