All checks were successful
Deploy to Production / deploy (push) Successful in 2m23s
Includes frontend redesign (Navigation, billingApi), backend updates (auth routes, admin routes, LLM service refactor), MSAL removal, and dependency updates. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
559 lines
23 KiB
Python
Executable file
559 lines
23 KiB
Python
Executable file
"""
|
|
LLM Service for Cohorta — Azure AI Foundry (Responses API)
|
|
|
|
All model calls route through a single Azure AI Foundry project endpoint.
|
|
gpt-5.4 — complex tasks: persona responses, moderator, detailed generation
|
|
gpt-5.4-mini — cheap tasks: summaries, key themes, conversation decisions, basic gen
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import asyncio
|
|
import logging
|
|
import base64
|
|
import traceback
|
|
import time
|
|
import io
|
|
from openai import AsyncOpenAI
|
|
from typing import Dict, Any, Optional, Union, List
|
|
from PIL import Image
|
|
|
|
|
|
def _require_env(key: str) -> str:
|
|
value = os.environ.get(key)
|
|
if not value:
|
|
raise RuntimeError(
|
|
f"Required environment variable '{key}' is not set. "
|
|
"Set it in backend/.env before starting the server."
|
|
)
|
|
return value
|
|
|
|
|
|
AZURE_AI_ENDPOINT = _require_env('AZURE_AI_ENDPOINT')
|
|
AZURE_AI_API_KEY = _require_env('AZURE_AI_API_KEY')
|
|
AZURE_MODEL_MAIN = os.environ.get('AZURE_AI_MODEL_MAIN', 'gpt-5.4')
|
|
AZURE_MODEL_MINI = os.environ.get('AZURE_AI_MODEL_MINI', 'gpt-5.4-mini')
|
|
|
|
# Features automatically routed to the cheaper mini model
|
|
MINI_FEATURES = frozenset({
|
|
'summary',
|
|
'key_themes',
|
|
'conversation_decision',
|
|
'persona_basic',
|
|
'discussion_guide',
|
|
'audience_brief',
|
|
})
|
|
|
|
DEFAULT_MODEL = AZURE_MODEL_MAIN
|
|
|
|
SUPPORTED_MODELS = {
|
|
AZURE_MODEL_MAIN: 'azure',
|
|
AZURE_MODEL_MINI: 'azure',
|
|
}
|
|
|
|
# Legacy model IDs stored in the database — all map to the Azure main model
|
|
MODEL_ALIASES = {
|
|
'gemini-3.1-pro-preview': AZURE_MODEL_MAIN,
|
|
'gemini-3-pro-preview': AZURE_MODEL_MAIN,
|
|
'gpt-5.4-2026-03-05': AZURE_MODEL_MAIN,
|
|
'gpt-5': AZURE_MODEL_MAIN,
|
|
'gpt-5.2': AZURE_MODEL_MAIN,
|
|
'gpt-4.1': AZURE_MODEL_MAIN,
|
|
}
|
|
|
|
|
|
def get_azure_client() -> AsyncOpenAI:
|
|
"""Create a fresh Azure AI Foundry client for each call.
|
|
|
|
Creating a new client per call avoids event-loop mismatch issues in ASGI
|
|
environments where requests may arrive on different event loops. The
|
|
overhead is negligible compared to the LLM API call itself.
|
|
|
|
The base URL must end with /v1/ so the SDK correctly appends operation
|
|
paths (e.g. 'responses' → .../v1/responses).
|
|
"""
|
|
base_url = AZURE_AI_ENDPOINT.rstrip('/') + '/'
|
|
return AsyncOpenAI(
|
|
base_url=base_url,
|
|
api_key=AZURE_AI_API_KEY,
|
|
timeout=600.0,
|
|
)
|
|
|
|
|
|
class LLMServiceError(Exception):
|
|
"""Raised for errors in LLM operations."""
|
|
pass
|
|
|
|
|
|
class LLMService:
|
|
"""Centralized service for LLM operations via Azure AI Foundry."""
|
|
|
|
@staticmethod
|
|
def _extract_responses_api_content(response) -> str:
|
|
"""Extract text from an Azure / OpenAI Responses API response."""
|
|
result = ""
|
|
if hasattr(response, 'output') and response.output:
|
|
for item in response.output:
|
|
if hasattr(item, 'content') and item.content is not None:
|
|
for content in item.content:
|
|
if hasattr(content, 'text'):
|
|
result += content.text
|
|
if not result and hasattr(response, 'output_text'):
|
|
result = response.output_text
|
|
if not result and hasattr(response, 'text'):
|
|
result = response.text
|
|
return result.strip()
|
|
|
|
@staticmethod
|
|
def _resolve_model(model_name: Optional[str] = None) -> str:
|
|
"""Resolve a model name, applying feature-based mini routing.
|
|
|
|
Resolution order:
|
|
1. If model_name is one of the directly supported models, use it —
|
|
but still override to mini when the current feature is a mini feature.
|
|
2. If model_name is a legacy alias, resolve it, then apply mini routing.
|
|
3. If model_name is None or unknown, auto-route by feature context.
|
|
"""
|
|
# Determine base model from the explicit argument
|
|
if model_name:
|
|
resolved = MODEL_ALIASES.get(model_name, model_name)
|
|
base = resolved if resolved in SUPPORTED_MODELS else DEFAULT_MODEL
|
|
else:
|
|
base = DEFAULT_MODEL
|
|
|
|
# Feature override: mini features always get the cheaper model
|
|
try:
|
|
from app.services.llm_usage_context import current_context
|
|
ctx = current_context()
|
|
if ctx.feature in MINI_FEATURES:
|
|
return AZURE_MODEL_MINI
|
|
except Exception:
|
|
pass
|
|
|
|
return base
|
|
|
|
@staticmethod
|
|
def _get_model_provider(model_name: Optional[str] = None) -> str:
|
|
"""Return the provider for the resolved model (always 'azure')."""
|
|
return 'azure'
|
|
|
|
@staticmethod
|
|
def _extract_usage_metadata(response, provider: str) -> dict:
|
|
"""Extract token counts from a Responses API response. All fields default to 0."""
|
|
_log = logging.getLogger(__name__)
|
|
usage = getattr(response, 'usage', None)
|
|
if usage is None:
|
|
_log.warning("Azure response missing usage — token counts will be 0")
|
|
return {'prompt': 0, 'completion': 0, 'cached': 0, 'reasoning': 0}
|
|
if hasattr(usage, 'input_tokens'):
|
|
# Responses API shape
|
|
input_details = getattr(usage, 'input_tokens_details', None)
|
|
output_details = getattr(usage, 'output_tokens_details', None)
|
|
return {
|
|
'prompt': getattr(usage, 'input_tokens', 0) or 0,
|
|
'completion': getattr(usage, 'output_tokens', 0) or 0,
|
|
'cached': getattr(input_details, 'cached_tokens', 0) or 0 if input_details else 0,
|
|
'reasoning': getattr(output_details, 'reasoning_tokens', 0) or 0 if output_details else 0,
|
|
}
|
|
# Chat Completions API shape (fallback)
|
|
prompt_details = getattr(usage, 'prompt_tokens_details', None)
|
|
return {
|
|
'prompt': getattr(usage, 'prompt_tokens', 0) or 0,
|
|
'completion': getattr(usage, 'completion_tokens', 0) or 0,
|
|
'cached': getattr(prompt_details, 'cached_tokens', 0) or 0 if prompt_details else 0,
|
|
'reasoning': 0,
|
|
}
|
|
|
|
@staticmethod
|
|
async def _record_usage(response, provider: str, model: str, start_time: float, retry_count: int) -> None:
|
|
"""Record a usage event after a successful LLM call. Never raises."""
|
|
try:
|
|
from app.services.llm_usage_context import current_context
|
|
from app.models.usage_event import UsageEvent
|
|
from app.models.model_pricing import ModelPricing
|
|
|
|
ctx = current_context()
|
|
tokens = LLMService._extract_usage_metadata(response, provider)
|
|
pricing = await ModelPricing.current_for(model)
|
|
cost = ModelPricing.compute_cost(
|
|
pricing,
|
|
prompt_tokens=tokens['prompt'],
|
|
completion_tokens=tokens['completion'],
|
|
cached_tokens=tokens['cached'],
|
|
)
|
|
price_id = pricing.get('_id') if pricing else None
|
|
|
|
await UsageEvent.record(
|
|
provider=provider,
|
|
model=model,
|
|
prompt_tokens=tokens['prompt'],
|
|
completion_tokens=tokens['completion'],
|
|
cached_tokens=tokens['cached'],
|
|
reasoning_tokens=tokens['reasoning'],
|
|
cost_usd=cost,
|
|
price_snapshot_id=price_id,
|
|
duration_ms=int((time.monotonic() - start_time) * 1000),
|
|
retry_count=retry_count,
|
|
status="success",
|
|
user_id=ctx.user_id,
|
|
focus_group_id=ctx.focus_group_id,
|
|
persona_id=ctx.persona_id,
|
|
feature=ctx.feature,
|
|
task_id=ctx.task_id,
|
|
)
|
|
|
|
try:
|
|
if ctx.focus_group_id:
|
|
from app.models.focus_group import emit_websocket_event
|
|
asyncio.create_task(emit_websocket_event(
|
|
'usage_update',
|
|
ctx.focus_group_id,
|
|
{
|
|
'cost_delta': cost.get('total', 0),
|
|
'tokens_delta': tokens['prompt'] + tokens['completion'],
|
|
'feature': ctx.feature,
|
|
}
|
|
))
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
logging.getLogger(__name__).warning("_record_usage failed (non-fatal)", exc_info=True)
|
|
|
|
@staticmethod
|
|
def _build_responses_kwargs(
|
|
actual_model: str,
|
|
input_content,
|
|
reasoning_effort: Optional[str] = None,
|
|
verbosity: Optional[str] = None,
|
|
) -> dict:
|
|
"""Build the kwargs dict for a Responses API call."""
|
|
return {
|
|
"model": actual_model,
|
|
"input": input_content,
|
|
"reasoning": {"effort": reasoning_effort or "low"},
|
|
"text": {
|
|
"format": {"type": "text"},
|
|
"verbosity": verbosity or "medium",
|
|
},
|
|
}
|
|
|
|
@staticmethod
|
|
async def generate_content(
|
|
prompt: str,
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
model_name: Optional[str] = None,
|
|
system_prompt: Optional[str] = None,
|
|
reasoning_effort: Optional[str] = None,
|
|
verbosity: Optional[str] = None
|
|
) -> str:
|
|
"""Generate text content via the Azure AI Foundry Responses API.
|
|
|
|
Args:
|
|
prompt: The user prompt.
|
|
temperature: Ignored for Responses API (kept for interface compatibility).
|
|
max_tokens: Ignored for Responses API (kept for interface compatibility).
|
|
model_name: Override model. If None, auto-routes by feature context.
|
|
system_prompt: Optional system instruction prepended to the prompt.
|
|
reasoning_effort: Responses API reasoning effort (low/medium/high).
|
|
verbosity: Responses API verbosity (low/medium/high).
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
max_retries = 3
|
|
last_error = None
|
|
|
|
try:
|
|
from app.models.quota import check_quota, QuotaExceededError as _QEE
|
|
from app.services.llm_usage_context import current_context as _ctx
|
|
_c = _ctx()
|
|
await check_quota(_c.user_id, _c.focus_group_id)
|
|
except Exception as _qe:
|
|
from app.models.quota import QuotaExceededError as _QEE2
|
|
if isinstance(_qe, _QEE2):
|
|
raise
|
|
pass
|
|
|
|
actual_model = LLMService._resolve_model(model_name)
|
|
_start_time = time.monotonic()
|
|
|
|
if system_prompt:
|
|
input_content = f"System: {system_prompt}\n\nUser: {prompt}"
|
|
else:
|
|
input_content = prompt
|
|
|
|
kwargs = LLMService._build_responses_kwargs(actual_model, input_content, reasoning_effort, verbosity)
|
|
|
|
for attempt in range(max_retries):
|
|
attempt_num = attempt + 1
|
|
logger.debug(f"LLM generate_content attempt {attempt_num}/{max_retries} model={actual_model}")
|
|
try:
|
|
response = await get_azure_client().responses.create(**kwargs)
|
|
result = LLMService._extract_responses_api_content(response)
|
|
if attempt > 0:
|
|
logger.info(f"LLM generate_content succeeded on attempt {attempt_num}/{max_retries}")
|
|
await LLMService._record_usage(response, 'azure', actual_model, _start_time, attempt)
|
|
return result
|
|
except Exception as e:
|
|
last_error = e
|
|
error_message = str(e).lower()
|
|
logger.warning(
|
|
f"LLM attempt {attempt_num}/{max_retries} failed — {type(e).__name__}: {e}\n"
|
|
f"{traceback.format_exc()}"
|
|
)
|
|
is_retryable = any(kw in error_message for kw in (
|
|
"500", "internal error", "internal server error",
|
|
"service unavailable", "timeout", "rate",
|
|
))
|
|
if is_retryable and attempt < max_retries - 1:
|
|
wait = 2 ** attempt
|
|
logger.info(f"Retryable error — waiting {wait}s before retry {attempt_num + 1}")
|
|
await asyncio.sleep(wait)
|
|
continue
|
|
else:
|
|
if not is_retryable:
|
|
logger.error(f"Non-retryable error: {e}")
|
|
break
|
|
|
|
error_detail = str(last_error) or repr(last_error)
|
|
logger.error(f"generate_content failed after {max_retries} attempts: {error_detail}")
|
|
raise LLMServiceError(f"Error generating content: {error_detail}")
|
|
|
|
@staticmethod
|
|
def parse_json_response(response_text: str) -> Union[Dict[str, Any], List[Any]]:
|
|
"""Parse a JSON response from the LLM, stripping markdown code fences."""
|
|
try:
|
|
clean = response_text
|
|
if clean.startswith("```json"):
|
|
clean = clean.strip("```json").strip("```").strip()
|
|
elif clean.startswith("```"):
|
|
clean = clean.strip("```").strip()
|
|
return json.loads(clean)
|
|
except json.JSONDecodeError as e:
|
|
msg = f"Failed to parse JSON response: {e}. Raw: {response_text[:200]}..."
|
|
logging.getLogger(__name__).error(msg)
|
|
raise LLMServiceError(msg)
|
|
|
|
@staticmethod
|
|
async def generate_structured_response(
|
|
prompt: str,
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
model_name: Optional[str] = None,
|
|
system_prompt: Optional[str] = None,
|
|
reasoning_effort: Optional[str] = None,
|
|
verbosity: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""Generate and parse a structured JSON dict response."""
|
|
text = await LLMService.generate_content(
|
|
prompt=prompt, temperature=temperature, max_tokens=max_tokens,
|
|
model_name=model_name, system_prompt=system_prompt,
|
|
reasoning_effort=reasoning_effort, verbosity=verbosity,
|
|
)
|
|
return LLMService.parse_json_response(text)
|
|
|
|
@staticmethod
|
|
async def generate_structured_array(
|
|
prompt: str,
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
model_name: Optional[str] = None,
|
|
system_prompt: Optional[str] = None,
|
|
reasoning_effort: Optional[str] = None,
|
|
verbosity: Optional[str] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Generate and parse a structured JSON array response."""
|
|
text = await LLMService.generate_content(
|
|
prompt=prompt, temperature=temperature, max_tokens=max_tokens,
|
|
model_name=model_name, system_prompt=system_prompt,
|
|
reasoning_effort=reasoning_effort, verbosity=verbosity,
|
|
)
|
|
result = LLMService.parse_json_response(text)
|
|
if not isinstance(result, list):
|
|
raise LLMServiceError(f"Expected JSON array but received {type(result)}")
|
|
return result
|
|
|
|
@staticmethod
|
|
async def generate_multimodal_content(
|
|
prompt: str,
|
|
image_paths: List[str],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
model_name: Optional[str] = None
|
|
) -> str:
|
|
"""Generate content from text + images via the Azure Responses API."""
|
|
logger = logging.getLogger(__name__)
|
|
max_retries = 3
|
|
last_error = None
|
|
|
|
try:
|
|
from app.models.quota import check_quota, QuotaExceededError as _QEE
|
|
from app.services.llm_usage_context import current_context as _ctx
|
|
_c = _ctx()
|
|
await check_quota(_c.user_id, _c.focus_group_id)
|
|
except Exception as _qe:
|
|
from app.models.quota import QuotaExceededError as _QEE2
|
|
if isinstance(_qe, _QEE2):
|
|
raise
|
|
pass
|
|
|
|
actual_model = LLMService._resolve_model(model_name)
|
|
logger.info(f"generate_multimodal_content: {len(image_paths)} image(s), model={actual_model}")
|
|
_start_time = time.monotonic()
|
|
|
|
# Build the multimodal input list
|
|
content_items = [{"type": "input_text", "text": prompt}]
|
|
for image_path in image_paths:
|
|
if not os.path.exists(image_path):
|
|
raise LLMServiceError(f"Image file not found: {image_path}")
|
|
ext = image_path.lower().split('.')[-1]
|
|
mime_type = {'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', 'png': 'image/png',
|
|
'gif': 'image/gif', 'webp': 'image/webp'}.get(ext, 'image/jpeg')
|
|
with open(image_path, "rb") as f:
|
|
b64 = base64.b64encode(f.read()).decode('utf-8')
|
|
content_items.append({
|
|
"type": "input_image",
|
|
"image_url": f"data:{mime_type};base64,{b64}",
|
|
})
|
|
logger.debug(f"Loaded image: {image_path}")
|
|
|
|
input_content = [{"role": "user", "content": content_items}]
|
|
kwargs = LLMService._build_responses_kwargs(actual_model, input_content)
|
|
|
|
for attempt in range(max_retries):
|
|
attempt_num = attempt + 1
|
|
logger.debug(f"generate_multimodal_content attempt {attempt_num}/{max_retries}")
|
|
try:
|
|
response = await get_azure_client().responses.create(**kwargs)
|
|
result = LLMService._extract_responses_api_content(response)
|
|
if attempt > 0:
|
|
logger.info(f"generate_multimodal_content succeeded on attempt {attempt_num}/{max_retries}")
|
|
await LLMService._record_usage(response, 'azure', actual_model, _start_time, attempt)
|
|
return result
|
|
except Exception as e:
|
|
last_error = e
|
|
error_message = str(e).lower()
|
|
logger.warning(f"Multimodal attempt {attempt_num}/{max_retries} failed: {e}")
|
|
is_retryable = any(kw in error_message for kw in (
|
|
"500", "internal error", "service unavailable", "timeout", "rate",
|
|
))
|
|
if is_retryable and attempt < max_retries - 1:
|
|
wait = 2 ** attempt
|
|
await asyncio.sleep(wait)
|
|
continue
|
|
else:
|
|
break
|
|
|
|
logger.error(f"generate_multimodal_content failed after {max_retries} attempts: {last_error}")
|
|
raise LLMServiceError(f"Error generating multimodal content: {last_error}")
|
|
|
|
@staticmethod
|
|
async def generate_contextual_response(
|
|
prompt: str,
|
|
conversation_context: List[Dict[str, Any]],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
model_name: Optional[str] = None,
|
|
reasoning_effort: Optional[str] = None,
|
|
verbosity: Optional[str] = None
|
|
) -> str:
|
|
"""Generate content using conversation context that may include text and images.
|
|
|
|
If the context contains images, builds a multimodal Responses API request.
|
|
Otherwise delegates to generate_content for a text-only call.
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
from app.models.quota import check_quota, QuotaExceededError as _QEE
|
|
from app.services.llm_usage_context import current_context as _ctx
|
|
_c = _ctx()
|
|
await check_quota(_c.user_id, _c.focus_group_id)
|
|
except Exception as _qe:
|
|
from app.models.quota import QuotaExceededError as _QEE2
|
|
if isinstance(_qe, _QEE2):
|
|
raise
|
|
pass
|
|
|
|
text_parts = []
|
|
pil_images = []
|
|
|
|
logger.debug(f"generate_contextual_response: {len(conversation_context)} context items")
|
|
|
|
for item in conversation_context:
|
|
if item["type"] == "text":
|
|
text_parts.append(item["content"])
|
|
elif item["type"] == "image":
|
|
try:
|
|
image_path = item["path"]
|
|
if os.path.exists(image_path):
|
|
with Image.open(image_path) as img:
|
|
if img.mode != 'RGB':
|
|
img = img.convert('RGB')
|
|
pil_images.append(img.copy())
|
|
logger.debug(f"Loaded context image: {item.get('filename', image_path)}")
|
|
else:
|
|
logger.warning(f"Context image not found: {image_path}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load context image {item.get('path', '?')}: {e}")
|
|
|
|
context_prefix = ("CONVERSATION CONTEXT:\n" + "\n".join(text_parts) + "\n\n") if text_parts else ""
|
|
full_prompt = context_prefix + prompt
|
|
|
|
if not pil_images:
|
|
return await LLMService.generate_content(
|
|
prompt=full_prompt,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
model_name=model_name,
|
|
reasoning_effort=reasoning_effort,
|
|
verbosity=verbosity,
|
|
)
|
|
|
|
# Multimodal path
|
|
actual_model = LLMService._resolve_model(model_name)
|
|
max_retries = 3
|
|
last_error = None
|
|
_start_time = time.monotonic()
|
|
|
|
# Convert PIL images to base64
|
|
content_items = [{"type": "input_text", "text": full_prompt}]
|
|
for img in pil_images:
|
|
buf = io.BytesIO()
|
|
img.save(buf, format='PNG')
|
|
b64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
|
content_items.append({
|
|
"type": "input_image",
|
|
"image_url": f"data:image/png;base64,{b64}",
|
|
})
|
|
|
|
input_content = [{"role": "user", "content": content_items}]
|
|
kwargs = LLMService._build_responses_kwargs(
|
|
actual_model, input_content, reasoning_effort, verbosity
|
|
)
|
|
|
|
for attempt in range(max_retries):
|
|
attempt_num = attempt + 1
|
|
logger.debug(f"generate_contextual_response multimodal attempt {attempt_num}/{max_retries}")
|
|
try:
|
|
response = await get_azure_client().responses.create(**kwargs)
|
|
result = LLMService._extract_responses_api_content(response)
|
|
if attempt > 0:
|
|
logger.info(f"generate_contextual_response succeeded on attempt {attempt_num}/{max_retries}")
|
|
await LLMService._record_usage(response, 'azure', actual_model, _start_time, attempt)
|
|
return result
|
|
except Exception as e:
|
|
last_error = e
|
|
error_message = str(e).lower()
|
|
logger.warning(f"Contextual multimodal attempt {attempt_num}/{max_retries} failed: {e}")
|
|
is_retryable = any(kw in error_message for kw in (
|
|
"500", "internal error", "service unavailable", "timeout", "rate",
|
|
))
|
|
if is_retryable and attempt < max_retries - 1:
|
|
wait = 2 ** attempt
|
|
await asyncio.sleep(wait)
|
|
continue
|
|
else:
|
|
break
|
|
|
|
logger.error(f"generate_contextual_response failed after {max_retries} attempts: {last_error}")
|
|
raise LLMServiceError(f"Error generating contextual response: {last_error}")
|