90 lines
2.8 KiB
Python
90 lines
2.8 KiB
Python
import os
|
|
from typing import Optional
|
|
from fastapi import HTTPException
|
|
from langchain_ollama import ChatOllama
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from api.utils.model_utils import get_large_model
|
|
|
|
|
|
def get_prompt_template():
|
|
return ChatPromptTemplate(
|
|
messages=[
|
|
(
|
|
"system",
|
|
"""
|
|
Analyze the provided [Input] and [Errors] then provide structured output by fixing the errors.
|
|
|
|
# Steps
|
|
1. Go through the provided [Input].
|
|
2. Find mentioned [Errors] in the [Input].
|
|
3. Check provided schema and follow every constraints.
|
|
4. Provide structured output.
|
|
|
|
# Notes
|
|
- Only output fields mentioned in the schema.
|
|
- Check if fields' key may have been misnamed in the provided **Input**.
|
|
- Change fields' name to match the schema.
|
|
""",
|
|
),
|
|
(
|
|
"user",
|
|
"""
|
|
- Input: {input}
|
|
- Errors: {errors}
|
|
""",
|
|
),
|
|
]
|
|
)
|
|
|
|
|
|
async def fix_validation_errors(response_model: BaseModel, response, errors):
|
|
model = ChatOllama(model=get_large_model(), temperature=0.8)
|
|
|
|
chain = get_prompt_template() | model.with_structured_output(
|
|
response_model.model_json_schema()
|
|
)
|
|
return await chain.ainvoke({"input": response, "errors": errors})
|
|
|
|
|
|
async def get_validated_response(
|
|
chain,
|
|
input_dict,
|
|
response_model: BaseModel,
|
|
validation_model: Optional[BaseModel] = None,
|
|
retries: int = 1,
|
|
):
|
|
response = await chain.ainvoke(input_dict)
|
|
validation_model = validation_model or response_model
|
|
|
|
attempt = 0
|
|
while retries >= attempt:
|
|
attempt += 1
|
|
print("-" * 50)
|
|
print(f"Validation Retry attempt - {attempt}")
|
|
try:
|
|
if response and type(response) is list:
|
|
response = response[0]["args"]
|
|
|
|
validated_response = validation_model(**response)
|
|
return validated_response
|
|
except ValidationError as e:
|
|
if retries < attempt:
|
|
break
|
|
|
|
error_details = []
|
|
for error in e.errors():
|
|
error_details.append(
|
|
{
|
|
"loc": " -> ".join(str(loc) for loc in error["loc"]),
|
|
"msg": error["msg"],
|
|
"type": error["type"],
|
|
}
|
|
)
|
|
|
|
response = await fix_validation_errors(
|
|
response_model, response, error_details
|
|
)
|
|
|
|
raise HTTPException(status_code=400, detail="Error while validating response")
|