cohorta/backend/app/services/llm_service.py
Vadym Samoilenko e01569c412
All checks were successful
Deploy to Production / deploy (push) Successful in 2m23s
feat: commit all app changes — billing API, new auth, design overhaul
Includes frontend redesign (Navigation, billingApi), backend updates
(auth routes, admin routes, LLM service refactor), MSAL removal,
and dependency updates.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-23 19:04:43 +01:00

559 lines
23 KiB
Python
Executable file

"""
LLM Service for Cohorta — Azure AI Foundry (Responses API)
All model calls route through a single Azure AI Foundry project endpoint.
gpt-5.4 — complex tasks: persona responses, moderator, detailed generation
gpt-5.4-mini — cheap tasks: summaries, key themes, conversation decisions, basic gen
"""
import os
import json
import asyncio
import logging
import base64
import traceback
import time
import io
from openai import AsyncOpenAI
from typing import Dict, Any, Optional, Union, List
from PIL import Image
def _require_env(key: str) -> str:
value = os.environ.get(key)
if not value:
raise RuntimeError(
f"Required environment variable '{key}' is not set. "
"Set it in backend/.env before starting the server."
)
return value
AZURE_AI_ENDPOINT = _require_env('AZURE_AI_ENDPOINT')
AZURE_AI_API_KEY = _require_env('AZURE_AI_API_KEY')
AZURE_MODEL_MAIN = os.environ.get('AZURE_AI_MODEL_MAIN', 'gpt-5.4')
AZURE_MODEL_MINI = os.environ.get('AZURE_AI_MODEL_MINI', 'gpt-5.4-mini')
# Features automatically routed to the cheaper mini model
MINI_FEATURES = frozenset({
'summary',
'key_themes',
'conversation_decision',
'persona_basic',
'discussion_guide',
'audience_brief',
})
DEFAULT_MODEL = AZURE_MODEL_MAIN
SUPPORTED_MODELS = {
AZURE_MODEL_MAIN: 'azure',
AZURE_MODEL_MINI: 'azure',
}
# Legacy model IDs stored in the database — all map to the Azure main model
MODEL_ALIASES = {
'gemini-3.1-pro-preview': AZURE_MODEL_MAIN,
'gemini-3-pro-preview': AZURE_MODEL_MAIN,
'gpt-5.4-2026-03-05': AZURE_MODEL_MAIN,
'gpt-5': AZURE_MODEL_MAIN,
'gpt-5.2': AZURE_MODEL_MAIN,
'gpt-4.1': AZURE_MODEL_MAIN,
}
def get_azure_client() -> AsyncOpenAI:
"""Create a fresh Azure AI Foundry client for each call.
Creating a new client per call avoids event-loop mismatch issues in ASGI
environments where requests may arrive on different event loops. The
overhead is negligible compared to the LLM API call itself.
The base URL must end with /v1/ so the SDK correctly appends operation
paths (e.g. 'responses' → .../v1/responses).
"""
base_url = AZURE_AI_ENDPOINT.rstrip('/') + '/'
return AsyncOpenAI(
base_url=base_url,
api_key=AZURE_AI_API_KEY,
timeout=600.0,
)
class LLMServiceError(Exception):
"""Raised for errors in LLM operations."""
pass
class LLMService:
"""Centralized service for LLM operations via Azure AI Foundry."""
@staticmethod
def _extract_responses_api_content(response) -> str:
"""Extract text from an Azure / OpenAI Responses API response."""
result = ""
if hasattr(response, 'output') and response.output:
for item in response.output:
if hasattr(item, 'content') and item.content is not None:
for content in item.content:
if hasattr(content, 'text'):
result += content.text
if not result and hasattr(response, 'output_text'):
result = response.output_text
if not result and hasattr(response, 'text'):
result = response.text
return result.strip()
@staticmethod
def _resolve_model(model_name: Optional[str] = None) -> str:
"""Resolve a model name, applying feature-based mini routing.
Resolution order:
1. If model_name is one of the directly supported models, use it —
but still override to mini when the current feature is a mini feature.
2. If model_name is a legacy alias, resolve it, then apply mini routing.
3. If model_name is None or unknown, auto-route by feature context.
"""
# Determine base model from the explicit argument
if model_name:
resolved = MODEL_ALIASES.get(model_name, model_name)
base = resolved if resolved in SUPPORTED_MODELS else DEFAULT_MODEL
else:
base = DEFAULT_MODEL
# Feature override: mini features always get the cheaper model
try:
from app.services.llm_usage_context import current_context
ctx = current_context()
if ctx.feature in MINI_FEATURES:
return AZURE_MODEL_MINI
except Exception:
pass
return base
@staticmethod
def _get_model_provider(model_name: Optional[str] = None) -> str:
"""Return the provider for the resolved model (always 'azure')."""
return 'azure'
@staticmethod
def _extract_usage_metadata(response, provider: str) -> dict:
"""Extract token counts from a Responses API response. All fields default to 0."""
_log = logging.getLogger(__name__)
usage = getattr(response, 'usage', None)
if usage is None:
_log.warning("Azure response missing usage — token counts will be 0")
return {'prompt': 0, 'completion': 0, 'cached': 0, 'reasoning': 0}
if hasattr(usage, 'input_tokens'):
# Responses API shape
input_details = getattr(usage, 'input_tokens_details', None)
output_details = getattr(usage, 'output_tokens_details', None)
return {
'prompt': getattr(usage, 'input_tokens', 0) or 0,
'completion': getattr(usage, 'output_tokens', 0) or 0,
'cached': getattr(input_details, 'cached_tokens', 0) or 0 if input_details else 0,
'reasoning': getattr(output_details, 'reasoning_tokens', 0) or 0 if output_details else 0,
}
# Chat Completions API shape (fallback)
prompt_details = getattr(usage, 'prompt_tokens_details', None)
return {
'prompt': getattr(usage, 'prompt_tokens', 0) or 0,
'completion': getattr(usage, 'completion_tokens', 0) or 0,
'cached': getattr(prompt_details, 'cached_tokens', 0) or 0 if prompt_details else 0,
'reasoning': 0,
}
@staticmethod
async def _record_usage(response, provider: str, model: str, start_time: float, retry_count: int) -> None:
"""Record a usage event after a successful LLM call. Never raises."""
try:
from app.services.llm_usage_context import current_context
from app.models.usage_event import UsageEvent
from app.models.model_pricing import ModelPricing
ctx = current_context()
tokens = LLMService._extract_usage_metadata(response, provider)
pricing = await ModelPricing.current_for(model)
cost = ModelPricing.compute_cost(
pricing,
prompt_tokens=tokens['prompt'],
completion_tokens=tokens['completion'],
cached_tokens=tokens['cached'],
)
price_id = pricing.get('_id') if pricing else None
await UsageEvent.record(
provider=provider,
model=model,
prompt_tokens=tokens['prompt'],
completion_tokens=tokens['completion'],
cached_tokens=tokens['cached'],
reasoning_tokens=tokens['reasoning'],
cost_usd=cost,
price_snapshot_id=price_id,
duration_ms=int((time.monotonic() - start_time) * 1000),
retry_count=retry_count,
status="success",
user_id=ctx.user_id,
focus_group_id=ctx.focus_group_id,
persona_id=ctx.persona_id,
feature=ctx.feature,
task_id=ctx.task_id,
)
try:
if ctx.focus_group_id:
from app.models.focus_group import emit_websocket_event
asyncio.create_task(emit_websocket_event(
'usage_update',
ctx.focus_group_id,
{
'cost_delta': cost.get('total', 0),
'tokens_delta': tokens['prompt'] + tokens['completion'],
'feature': ctx.feature,
}
))
except Exception:
pass
except Exception:
logging.getLogger(__name__).warning("_record_usage failed (non-fatal)", exc_info=True)
@staticmethod
def _build_responses_kwargs(
actual_model: str,
input_content,
reasoning_effort: Optional[str] = None,
verbosity: Optional[str] = None,
) -> dict:
"""Build the kwargs dict for a Responses API call."""
return {
"model": actual_model,
"input": input_content,
"reasoning": {"effort": reasoning_effort or "low"},
"text": {
"format": {"type": "text"},
"verbosity": verbosity or "medium",
},
}
@staticmethod
async def generate_content(
prompt: str,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
model_name: Optional[str] = None,
system_prompt: Optional[str] = None,
reasoning_effort: Optional[str] = None,
verbosity: Optional[str] = None
) -> str:
"""Generate text content via the Azure AI Foundry Responses API.
Args:
prompt: The user prompt.
temperature: Ignored for Responses API (kept for interface compatibility).
max_tokens: Ignored for Responses API (kept for interface compatibility).
model_name: Override model. If None, auto-routes by feature context.
system_prompt: Optional system instruction prepended to the prompt.
reasoning_effort: Responses API reasoning effort (low/medium/high).
verbosity: Responses API verbosity (low/medium/high).
"""
logger = logging.getLogger(__name__)
max_retries = 3
last_error = None
try:
from app.models.quota import check_quota, QuotaExceededError as _QEE
from app.services.llm_usage_context import current_context as _ctx
_c = _ctx()
await check_quota(_c.user_id, _c.focus_group_id)
except Exception as _qe:
from app.models.quota import QuotaExceededError as _QEE2
if isinstance(_qe, _QEE2):
raise
pass
actual_model = LLMService._resolve_model(model_name)
_start_time = time.monotonic()
if system_prompt:
input_content = f"System: {system_prompt}\n\nUser: {prompt}"
else:
input_content = prompt
kwargs = LLMService._build_responses_kwargs(actual_model, input_content, reasoning_effort, verbosity)
for attempt in range(max_retries):
attempt_num = attempt + 1
logger.debug(f"LLM generate_content attempt {attempt_num}/{max_retries} model={actual_model}")
try:
response = await get_azure_client().responses.create(**kwargs)
result = LLMService._extract_responses_api_content(response)
if attempt > 0:
logger.info(f"LLM generate_content succeeded on attempt {attempt_num}/{max_retries}")
await LLMService._record_usage(response, 'azure', actual_model, _start_time, attempt)
return result
except Exception as e:
last_error = e
error_message = str(e).lower()
logger.warning(
f"LLM attempt {attempt_num}/{max_retries} failed — {type(e).__name__}: {e}\n"
f"{traceback.format_exc()}"
)
is_retryable = any(kw in error_message for kw in (
"500", "internal error", "internal server error",
"service unavailable", "timeout", "rate",
))
if is_retryable and attempt < max_retries - 1:
wait = 2 ** attempt
logger.info(f"Retryable error — waiting {wait}s before retry {attempt_num + 1}")
await asyncio.sleep(wait)
continue
else:
if not is_retryable:
logger.error(f"Non-retryable error: {e}")
break
error_detail = str(last_error) or repr(last_error)
logger.error(f"generate_content failed after {max_retries} attempts: {error_detail}")
raise LLMServiceError(f"Error generating content: {error_detail}")
@staticmethod
def parse_json_response(response_text: str) -> Union[Dict[str, Any], List[Any]]:
"""Parse a JSON response from the LLM, stripping markdown code fences."""
try:
clean = response_text
if clean.startswith("```json"):
clean = clean.strip("```json").strip("```").strip()
elif clean.startswith("```"):
clean = clean.strip("```").strip()
return json.loads(clean)
except json.JSONDecodeError as e:
msg = f"Failed to parse JSON response: {e}. Raw: {response_text[:200]}..."
logging.getLogger(__name__).error(msg)
raise LLMServiceError(msg)
@staticmethod
async def generate_structured_response(
prompt: str,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
model_name: Optional[str] = None,
system_prompt: Optional[str] = None,
reasoning_effort: Optional[str] = None,
verbosity: Optional[str] = None
) -> Dict[str, Any]:
"""Generate and parse a structured JSON dict response."""
text = await LLMService.generate_content(
prompt=prompt, temperature=temperature, max_tokens=max_tokens,
model_name=model_name, system_prompt=system_prompt,
reasoning_effort=reasoning_effort, verbosity=verbosity,
)
return LLMService.parse_json_response(text)
@staticmethod
async def generate_structured_array(
prompt: str,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
model_name: Optional[str] = None,
system_prompt: Optional[str] = None,
reasoning_effort: Optional[str] = None,
verbosity: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Generate and parse a structured JSON array response."""
text = await LLMService.generate_content(
prompt=prompt, temperature=temperature, max_tokens=max_tokens,
model_name=model_name, system_prompt=system_prompt,
reasoning_effort=reasoning_effort, verbosity=verbosity,
)
result = LLMService.parse_json_response(text)
if not isinstance(result, list):
raise LLMServiceError(f"Expected JSON array but received {type(result)}")
return result
@staticmethod
async def generate_multimodal_content(
prompt: str,
image_paths: List[str],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
model_name: Optional[str] = None
) -> str:
"""Generate content from text + images via the Azure Responses API."""
logger = logging.getLogger(__name__)
max_retries = 3
last_error = None
try:
from app.models.quota import check_quota, QuotaExceededError as _QEE
from app.services.llm_usage_context import current_context as _ctx
_c = _ctx()
await check_quota(_c.user_id, _c.focus_group_id)
except Exception as _qe:
from app.models.quota import QuotaExceededError as _QEE2
if isinstance(_qe, _QEE2):
raise
pass
actual_model = LLMService._resolve_model(model_name)
logger.info(f"generate_multimodal_content: {len(image_paths)} image(s), model={actual_model}")
_start_time = time.monotonic()
# Build the multimodal input list
content_items = [{"type": "input_text", "text": prompt}]
for image_path in image_paths:
if not os.path.exists(image_path):
raise LLMServiceError(f"Image file not found: {image_path}")
ext = image_path.lower().split('.')[-1]
mime_type = {'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', 'png': 'image/png',
'gif': 'image/gif', 'webp': 'image/webp'}.get(ext, 'image/jpeg')
with open(image_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode('utf-8')
content_items.append({
"type": "input_image",
"image_url": f"data:{mime_type};base64,{b64}",
})
logger.debug(f"Loaded image: {image_path}")
input_content = [{"role": "user", "content": content_items}]
kwargs = LLMService._build_responses_kwargs(actual_model, input_content)
for attempt in range(max_retries):
attempt_num = attempt + 1
logger.debug(f"generate_multimodal_content attempt {attempt_num}/{max_retries}")
try:
response = await get_azure_client().responses.create(**kwargs)
result = LLMService._extract_responses_api_content(response)
if attempt > 0:
logger.info(f"generate_multimodal_content succeeded on attempt {attempt_num}/{max_retries}")
await LLMService._record_usage(response, 'azure', actual_model, _start_time, attempt)
return result
except Exception as e:
last_error = e
error_message = str(e).lower()
logger.warning(f"Multimodal attempt {attempt_num}/{max_retries} failed: {e}")
is_retryable = any(kw in error_message for kw in (
"500", "internal error", "service unavailable", "timeout", "rate",
))
if is_retryable and attempt < max_retries - 1:
wait = 2 ** attempt
await asyncio.sleep(wait)
continue
else:
break
logger.error(f"generate_multimodal_content failed after {max_retries} attempts: {last_error}")
raise LLMServiceError(f"Error generating multimodal content: {last_error}")
@staticmethod
async def generate_contextual_response(
prompt: str,
conversation_context: List[Dict[str, Any]],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
model_name: Optional[str] = None,
reasoning_effort: Optional[str] = None,
verbosity: Optional[str] = None
) -> str:
"""Generate content using conversation context that may include text and images.
If the context contains images, builds a multimodal Responses API request.
Otherwise delegates to generate_content for a text-only call.
"""
logger = logging.getLogger(__name__)
try:
from app.models.quota import check_quota, QuotaExceededError as _QEE
from app.services.llm_usage_context import current_context as _ctx
_c = _ctx()
await check_quota(_c.user_id, _c.focus_group_id)
except Exception as _qe:
from app.models.quota import QuotaExceededError as _QEE2
if isinstance(_qe, _QEE2):
raise
pass
text_parts = []
pil_images = []
logger.debug(f"generate_contextual_response: {len(conversation_context)} context items")
for item in conversation_context:
if item["type"] == "text":
text_parts.append(item["content"])
elif item["type"] == "image":
try:
image_path = item["path"]
if os.path.exists(image_path):
with Image.open(image_path) as img:
if img.mode != 'RGB':
img = img.convert('RGB')
pil_images.append(img.copy())
logger.debug(f"Loaded context image: {item.get('filename', image_path)}")
else:
logger.warning(f"Context image not found: {image_path}")
except Exception as e:
logger.warning(f"Failed to load context image {item.get('path', '?')}: {e}")
context_prefix = ("CONVERSATION CONTEXT:\n" + "\n".join(text_parts) + "\n\n") if text_parts else ""
full_prompt = context_prefix + prompt
if not pil_images:
return await LLMService.generate_content(
prompt=full_prompt,
temperature=temperature,
max_tokens=max_tokens,
model_name=model_name,
reasoning_effort=reasoning_effort,
verbosity=verbosity,
)
# Multimodal path
actual_model = LLMService._resolve_model(model_name)
max_retries = 3
last_error = None
_start_time = time.monotonic()
# Convert PIL images to base64
content_items = [{"type": "input_text", "text": full_prompt}]
for img in pil_images:
buf = io.BytesIO()
img.save(buf, format='PNG')
b64 = base64.b64encode(buf.getvalue()).decode('utf-8')
content_items.append({
"type": "input_image",
"image_url": f"data:image/png;base64,{b64}",
})
input_content = [{"role": "user", "content": content_items}]
kwargs = LLMService._build_responses_kwargs(
actual_model, input_content, reasoning_effort, verbosity
)
for attempt in range(max_retries):
attempt_num = attempt + 1
logger.debug(f"generate_contextual_response multimodal attempt {attempt_num}/{max_retries}")
try:
response = await get_azure_client().responses.create(**kwargs)
result = LLMService._extract_responses_api_content(response)
if attempt > 0:
logger.info(f"generate_contextual_response succeeded on attempt {attempt_num}/{max_retries}")
await LLMService._record_usage(response, 'azure', actual_model, _start_time, attempt)
return result
except Exception as e:
last_error = e
error_message = str(e).lower()
logger.warning(f"Contextual multimodal attempt {attempt_num}/{max_retries} failed: {e}")
is_retryable = any(kw in error_message for kw in (
"500", "internal error", "service unavailable", "timeout", "rate",
))
if is_retryable and attempt < max_retries - 1:
wait = 2 ** attempt
await asyncio.sleep(wait)
continue
else:
break
logger.error(f"generate_contextual_response failed after {max_retries} attempts: {last_error}")
raise LLMServiceError(f"Error generating contextual response: {last_error}")