presenton/servers/fastapi/ppt_generator/fix_validation_errors.py
2025-06-23 15:13:04 +05:45

91 lines
2.8 KiB
Python

import os
from typing import Optional
from fastapi import HTTPException
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, ValidationError
from api.utils.utils import get_large_model
def get_prompt_template():
return ChatPromptTemplate(
messages=[
(
"system",
"""
Analyze the provided [Input] and [Errors] then provide structured output by fixing the errors.
# Steps
1. Go through the provided [Input].
2. Find mentioned [Errors] in the [Input].
3. Check provided schema and follow every constraints.
4. Provide structured output.
# Notes
- Only output fields mentioned in the schema.
- Check if fields' key may have been misnamed in the provided **Input**.
- Change fields' name to match the schema.
""",
),
(
"user",
"""
- Input: {input}
- Errors: {errors}
""",
),
]
)
async def fix_validation_errors(response_model: BaseModel, response, errors):
model = get_large_model()
chain = get_prompt_template() | model.with_structured_output(
response_model.model_json_schema()
)
return await chain.ainvoke({"input": response, "errors": errors})
async def get_validated_response(
chain,
input_dict,
response_model: BaseModel,
validation_model: Optional[BaseModel] = None,
retries: int = 1,
):
response = await chain.ainvoke(input_dict)
validation_model = validation_model or response_model
attempt = 0
while retries >= attempt:
attempt += 1
print("-" * 50)
print(f"Validation Retry attempt - {attempt}")
try:
if response and type(response) is list:
response = response[0]["args"]
validated_response = validation_model(**response)
return validated_response
except ValidationError as e:
if retries < attempt:
break
error_details = []
for error in e.errors():
error_details.append(
{
"loc": " -> ".join(str(loc) for loc in error["loc"]),
"msg": error["msg"],
"type": error["type"],
}
)
response = await fix_validation_errors(
response_model, response, error_details
)
raise HTTPException(status_code=400, detail="Error while validating response")