feat: adds support for deepseek using custom llm (tool calls for deepseek)

This commit is contained in:
sauravniraula 2025-08-01 18:16:59 +05:45
parent 60bc61ae48
commit 4b6de697ec
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
3 changed files with 217 additions and 87 deletions

View file

@ -1,6 +1,6 @@
import asyncio
import json
from typing import List
from typing import List, Optional
from fastapi import HTTPException
from openai import AsyncOpenAI
from google import genai
@ -24,10 +24,17 @@ from utils.schema_utils import ensure_strict_json_schema
class LLMClient:
def __init__(self, max_tokens: int = 4000):
def __init__(self):
self.llm_provider = get_llm_provider()
self._client = self._get_client()
self.max_tokens = max_tokens
# Supports json_schema
def supports_json_schema(self, model: str) -> bool:
if model.startswith("deepseek"):
return False
if model.startswith("claude"):
return False
return True
# ? Clients
def _get_client(self):
@ -103,16 +110,26 @@ class LLMClient:
return [message for message in messages if message.role == "user"]
# ? Generate Unstructured Content
async def _generate_openai(self, model: str, messages: List[LLMMessage]):
async def _generate_openai(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
):
client: AsyncOpenAI = self._client
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=self.max_tokens,
max_completion_tokens=max_tokens,
)
return response.choices[0].message.content
async def _generate_google(self, model: str, messages: List[LLMMessage]):
async def _generate_google(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
):
client: genai.Client = self._client
response = await asyncio.to_thread(
client.models.generate_content,
@ -121,12 +138,17 @@ class LLMClient:
config=GenerateContentConfig(
system_instruction=self._get_system_prompt(messages),
response_mime_type="text/plain",
max_output_tokens=self.max_tokens,
max_output_tokens=max_tokens,
),
)
return response.text
async def _generate_anthropic(self, model: str, messages: List[LLMMessage]):
async def _generate_anthropic(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
):
client: AsyncAnthropic = self._client
response: AnthropicMessage = await client.messages.create(
model=model,
@ -135,7 +157,7 @@ class LLMClient:
message.model_dump()
for message in self._get_user_llm_messages(messages)
],
max_tokens=self.max_tokens,
max_tokens=max_tokens or 4000,
)
text = ""
for content in response.content:
@ -145,25 +167,34 @@ class LLMClient:
return None
return text
async def _generate_ollama(self, model: str, messages: List[LLMMessage]):
return await self._generate_openai(model, messages)
async def _generate_ollama(
self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None
):
return await self._generate_openai(model, messages, max_tokens)
async def _generate_custom(self, model: str, messages: List[LLMMessage]):
return await self._generate_openai(model, messages)
async def _generate_custom(
self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None
):
return await self._generate_openai(model, messages, max_tokens)
async def generate(self, model: str, messages: List[LLMMessage]):
async def generate(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
):
content = None
match self.llm_provider:
case LLMProvider.OPENAI:
content = await self._generate_openai(model, messages)
content = await self._generate_openai(model, messages, max_tokens)
case LLMProvider.GOOGLE:
content = await self._generate_google(model, messages)
content = await self._generate_google(model, messages, max_tokens)
case LLMProvider.ANTHROPIC:
content = await self._generate_anthropic(model, messages)
content = await self._generate_anthropic(model, messages, max_tokens)
case LLMProvider.OLLAMA:
content = await self._generate_ollama(model, messages)
content = await self._generate_ollama(model, messages, max_tokens)
case LLMProvider.CUSTOM:
content = await self._generate_custom(model, messages)
content = await self._generate_custom(model, messages, max_tokens)
if content is None:
raise HTTPException(
status_code=400,
@ -178,8 +209,10 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
):
client: AsyncOpenAI = self._client
supports_json_schema = self.supports_json_schema(model)
response_schema = response_format
if strict:
response_schema = ensure_strict_json_schema(
@ -187,22 +220,45 @@ class LLMClient:
path=(),
root=response_schema,
)
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
response_format={
"type": "json_schema",
"json_schema": (
if supports_json_schema:
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
response_format={
"type": "json_schema",
"json_schema": (
{
"name": "ResponseSchema",
"strict": strict,
"schema": response_schema,
}
),
},
max_completion_tokens=max_tokens,
)
content = response.choices[0].message.content
else:
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
tools=[
{
"name": "ResponseSchema",
"strict": strict,
"schema": response_schema,
"type": "function",
"function": {
"name": "ResponseSchema",
"description": "A response to the user's message",
"strict": strict,
"parameters": response_format,
},
}
),
},
max_completion_tokens=self.max_tokens,
)
content = response.choices[0].message.content
],
tool_choice="required",
max_completion_tokens=max_tokens,
)
tool_calls = response.choices[0].message.tool_calls
if tool_calls:
content = tool_calls[0].function.arguments
if content:
return json.loads(content)
return None
@ -212,6 +268,7 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
response_format: dict,
max_tokens: Optional[int] = None,
):
client: genai.Client = self._client
response = await asyncio.to_thread(
@ -222,7 +279,7 @@ class LLMClient:
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json",
response_json_schema=response_format,
max_output_tokens=self.max_tokens,
max_output_tokens=max_tokens,
),
)
content = None
@ -236,6 +293,7 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
response_format: dict,
max_tokens: Optional[int] = None,
):
client: AsyncAnthropic = self._client
response: AnthropicMessage = await client.messages.create(
@ -245,7 +303,7 @@ class LLMClient:
message.model_dump()
for message in self._get_user_llm_messages(messages)
],
max_tokens=self.max_tokens,
max_tokens=max_tokens or 4000,
tools=[
{
"name": "ResponseSchema",
@ -253,6 +311,10 @@ class LLMClient:
"input_schema": response_format,
}
],
tool_choice={
"type": "tool",
"name": "ResponseSchema",
},
)
content: dict | None = None
for content_block in response.content:
@ -267,9 +329,10 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
):
return await self._generate_openai_structured(
model, messages, response_format, strict
model, messages, response_format, strict, max_tokens
)
async def _generate_custom_structured(
@ -278,9 +341,10 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
):
return await self._generate_openai_structured(
model, messages, response_format, strict
model, messages, response_format, strict, max_tokens
)
async def generate_structured(
@ -289,28 +353,29 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
) -> dict:
content = None
match self.llm_provider:
case LLMProvider.OPENAI:
content = await self._generate_openai_structured(
model, messages, response_format, strict
model, messages, response_format, strict, max_tokens
)
case LLMProvider.GOOGLE:
content = await self._generate_google_structured(
model, messages, response_format
model, messages, response_format, max_tokens
)
case LLMProvider.ANTHROPIC:
content = await self._generate_anthropic_structured(
model, messages, response_format
model, messages, response_format, max_tokens
)
case LLMProvider.OLLAMA:
content = await self._generate_ollama_structured(
model, messages, response_format, strict
model, messages, response_format, strict, max_tokens
)
case LLMProvider.CUSTOM:
content = await self._generate_custom_structured(
model, messages, response_format, strict
model, messages, response_format, strict, max_tokens
)
if content is None:
raise HTTPException(
@ -320,18 +385,28 @@ class LLMClient:
return content
# ? Stream Unstructured Content
async def _stream_openai(self, model: str, messages: List[LLMMessage]):
async def _stream_openai(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
):
client: AsyncOpenAI = self._client
async with client.chat.completions.stream(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=self.max_tokens,
max_completion_tokens=max_tokens,
) as stream:
async for event in stream:
if event.type == "content.delta":
yield event.delta
async def _stream_google(self, model: str, messages: List[LLMMessage]):
async def _stream_google(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
):
client: genai.Client = self._client
async for event in iterator_to_async(client.models.generate_content_stream)(
model=model,
@ -339,13 +414,18 @@ class LLMClient:
config=GenerateContentConfig(
system_instruction=self._get_system_prompt(messages),
response_mime_type="text/plain",
max_output_tokens=self.max_tokens,
max_output_tokens=max_tokens,
),
):
if event.text:
yield event.text
async def _stream_anthropic(self, model: str, messages: List[LLMMessage]):
async def _stream_anthropic(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
):
client: AsyncAnthropic = self._client
async with client.messages.stream(
model=model,
@ -354,31 +434,43 @@ class LLMClient:
message.model_dump()
for message in self._get_user_llm_messages(messages)
],
max_tokens=self.max_tokens,
max_tokens=max_tokens or 4000,
) as stream:
async for event in stream:
event: AnthropicMessageStreamEvent = event
if event.type == "text" and isinstance(event.text, str):
yield event.text
def _stream_ollama(self, model: str, messages: List[LLMMessage]):
return self._stream_openai(model, messages)
def _stream_ollama(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
):
return self._stream_openai(model, messages, max_tokens)
def _stream_custom(self, model: str, messages: List[LLMMessage]):
return self._stream_openai(model, messages)
def _stream_custom(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
):
return self._stream_openai(model, messages, max_tokens)
def stream(self, model: str, messages: List[LLMMessage]):
def stream(
self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None
):
match self.llm_provider:
case LLMProvider.OPENAI:
return self._stream_openai(model, messages)
return self._stream_openai(model, messages, max_tokens)
case LLMProvider.GOOGLE:
return self._stream_google(model, messages)
return self._stream_google(model, messages, max_tokens)
case LLMProvider.ANTHROPIC:
return self._stream_anthropic(model, messages)
return self._stream_anthropic(model, messages, max_tokens)
case LLMProvider.OLLAMA:
return self._stream_ollama(model, messages)
return self._stream_ollama(model, messages, max_tokens)
case LLMProvider.CUSTOM:
return self._stream_custom(model, messages)
return self._stream_custom(model, messages, max_tokens)
# ? Stream Structured Content
async def _stream_openai_structured(
@ -387,8 +479,10 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
):
client: AsyncOpenAI = self._client
supports_json_schema = self.supports_json_schema(model)
response_schema = response_format
if strict:
response_schema = ensure_strict_json_schema(
@ -396,30 +490,53 @@ class LLMClient:
path=(),
root=response_schema,
)
async with client.chat.completions.stream(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=self.max_tokens,
response_format=(
{
"type": "json_schema",
"json_schema": {
"name": "ResponseSchema",
"strict": strict,
"schema": response_schema,
},
}
),
) as stream:
async for event in stream:
if event.type == "content.delta":
yield event.delta
if supports_json_schema:
async with client.chat.completions.stream(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=max_tokens,
response_format=(
{
"type": "json_schema",
"json_schema": {
"name": "ResponseSchema",
"strict": strict,
"schema": response_schema,
},
}
),
) as stream:
async for event in stream:
if event.type == "content.delta":
yield event.delta
else:
async with client.chat.completions.stream(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=max_tokens,
tools=[
{
"type": "function",
"function": {
"name": "ResponseSchema",
"description": "A response to the user's message",
"strict": strict,
"parameters": response_format,
},
}
],
tool_choice="required",
) as stream:
async for event in stream:
if event.type == "tool_calls.function.arguments.delta":
yield event.arguments_delta
async def _stream_google_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
max_tokens: Optional[int] = None,
):
client: genai.Client = self._client
async for event in iterator_to_async(client.models.generate_content_stream)(
@ -429,7 +546,7 @@ class LLMClient:
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json",
response_json_schema=response_format,
max_output_tokens=self.max_tokens,
max_output_tokens=max_tokens,
),
):
if event.text:
@ -440,6 +557,7 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
response_format: dict,
max_tokens: Optional[int] = None,
):
client: AsyncAnthropic = self._client
async with client.messages.stream(
@ -449,7 +567,7 @@ class LLMClient:
message.model_dump()
for message in self._get_user_llm_messages(messages)
],
max_tokens=self.max_tokens,
max_tokens=max_tokens or 4000,
tools=[
{
"name": "ResponseSchema",
@ -457,6 +575,10 @@ class LLMClient:
"input_schema": response_format,
}
],
tool_choice={
"type": "tool",
"name": "ResponseSchema",
},
) as stream:
async for event in stream:
event: AnthropicMessageStreamEvent = event
@ -469,8 +591,11 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
):
return self._stream_openai_structured(model, messages, response_format, strict)
return self._stream_openai_structured(
model, messages, response_format, strict, max_tokens
)
def _stream_custom_structured(
self,
@ -478,8 +603,11 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
):
return self._stream_openai_structured(model, messages, response_format, strict)
return self._stream_openai_structured(
model, messages, response_format, strict, max_tokens
)
def stream_structured(
self,
@ -487,23 +615,26 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
):
match self.llm_provider:
case LLMProvider.OPENAI:
return self._stream_openai_structured(
model, messages, response_format, strict
model, messages, response_format, strict, max_tokens
)
case LLMProvider.GOOGLE:
return self._stream_google_structured(model, messages, response_format)
return self._stream_google_structured(
model, messages, response_format, max_tokens
)
case LLMProvider.ANTHROPIC:
return self._stream_anthropic_structured(
model, messages, response_format
model, messages, response_format, max_tokens
)
case LLMProvider.OLLAMA:
return self._stream_ollama_structured(
model, messages, response_format, strict
model, messages, response_format, strict, max_tokens
)
case LLMProvider.CUSTOM:
return self._stream_custom_structured(
model, messages, response_format, strict
model, messages, response_format, strict, max_tokens
)

View file

@ -8,7 +8,6 @@ import { NextResponse, NextRequest } from 'next/server';
export async function POST(req: NextRequest) {
const { id, title } = await req.json();
console.log('path', process.env.APP_DATA_DIRECTORY);
if (!id) {
return NextResponse.json({ error: "Missing Presentation ID" }, { status: 400 });
}
@ -69,7 +68,7 @@ export async function POST(req: NextRequest) {
const sanitizedTitle = sanitizeFilename(title);
const destinationPath = path.join(process.env.APP_DATA_DIRECTORY!, 'exports', `${sanitizedTitle}.pdf`);
console.log('destinationPath', destinationPath);
await fs.promises.mkdir(path.dirname(destinationPath), { recursive: true });
await fs.promises.writeFile(destinationPath, pdfBuffer);
return NextResponse.json({