amazon-transcreation/backend/app/llm/client.py
DJP 100eddbc21 Switch LLM calls to streaming + tighten batch sizes
The Anthropic SDK refuses non-streaming calls expected to take >10
minutes ("Streaming is required..."). Long-output batches (32k tokens
of densely-formatted markdown) hit this on real 172-line briefs.

Both LLMClient.create_message and create_message_cached now use the
streaming context manager (client.messages.stream(...)) and accumulate
text chunks; final usage + stop_reason come from get_final_message().
No timeout on streaming requests.

Tightened the batch tier so individual streams stay well under any
ceiling and progress / failure recovery is more granular:

- ≤50 lines: single call
- 51-120: batches of 30 (max_tokens=16k each)
- 121+:   batches of 25 (max_tokens=16k each)

Verified with the 172-line case: 7 batches of 25, 172 drafts produced.
Live streaming call confirmed end-to-end (haiku returned, usage and
stop_reason populated correctly).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 12:20:16 -04:00

304 lines
11 KiB
Python

"""Anthropic SDK wrapper with retry logic and token tracking."""
import logging
import time
from typing import Any
import anthropic
from app.config import settings
logger = logging.getLogger(__name__)
# Cost per token by model (approximate)
MODEL_COSTS: dict[str, tuple[float, float]] = {
# (input_cost_per_token, output_cost_per_token)
"claude-sonnet-4-6": (3.0 / 1_000_000, 15.0 / 1_000_000),
"claude-opus-4-6": (15.0 / 1_000_000, 75.0 / 1_000_000),
}
# Default fallback (Sonnet pricing)
COST_PER_INPUT_TOKEN = 3.0 / 1_000_000
COST_PER_OUTPUT_TOKEN = 15.0 / 1_000_000
class LLMClient:
"""Wrapper around the Anthropic SDK with retry and token tracking.
Provides exponential backoff retry on rate limit and server errors.
Tracks token usage per call for cost monitoring.
"""
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
max_retries: int = 3,
base_delay: float = 1.0,
) -> None:
self.api_key = api_key or settings.ANTHROPIC_API_KEY
self.model = model or settings.LLM_MODEL
self.max_retries = max_retries
self.base_delay = base_delay
self.client = anthropic.Anthropic(api_key=self.api_key)
self.last_usage: dict[str, Any] = {}
def create_message(
self,
system_prompt: str,
user_message: str,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> tuple[str, dict[str, Any]]:
"""Send a message to Claude and return the response with usage data.
Args:
system_prompt: The system prompt.
user_message: The user message.
max_tokens: Maximum tokens in the response.
temperature: Sampling temperature.
Returns:
Tuple of (response_text, usage_dict).
usage_dict has keys: input_tokens, output_tokens, total_tokens, estimated_cost_usd.
Raises:
anthropic.APIError: If all retries are exhausted.
"""
last_error = None
for attempt in range(1, self.max_retries + 1):
try:
# Streaming required for long requests (>10 min cap on
# non-streaming calls). All call paths use streaming for
# consistency.
response_text = ""
with self.client.messages.stream(
model=self.model,
max_tokens=max_tokens,
system=system_prompt,
messages=[{"role": "user", "content": user_message}],
temperature=temperature,
) as stream:
for chunk in stream.text_stream:
response_text += chunk
final = stream.get_final_message()
input_tokens = final.usage.input_tokens
output_tokens = final.usage.output_tokens
total_tokens = input_tokens + output_tokens
input_rate, output_rate = MODEL_COSTS.get(
self.model, (COST_PER_INPUT_TOKEN, COST_PER_OUTPUT_TOKEN)
)
estimated_cost = (
input_tokens * input_rate
+ output_tokens * output_rate
)
stop_reason = getattr(final, "stop_reason", None)
if stop_reason == "max_tokens":
logger.warning(
"LLM hit max_tokens (%d) — response was truncated. "
"Consider raising max_tokens or batching the input.",
max_tokens,
)
usage = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
"estimated_cost_usd": round(estimated_cost, 6),
"model": self.model,
"stop_reason": stop_reason,
}
self.last_usage = usage
return response_text, usage
except anthropic.RateLimitError as e:
last_error = e
delay = self.base_delay * (2 ** (attempt - 1))
logger.warning(
f"Rate limited (attempt {attempt}/{self.max_retries}), "
f"retrying in {delay}s"
)
time.sleep(delay)
except anthropic.APIStatusError as e:
if e.status_code >= 500:
last_error = e
delay = self.base_delay * (2 ** (attempt - 1))
logger.warning(
f"Server error {e.status_code} (attempt {attempt}/{self.max_retries}), "
f"retrying in {delay}s"
)
time.sleep(delay)
else:
raise
raise last_error # type: ignore[misc]
async def acreate_message(
self,
system_prompt: str,
user_message: str,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> tuple[str, dict[str, Any]]:
"""Async version of create_message using the async client.
Same interface as create_message but uses asyncio.
"""
import asyncio
# Run sync client in executor to avoid blocking
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self.create_message(
system_prompt, user_message, max_tokens, temperature
),
)
def create_message_cached(
self,
system_prompt: str,
cached_user_content: str,
dynamic_user_content: str,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> tuple[str, dict[str, Any]]:
"""Send a message with prompt caching enabled.
Two cache breakpoints are placed:
1. on the system prompt (the V25 instructions — ~30k tokens)
2. on the static user content (TM + reference data — variable)
The dynamic_user_content (per-batch source lines) is appended
after the cached block and is NOT cached. This means every batch
within a job re-uses the cached prefix at ~10% of the input cost,
making source-line batching cheap.
Cache TTL is 5 minutes by default — well within a single job's
runtime, so all batches benefit from cache hits after the first.
"""
last_error = None
# Build structured system + user content with cache_control markers.
system_blocks: list[dict[str, Any]] = [
{
"type": "text",
"text": system_prompt,
"cache_control": {"type": "ephemeral"},
}
]
user_content: list[dict[str, Any]] = [
{
"type": "text",
"text": cached_user_content,
"cache_control": {"type": "ephemeral"},
},
{
"type": "text",
"text": dynamic_user_content,
},
]
for attempt in range(1, self.max_retries + 1):
try:
# Streaming required: a single non-streaming request is
# capped by the SDK at 10 minutes. Long batches (32k+
# output tokens) routinely exceed that. Streaming keeps
# the connection alive event-by-event and has no time cap.
response_text = ""
with self.client.messages.stream(
model=self.model,
max_tokens=max_tokens,
system=system_blocks,
messages=[{"role": "user", "content": user_content}],
temperature=temperature,
) as stream:
for chunk in stream.text_stream:
response_text += chunk
final = stream.get_final_message()
input_tokens = final.usage.input_tokens
output_tokens = final.usage.output_tokens
cache_read = getattr(final.usage, "cache_read_input_tokens", 0) or 0
cache_creation = getattr(final.usage, "cache_creation_input_tokens", 0) or 0
total_tokens = input_tokens + output_tokens
input_rate, output_rate = MODEL_COSTS.get(
self.model, (COST_PER_INPUT_TOKEN, COST_PER_OUTPUT_TOKEN)
)
# Cache writes cost 1.25x input rate, cache reads 0.1x input rate.
estimated_cost = (
input_tokens * input_rate
+ cache_creation * input_rate * 1.25
+ cache_read * input_rate * 0.1
+ output_tokens * output_rate
)
stop_reason = getattr(final, "stop_reason", None)
if stop_reason == "max_tokens":
logger.warning(
"LLM hit max_tokens (%d) — response was truncated. "
"Consider raising max_tokens or smaller batches.",
max_tokens,
)
usage = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
"cache_read_input_tokens": cache_read,
"cache_creation_input_tokens": cache_creation,
"estimated_cost_usd": round(estimated_cost, 6),
"model": self.model,
"stop_reason": stop_reason,
}
self.last_usage = usage
return response_text, usage
except anthropic.RateLimitError as e:
last_error = e
delay = self.base_delay * (2 ** (attempt - 1))
logger.warning(
f"Rate limited (attempt {attempt}/{self.max_retries}), retrying in {delay}s"
)
time.sleep(delay)
except anthropic.APIStatusError as e:
if e.status_code >= 500:
last_error = e
delay = self.base_delay * (2 ** (attempt - 1))
logger.warning(
f"Server error {e.status_code} (attempt {attempt}/{self.max_retries}), retrying in {delay}s"
)
time.sleep(delay)
else:
raise
raise last_error # type: ignore[misc]
async def acreate_message_cached(
self,
system_prompt: str,
cached_user_content: str,
dynamic_user_content: str,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> tuple[str, dict[str, Any]]:
"""Async version of create_message_cached."""
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self.create_message_cached(
system_prompt,
cached_user_content,
dynamic_user_content,
max_tokens,
temperature,
),
)