"""Anthropic SDK wrapper with retry logic and token tracking.""" import logging import time from typing import Any import anthropic from app.config import settings logger = logging.getLogger(__name__) # Cost per token (approximate, varies by model) COST_PER_INPUT_TOKEN = 3.0 / 1_000_000 # $3 per 1M input tokens COST_PER_OUTPUT_TOKEN = 15.0 / 1_000_000 # $15 per 1M output tokens class LLMClient: """Wrapper around the Anthropic SDK with retry and token tracking. Provides exponential backoff retry on rate limit and server errors. Tracks token usage per call for cost monitoring. """ def __init__( self, api_key: str | None = None, model: str | None = None, max_retries: int = 3, base_delay: float = 1.0, ) -> None: self.api_key = api_key or settings.ANTHROPIC_API_KEY self.model = model or settings.LLM_MODEL self.max_retries = max_retries self.base_delay = base_delay self.client = anthropic.Anthropic(api_key=self.api_key) self.last_usage: dict[str, Any] = {} def create_message( self, system_prompt: str, user_message: str, max_tokens: int = 4096, temperature: float = 0.7, ) -> tuple[str, dict[str, Any]]: """Send a message to Claude and return the response with usage data. Args: system_prompt: The system prompt. user_message: The user message. max_tokens: Maximum tokens in the response. temperature: Sampling temperature. Returns: Tuple of (response_text, usage_dict). usage_dict has keys: input_tokens, output_tokens, total_tokens, estimated_cost_usd. Raises: anthropic.APIError: If all retries are exhausted. """ last_error = None for attempt in range(1, self.max_retries + 1): try: response = self.client.messages.create( model=self.model, max_tokens=max_tokens, system=system_prompt, messages=[{"role": "user", "content": user_message}], temperature=temperature, ) # Extract text response_text = "" for block in response.content: if hasattr(block, "text"): response_text += block.text # Track usage input_tokens = response.usage.input_tokens output_tokens = response.usage.output_tokens total_tokens = input_tokens + output_tokens estimated_cost = ( input_tokens * COST_PER_INPUT_TOKEN + output_tokens * COST_PER_OUTPUT_TOKEN ) usage = { "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens, "estimated_cost_usd": round(estimated_cost, 6), "model": self.model, } self.last_usage = usage return response_text, usage except anthropic.RateLimitError as e: last_error = e delay = self.base_delay * (2 ** (attempt - 1)) logger.warning( f"Rate limited (attempt {attempt}/{self.max_retries}), " f"retrying in {delay}s" ) time.sleep(delay) except anthropic.APIStatusError as e: if e.status_code >= 500: last_error = e delay = self.base_delay * (2 ** (attempt - 1)) logger.warning( f"Server error {e.status_code} (attempt {attempt}/{self.max_retries}), " f"retrying in {delay}s" ) time.sleep(delay) else: raise raise last_error # type: ignore[misc] async def acreate_message( self, system_prompt: str, user_message: str, max_tokens: int = 4096, temperature: float = 0.7, ) -> tuple[str, dict[str, Any]]: """Async version of create_message using the async client. Same interface as create_message but uses asyncio. """ import asyncio # Run sync client in executor to avoid blocking loop = asyncio.get_event_loop() return await loop.run_in_executor( None, lambda: self.create_message( system_prompt, user_message, max_tokens, temperature ), )