Merge pull request #177 from presenton/feat/tool-call-support-in-custom-llm
feat: adds support for deepseek using custom llm (tool calls for deepseek)
This commit is contained in:
commit
2823b8c12f
3 changed files with 217 additions and 87 deletions
Binary file not shown.
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from google import genai
|
||||
|
|
@ -24,10 +24,17 @@ from utils.schema_utils import ensure_strict_json_schema
|
|||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self, max_tokens: int = 4000):
|
||||
def __init__(self):
|
||||
self.llm_provider = get_llm_provider()
|
||||
self._client = self._get_client()
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Supports json_schema
|
||||
def supports_json_schema(self, model: str) -> bool:
|
||||
if model.startswith("deepseek"):
|
||||
return False
|
||||
if model.startswith("claude"):
|
||||
return False
|
||||
return True
|
||||
|
||||
# ? Clients
|
||||
def _get_client(self):
|
||||
|
|
@ -103,16 +110,26 @@ class LLMClient:
|
|||
return [message for message in messages if message.role == "user"]
|
||||
|
||||
# ? Generate Unstructured Content
|
||||
async def _generate_openai(self, model: str, messages: List[LLMMessage]):
|
||||
async def _generate_openai(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=self.max_tokens,
|
||||
max_completion_tokens=max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
async def _generate_google(self, model: str, messages: List[LLMMessage]):
|
||||
async def _generate_google(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: genai.Client = self._client
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
|
|
@ -121,12 +138,17 @@ class LLMClient:
|
|||
config=GenerateContentConfig(
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="text/plain",
|
||||
max_output_tokens=self.max_tokens,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
)
|
||||
return response.text
|
||||
|
||||
async def _generate_anthropic(self, model: str, messages: List[LLMMessage]):
|
||||
async def _generate_anthropic(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: AsyncAnthropic = self._client
|
||||
response: AnthropicMessage = await client.messages.create(
|
||||
model=model,
|
||||
|
|
@ -135,7 +157,7 @@ class LLMClient:
|
|||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
max_tokens=max_tokens or 4000,
|
||||
)
|
||||
text = ""
|
||||
for content in response.content:
|
||||
|
|
@ -145,25 +167,34 @@ class LLMClient:
|
|||
return None
|
||||
return text
|
||||
|
||||
async def _generate_ollama(self, model: str, messages: List[LLMMessage]):
|
||||
return await self._generate_openai(model, messages)
|
||||
async def _generate_ollama(
|
||||
self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None
|
||||
):
|
||||
return await self._generate_openai(model, messages, max_tokens)
|
||||
|
||||
async def _generate_custom(self, model: str, messages: List[LLMMessage]):
|
||||
return await self._generate_openai(model, messages)
|
||||
async def _generate_custom(
|
||||
self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None
|
||||
):
|
||||
return await self._generate_openai(model, messages, max_tokens)
|
||||
|
||||
async def generate(self, model: str, messages: List[LLMMessage]):
|
||||
async def generate(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
content = None
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
content = await self._generate_openai(model, messages)
|
||||
content = await self._generate_openai(model, messages, max_tokens)
|
||||
case LLMProvider.GOOGLE:
|
||||
content = await self._generate_google(model, messages)
|
||||
content = await self._generate_google(model, messages, max_tokens)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
content = await self._generate_anthropic(model, messages)
|
||||
content = await self._generate_anthropic(model, messages, max_tokens)
|
||||
case LLMProvider.OLLAMA:
|
||||
content = await self._generate_ollama(model, messages)
|
||||
content = await self._generate_ollama(model, messages, max_tokens)
|
||||
case LLMProvider.CUSTOM:
|
||||
content = await self._generate_custom(model, messages)
|
||||
content = await self._generate_custom(model, messages, max_tokens)
|
||||
if content is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
@ -178,8 +209,10 @@ class LLMClient:
|
|||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
supports_json_schema = self.supports_json_schema(model)
|
||||
response_schema = response_format
|
||||
if strict:
|
||||
response_schema = ensure_strict_json_schema(
|
||||
|
|
@ -187,22 +220,45 @@ class LLMClient:
|
|||
path=(),
|
||||
root=response_schema,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": (
|
||||
if supports_json_schema:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": (
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
"strict": strict,
|
||||
"schema": response_schema,
|
||||
}
|
||||
),
|
||||
},
|
||||
max_completion_tokens=max_tokens,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
tools=[
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
"strict": strict,
|
||||
"schema": response_schema,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ResponseSchema",
|
||||
"description": "A response to the user's message",
|
||||
"strict": strict,
|
||||
"parameters": response_format,
|
||||
},
|
||||
}
|
||||
),
|
||||
},
|
||||
max_completion_tokens=self.max_tokens,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
],
|
||||
tool_choice="required",
|
||||
max_completion_tokens=max_tokens,
|
||||
)
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
if tool_calls:
|
||||
content = tool_calls[0].function.arguments
|
||||
|
||||
if content:
|
||||
return json.loads(content)
|
||||
return None
|
||||
|
|
@ -212,6 +268,7 @@ class LLMClient:
|
|||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: genai.Client = self._client
|
||||
response = await asyncio.to_thread(
|
||||
|
|
@ -222,7 +279,7 @@ class LLMClient:
|
|||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=response_format,
|
||||
max_output_tokens=self.max_tokens,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
)
|
||||
content = None
|
||||
|
|
@ -236,6 +293,7 @@ class LLMClient:
|
|||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: AsyncAnthropic = self._client
|
||||
response: AnthropicMessage = await client.messages.create(
|
||||
|
|
@ -245,7 +303,7 @@ class LLMClient:
|
|||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
max_tokens=max_tokens or 4000,
|
||||
tools=[
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
|
|
@ -253,6 +311,10 @@ class LLMClient:
|
|||
"input_schema": response_format,
|
||||
}
|
||||
],
|
||||
tool_choice={
|
||||
"type": "tool",
|
||||
"name": "ResponseSchema",
|
||||
},
|
||||
)
|
||||
content: dict | None = None
|
||||
for content_block in response.content:
|
||||
|
|
@ -267,9 +329,10 @@ class LLMClient:
|
|||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
return await self._generate_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
|
||||
async def _generate_custom_structured(
|
||||
|
|
@ -278,9 +341,10 @@ class LLMClient:
|
|||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
return await self._generate_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
|
||||
async def generate_structured(
|
||||
|
|
@ -289,28 +353,29 @@ class LLMClient:
|
|||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> dict:
|
||||
content = None
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
content = await self._generate_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
content = await self._generate_google_structured(
|
||||
model, messages, response_format
|
||||
model, messages, response_format, max_tokens
|
||||
)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
content = await self._generate_anthropic_structured(
|
||||
model, messages, response_format
|
||||
model, messages, response_format, max_tokens
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
content = await self._generate_ollama_structured(
|
||||
model, messages, response_format, strict
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
case LLMProvider.CUSTOM:
|
||||
content = await self._generate_custom_structured(
|
||||
model, messages, response_format, strict
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
if content is None:
|
||||
raise HTTPException(
|
||||
|
|
@ -320,18 +385,28 @@ class LLMClient:
|
|||
return content
|
||||
|
||||
# ? Stream Unstructured Content
|
||||
async def _stream_openai(self, model: str, messages: List[LLMMessage]):
|
||||
async def _stream_openai(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
async with client.chat.completions.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=self.max_tokens,
|
||||
max_completion_tokens=max_tokens,
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "content.delta":
|
||||
yield event.delta
|
||||
|
||||
async def _stream_google(self, model: str, messages: List[LLMMessage]):
|
||||
async def _stream_google(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: genai.Client = self._client
|
||||
async for event in iterator_to_async(client.models.generate_content_stream)(
|
||||
model=model,
|
||||
|
|
@ -339,13 +414,18 @@ class LLMClient:
|
|||
config=GenerateContentConfig(
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="text/plain",
|
||||
max_output_tokens=self.max_tokens,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
):
|
||||
if event.text:
|
||||
yield event.text
|
||||
|
||||
async def _stream_anthropic(self, model: str, messages: List[LLMMessage]):
|
||||
async def _stream_anthropic(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: AsyncAnthropic = self._client
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
|
|
@ -354,31 +434,43 @@ class LLMClient:
|
|||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
max_tokens=max_tokens or 4000,
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
event: AnthropicMessageStreamEvent = event
|
||||
if event.type == "text" and isinstance(event.text, str):
|
||||
yield event.text
|
||||
|
||||
def _stream_ollama(self, model: str, messages: List[LLMMessage]):
|
||||
return self._stream_openai(model, messages)
|
||||
def _stream_ollama(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
return self._stream_openai(model, messages, max_tokens)
|
||||
|
||||
def _stream_custom(self, model: str, messages: List[LLMMessage]):
|
||||
return self._stream_openai(model, messages)
|
||||
def _stream_custom(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
return self._stream_openai(model, messages, max_tokens)
|
||||
|
||||
def stream(self, model: str, messages: List[LLMMessage]):
|
||||
def stream(
|
||||
self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None
|
||||
):
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
return self._stream_openai(model, messages)
|
||||
return self._stream_openai(model, messages, max_tokens)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self._stream_google(model, messages)
|
||||
return self._stream_google(model, messages, max_tokens)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return self._stream_anthropic(model, messages)
|
||||
return self._stream_anthropic(model, messages, max_tokens)
|
||||
case LLMProvider.OLLAMA:
|
||||
return self._stream_ollama(model, messages)
|
||||
return self._stream_ollama(model, messages, max_tokens)
|
||||
case LLMProvider.CUSTOM:
|
||||
return self._stream_custom(model, messages)
|
||||
return self._stream_custom(model, messages, max_tokens)
|
||||
|
||||
# ? Stream Structured Content
|
||||
async def _stream_openai_structured(
|
||||
|
|
@ -387,8 +479,10 @@ class LLMClient:
|
|||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
supports_json_schema = self.supports_json_schema(model)
|
||||
response_schema = response_format
|
||||
if strict:
|
||||
response_schema = ensure_strict_json_schema(
|
||||
|
|
@ -396,30 +490,53 @@ class LLMClient:
|
|||
path=(),
|
||||
root=response_schema,
|
||||
)
|
||||
async with client.chat.completions.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=self.max_tokens,
|
||||
response_format=(
|
||||
{
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "ResponseSchema",
|
||||
"strict": strict,
|
||||
"schema": response_schema,
|
||||
},
|
||||
}
|
||||
),
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "content.delta":
|
||||
yield event.delta
|
||||
if supports_json_schema:
|
||||
async with client.chat.completions.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=max_tokens,
|
||||
response_format=(
|
||||
{
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "ResponseSchema",
|
||||
"strict": strict,
|
||||
"schema": response_schema,
|
||||
},
|
||||
}
|
||||
),
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "content.delta":
|
||||
yield event.delta
|
||||
else:
|
||||
async with client.chat.completions.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ResponseSchema",
|
||||
"description": "A response to the user's message",
|
||||
"strict": strict,
|
||||
"parameters": response_format,
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice="required",
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "tool_calls.function.arguments.delta":
|
||||
yield event.arguments_delta
|
||||
|
||||
async def _stream_google_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: genai.Client = self._client
|
||||
async for event in iterator_to_async(client.models.generate_content_stream)(
|
||||
|
|
@ -429,7 +546,7 @@ class LLMClient:
|
|||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=response_format,
|
||||
max_output_tokens=self.max_tokens,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
):
|
||||
if event.text:
|
||||
|
|
@ -440,6 +557,7 @@ class LLMClient:
|
|||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
client: AsyncAnthropic = self._client
|
||||
async with client.messages.stream(
|
||||
|
|
@ -449,7 +567,7 @@ class LLMClient:
|
|||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
max_tokens=max_tokens or 4000,
|
||||
tools=[
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
|
|
@ -457,6 +575,10 @@ class LLMClient:
|
|||
"input_schema": response_format,
|
||||
}
|
||||
],
|
||||
tool_choice={
|
||||
"type": "tool",
|
||||
"name": "ResponseSchema",
|
||||
},
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
event: AnthropicMessageStreamEvent = event
|
||||
|
|
@ -469,8 +591,11 @@ class LLMClient:
|
|||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
return self._stream_openai_structured(model, messages, response_format, strict)
|
||||
return self._stream_openai_structured(
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
|
||||
def _stream_custom_structured(
|
||||
self,
|
||||
|
|
@ -478,8 +603,11 @@ class LLMClient:
|
|||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
return self._stream_openai_structured(model, messages, response_format, strict)
|
||||
return self._stream_openai_structured(
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
|
||||
def stream_structured(
|
||||
self,
|
||||
|
|
@ -487,23 +615,26 @@ class LLMClient:
|
|||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
return self._stream_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self._stream_google_structured(model, messages, response_format)
|
||||
return self._stream_google_structured(
|
||||
model, messages, response_format, max_tokens
|
||||
)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return self._stream_anthropic_structured(
|
||||
model, messages, response_format
|
||||
model, messages, response_format, max_tokens
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
return self._stream_ollama_structured(
|
||||
model, messages, response_format, strict
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
case LLMProvider.CUSTOM:
|
||||
return self._stream_custom_structured(
|
||||
model, messages, response_format, strict
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import { NextResponse, NextRequest } from 'next/server';
|
|||
|
||||
export async function POST(req: NextRequest) {
|
||||
const { id, title } = await req.json();
|
||||
console.log('path', process.env.APP_DATA_DIRECTORY);
|
||||
if (!id) {
|
||||
return NextResponse.json({ error: "Missing Presentation ID" }, { status: 400 });
|
||||
}
|
||||
|
|
@ -69,7 +68,7 @@ export async function POST(req: NextRequest) {
|
|||
|
||||
const sanitizedTitle = sanitizeFilename(title);
|
||||
const destinationPath = path.join(process.env.APP_DATA_DIRECTORY!, 'exports', `${sanitizedTitle}.pdf`);
|
||||
console.log('destinationPath', destinationPath);
|
||||
await fs.promises.mkdir(path.dirname(destinationPath), { recursive: true });
|
||||
await fs.promises.writeFile(destinationPath, pdfBuffer);
|
||||
|
||||
return NextResponse.json({
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue