"""Anthropic SDK wrapper with retry logic and token tracking.""" import logging import time from typing import Any import anthropic from app.config import settings logger = logging.getLogger(__name__) # Cost per token by model (approximate) MODEL_COSTS: dict[str, tuple[float, float]] = { # (input_cost_per_token, output_cost_per_token) "claude-sonnet-4-6": (3.0 / 1_000_000, 15.0 / 1_000_000), "claude-opus-4-6": (15.0 / 1_000_000, 75.0 / 1_000_000), } # Default fallback (Sonnet pricing) COST_PER_INPUT_TOKEN = 3.0 / 1_000_000 COST_PER_OUTPUT_TOKEN = 15.0 / 1_000_000 class LLMClient: """Wrapper around the Anthropic SDK with retry and token tracking. Provides exponential backoff retry on rate limit and server errors. Tracks token usage per call for cost monitoring. """ def __init__( self, api_key: str | None = None, model: str | None = None, max_retries: int = 3, base_delay: float = 1.0, ) -> None: self.api_key = api_key or settings.ANTHROPIC_API_KEY self.model = model or settings.LLM_MODEL self.max_retries = max_retries self.base_delay = base_delay self.client = anthropic.Anthropic(api_key=self.api_key) self.last_usage: dict[str, Any] = {} def create_message( self, system_prompt: str, user_message: str, max_tokens: int = 4096, temperature: float = 0.7, ) -> tuple[str, dict[str, Any]]: """Send a message to Claude and return the response with usage data. Args: system_prompt: The system prompt. user_message: The user message. max_tokens: Maximum tokens in the response. temperature: Sampling temperature. Returns: Tuple of (response_text, usage_dict). usage_dict has keys: input_tokens, output_tokens, total_tokens, estimated_cost_usd. Raises: anthropic.APIError: If all retries are exhausted. """ last_error = None for attempt in range(1, self.max_retries + 1): try: # Streaming required for long requests (>10 min cap on # non-streaming calls). All call paths use streaming for # consistency. response_text = "" with self.client.messages.stream( model=self.model, max_tokens=max_tokens, system=system_prompt, messages=[{"role": "user", "content": user_message}], temperature=temperature, ) as stream: for chunk in stream.text_stream: response_text += chunk final = stream.get_final_message() input_tokens = final.usage.input_tokens output_tokens = final.usage.output_tokens total_tokens = input_tokens + output_tokens input_rate, output_rate = MODEL_COSTS.get( self.model, (COST_PER_INPUT_TOKEN, COST_PER_OUTPUT_TOKEN) ) estimated_cost = ( input_tokens * input_rate + output_tokens * output_rate ) stop_reason = getattr(final, "stop_reason", None) if stop_reason == "max_tokens": logger.warning( "LLM hit max_tokens (%d) — response was truncated. " "Consider raising max_tokens or batching the input.", max_tokens, ) usage = { "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens, "estimated_cost_usd": round(estimated_cost, 6), "model": self.model, "stop_reason": stop_reason, } self.last_usage = usage return response_text, usage except anthropic.RateLimitError as e: last_error = e delay = self.base_delay * (2 ** (attempt - 1)) logger.warning( f"Rate limited (attempt {attempt}/{self.max_retries}), " f"retrying in {delay}s" ) time.sleep(delay) except anthropic.APIStatusError as e: if e.status_code >= 500: last_error = e delay = self.base_delay * (2 ** (attempt - 1)) logger.warning( f"Server error {e.status_code} (attempt {attempt}/{self.max_retries}), " f"retrying in {delay}s" ) time.sleep(delay) else: raise raise last_error # type: ignore[misc] async def acreate_message( self, system_prompt: str, user_message: str, max_tokens: int = 4096, temperature: float = 0.7, ) -> tuple[str, dict[str, Any]]: """Async version of create_message using the async client. Same interface as create_message but uses asyncio. """ import asyncio # Run sync client in executor to avoid blocking loop = asyncio.get_event_loop() return await loop.run_in_executor( None, lambda: self.create_message( system_prompt, user_message, max_tokens, temperature ), ) def create_message_cached( self, system_prompt: str, cached_user_content: str, dynamic_user_content: str, max_tokens: int = 4096, temperature: float = 0.7, ) -> tuple[str, dict[str, Any]]: """Send a message with prompt caching enabled. Two cache breakpoints are placed: 1. on the system prompt (the V25 instructions — ~30k tokens) 2. on the static user content (TM + reference data — variable) The dynamic_user_content (per-batch source lines) is appended after the cached block and is NOT cached. This means every batch within a job re-uses the cached prefix at ~10% of the input cost, making source-line batching cheap. Cache TTL is 5 minutes by default — well within a single job's runtime, so all batches benefit from cache hits after the first. """ last_error = None # Build structured system + user content with cache_control markers. system_blocks: list[dict[str, Any]] = [ { "type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}, } ] user_content: list[dict[str, Any]] = [ { "type": "text", "text": cached_user_content, "cache_control": {"type": "ephemeral"}, }, { "type": "text", "text": dynamic_user_content, }, ] for attempt in range(1, self.max_retries + 1): try: # Streaming required: a single non-streaming request is # capped by the SDK at 10 minutes. Long batches (32k+ # output tokens) routinely exceed that. Streaming keeps # the connection alive event-by-event and has no time cap. response_text = "" with self.client.messages.stream( model=self.model, max_tokens=max_tokens, system=system_blocks, messages=[{"role": "user", "content": user_content}], temperature=temperature, ) as stream: for chunk in stream.text_stream: response_text += chunk final = stream.get_final_message() input_tokens = final.usage.input_tokens output_tokens = final.usage.output_tokens cache_read = getattr(final.usage, "cache_read_input_tokens", 0) or 0 cache_creation = getattr(final.usage, "cache_creation_input_tokens", 0) or 0 total_tokens = input_tokens + output_tokens input_rate, output_rate = MODEL_COSTS.get( self.model, (COST_PER_INPUT_TOKEN, COST_PER_OUTPUT_TOKEN) ) # Cache writes cost 1.25x input rate, cache reads 0.1x input rate. estimated_cost = ( input_tokens * input_rate + cache_creation * input_rate * 1.25 + cache_read * input_rate * 0.1 + output_tokens * output_rate ) stop_reason = getattr(final, "stop_reason", None) if stop_reason == "max_tokens": logger.warning( "LLM hit max_tokens (%d) — response was truncated. " "Consider raising max_tokens or smaller batches.", max_tokens, ) usage = { "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens, "cache_read_input_tokens": cache_read, "cache_creation_input_tokens": cache_creation, "estimated_cost_usd": round(estimated_cost, 6), "model": self.model, "stop_reason": stop_reason, } self.last_usage = usage return response_text, usage except anthropic.RateLimitError as e: last_error = e delay = self.base_delay * (2 ** (attempt - 1)) logger.warning( f"Rate limited (attempt {attempt}/{self.max_retries}), retrying in {delay}s" ) time.sleep(delay) except anthropic.APIStatusError as e: if e.status_code >= 500: last_error = e delay = self.base_delay * (2 ** (attempt - 1)) logger.warning( f"Server error {e.status_code} (attempt {attempt}/{self.max_retries}), retrying in {delay}s" ) time.sleep(delay) else: raise raise last_error # type: ignore[misc] async def acreate_message_cached( self, system_prompt: str, cached_user_content: str, dynamic_user_content: str, max_tokens: int = 4096, temperature: float = 0.7, ) -> tuple[str, dict[str, Any]]: """Async version of create_message_cached.""" import asyncio loop = asyncio.get_event_loop() return await loop.run_in_executor( None, lambda: self.create_message_cached( system_prompt, cached_user_content, dynamic_user_content, max_tokens, temperature, ), )