fix(llm_client): google structured stream and tool call issue

This commit is contained in:
sauravniraula 2025-08-11 16:33:59 +05:45
parent c1b56747c9
commit 998e6be325
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
3 changed files with 67 additions and 50 deletions

View file

@ -66,15 +66,12 @@ async def stream_outlines(
event="response",
data=json.dumps({"type": "chunk", "chunk": chunk}),
).to_string()
presentation_outlines_text += chunk
try:
presentation_outlines_json = json.loads(presentation_outlines_text)
except Exception as e:
print(e)
with open("./debug/outlines.txt", "w") as f:
f.write(presentation_outlines_text)
print(presentation_outlines_text)
raise HTTPException(
status_code=400,
detail="Failed to generate presentation outlines. Please try again.",

View file

@ -8,7 +8,13 @@ from openai.types.chat.chat_completion_chunk import (
)
from google import genai
from google.genai.types import Content as GoogleContent, Part as GoogleContentPart
from google.genai.types import GenerateContentConfig, GoogleSearch
from google.genai.types import (
GenerateContentConfig,
GoogleSearch,
ToolConfig as GoogleToolConfig,
FunctionCallingConfig as GoogleFunctionCallingConfig,
FunctionCallingConfigMode as GoogleFunctionCallingConfigMode,
)
from google.genai.types import Tool as GoogleTool
from anthropic import AsyncAnthropic
from anthropic.types import Message as AnthropicMessage
@ -48,7 +54,7 @@ from utils.get_env import (
)
from utils.llm_provider import get_llm_provider, get_model
from utils.parsers import parse_bool_or_none
from utils.schema_utils import ensure_strict_json_schema
from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema
class LLMClient:
@ -565,7 +571,7 @@ class LLMClient:
{
"name": "ResponseSchema",
"description": "Provide response to the user",
"parameters_json_schema": response_format,
"parameters": flatten_json_schema(response_format),
}
]
)
@ -577,6 +583,15 @@ class LLMClient:
contents=self._get_google_messages(messages),
config=GenerateContentConfig(
tools=google_tools,
tool_config=(
GoogleToolConfig(
function_calling_config=GoogleFunctionCallingConfig(
mode=GoogleFunctionCallingConfigMode.ANY,
),
)
if tools
else None
),
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json" if not tools else None,
response_json_schema=response_format if not tools else None,
@ -1256,7 +1271,7 @@ class LLMClient:
{
"name": "ResponseSchema",
"description": "Provide response to the user",
"parameters_json_schema": response_format,
"parameters": flatten_json_schema(response_format),
}
]
)
@ -1269,54 +1284,59 @@ class LLMClient:
contents=self._get_google_messages(messages),
config=GenerateContentConfig(
tools=google_tools,
tool_config=(
GoogleToolConfig(
function_calling_config=GoogleFunctionCallingConfig(
mode=GoogleFunctionCallingConfigMode.ANY,
),
)
if tools
else None
),
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json" if not tools else None,
response_json_schema=response_format if not tools else None,
max_output_tokens=max_tokens,
),
):
if event.text:
yield event.text
for each_part in event.candidates[0].content.parts:
if each_part.text:
yield each_part.text
if event.function_calls:
tool_calls = [
GoogleToolCall(
name=each.name,
arguments=each.args,
if each_part.function_call:
if each_part.function_call.name == "ResponseSchema":
has_response_schema_tool_call = True
if each_part.function_call.args:
yield json.dumps(each_part.function_call.args)
tool_calls.append(
GoogleToolCall(
name=each_part.function_call.name,
arguments=each_part.function_call.args,
)
)
for each in event.function_calls
]
for each in tool_calls:
if each.name == "ResponseSchema":
has_response_schema_tool_call = True
if each.arguments:
yield json.dumps(each.arguments)
if has_response_schema_tool_call:
continue
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_google(tool_calls)
)
new_messages = [
*messages,
GoogleAssistantMessage(
role="assistant",
content=event.candidates[0].content,
),
*tool_call_messages,
]
async for event in self._stream_google_structured(
model=model,
messages=new_messages,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
depth=depth + 1,
):
yield event
if tool_calls and not has_response_schema_tool_call:
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google(
tool_calls
)
new_messages = [
*messages,
GoogleAssistantMessage(
role="assistant",
content=event.candidates[0].content,
),
*tool_call_messages,
]
async for event in self._stream_google_structured(
model=model,
messages=new_messages,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
depth=depth + 1,
):
yield event
async def _stream_anthropic_structured(
self,

View file

@ -89,9 +89,9 @@ class LLMToolCallsHandler:
def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool):
parsed = self.parse_tool_openai(tool)
# parsed["function"]["parameters"] = flatten_json_schema(
# parsed["function"]["parameters"]
# )
parsed["function"]["parameters"] = flatten_json_schema(
parsed["function"]["parameters"]
)
return {
"name": parsed["function"]["name"],
"description": parsed["function"]["description"],