fix(llm_client): google structured stream and tool call issue
This commit is contained in:
parent
c1b56747c9
commit
998e6be325
3 changed files with 67 additions and 50 deletions
|
|
@ -66,15 +66,12 @@ async def stream_outlines(
|
|||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk}),
|
||||
).to_string()
|
||||
|
||||
presentation_outlines_text += chunk
|
||||
|
||||
try:
|
||||
presentation_outlines_json = json.loads(presentation_outlines_text)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
with open("./debug/outlines.txt", "w") as f:
|
||||
f.write(presentation_outlines_text)
|
||||
print(presentation_outlines_text)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Failed to generate presentation outlines. Please try again.",
|
||||
|
|
|
|||
|
|
@ -8,7 +8,13 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
)
|
||||
from google import genai
|
||||
from google.genai.types import Content as GoogleContent, Part as GoogleContentPart
|
||||
from google.genai.types import GenerateContentConfig, GoogleSearch
|
||||
from google.genai.types import (
|
||||
GenerateContentConfig,
|
||||
GoogleSearch,
|
||||
ToolConfig as GoogleToolConfig,
|
||||
FunctionCallingConfig as GoogleFunctionCallingConfig,
|
||||
FunctionCallingConfigMode as GoogleFunctionCallingConfigMode,
|
||||
)
|
||||
from google.genai.types import Tool as GoogleTool
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message as AnthropicMessage
|
||||
|
|
@ -48,7 +54,7 @@ from utils.get_env import (
|
|||
)
|
||||
from utils.llm_provider import get_llm_provider, get_model
|
||||
from utils.parsers import parse_bool_or_none
|
||||
from utils.schema_utils import ensure_strict_json_schema
|
||||
from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema
|
||||
|
||||
|
||||
class LLMClient:
|
||||
|
|
@ -565,7 +571,7 @@ class LLMClient:
|
|||
{
|
||||
"name": "ResponseSchema",
|
||||
"description": "Provide response to the user",
|
||||
"parameters_json_schema": response_format,
|
||||
"parameters": flatten_json_schema(response_format),
|
||||
}
|
||||
]
|
||||
)
|
||||
|
|
@ -577,6 +583,15 @@ class LLMClient:
|
|||
contents=self._get_google_messages(messages),
|
||||
config=GenerateContentConfig(
|
||||
tools=google_tools,
|
||||
tool_config=(
|
||||
GoogleToolConfig(
|
||||
function_calling_config=GoogleFunctionCallingConfig(
|
||||
mode=GoogleFunctionCallingConfigMode.ANY,
|
||||
),
|
||||
)
|
||||
if tools
|
||||
else None
|
||||
),
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json" if not tools else None,
|
||||
response_json_schema=response_format if not tools else None,
|
||||
|
|
@ -1256,7 +1271,7 @@ class LLMClient:
|
|||
{
|
||||
"name": "ResponseSchema",
|
||||
"description": "Provide response to the user",
|
||||
"parameters_json_schema": response_format,
|
||||
"parameters": flatten_json_schema(response_format),
|
||||
}
|
||||
]
|
||||
)
|
||||
|
|
@ -1269,54 +1284,59 @@ class LLMClient:
|
|||
contents=self._get_google_messages(messages),
|
||||
config=GenerateContentConfig(
|
||||
tools=google_tools,
|
||||
tool_config=(
|
||||
GoogleToolConfig(
|
||||
function_calling_config=GoogleFunctionCallingConfig(
|
||||
mode=GoogleFunctionCallingConfigMode.ANY,
|
||||
),
|
||||
)
|
||||
if tools
|
||||
else None
|
||||
),
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json" if not tools else None,
|
||||
response_json_schema=response_format if not tools else None,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
):
|
||||
if event.text:
|
||||
yield event.text
|
||||
for each_part in event.candidates[0].content.parts:
|
||||
if each_part.text:
|
||||
yield each_part.text
|
||||
|
||||
if event.function_calls:
|
||||
tool_calls = [
|
||||
GoogleToolCall(
|
||||
name=each.name,
|
||||
arguments=each.args,
|
||||
if each_part.function_call:
|
||||
if each_part.function_call.name == "ResponseSchema":
|
||||
has_response_schema_tool_call = True
|
||||
if each_part.function_call.args:
|
||||
yield json.dumps(each_part.function_call.args)
|
||||
|
||||
tool_calls.append(
|
||||
GoogleToolCall(
|
||||
name=each_part.function_call.name,
|
||||
arguments=each_part.function_call.args,
|
||||
)
|
||||
)
|
||||
for each in event.function_calls
|
||||
]
|
||||
|
||||
for each in tool_calls:
|
||||
if each.name == "ResponseSchema":
|
||||
has_response_schema_tool_call = True
|
||||
if each.arguments:
|
||||
yield json.dumps(each.arguments)
|
||||
|
||||
if has_response_schema_tool_call:
|
||||
continue
|
||||
|
||||
if tool_calls:
|
||||
tool_call_messages = (
|
||||
await self.tool_calls_handler.handle_tool_calls_google(tool_calls)
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
GoogleAssistantMessage(
|
||||
role="assistant",
|
||||
content=event.candidates[0].content,
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
async for event in self._stream_google_structured(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
depth=depth + 1,
|
||||
):
|
||||
yield event
|
||||
if tool_calls and not has_response_schema_tool_call:
|
||||
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google(
|
||||
tool_calls
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
GoogleAssistantMessage(
|
||||
role="assistant",
|
||||
content=event.candidates[0].content,
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
async for event in self._stream_google_structured(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
depth=depth + 1,
|
||||
):
|
||||
yield event
|
||||
|
||||
async def _stream_anthropic_structured(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -89,9 +89,9 @@ class LLMToolCallsHandler:
|
|||
|
||||
def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool):
|
||||
parsed = self.parse_tool_openai(tool)
|
||||
# parsed["function"]["parameters"] = flatten_json_schema(
|
||||
# parsed["function"]["parameters"]
|
||||
# )
|
||||
parsed["function"]["parameters"] = flatten_json_schema(
|
||||
parsed["function"]["parameters"]
|
||||
)
|
||||
return {
|
||||
"name": parsed["function"]["name"],
|
||||
"description": parsed["function"]["description"],
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue