The Anthropic SDK refuses non-streaming calls expected to take >10
minutes ("Streaming is required..."). Long-output batches (32k tokens
of densely-formatted markdown) hit this on real 172-line briefs.
Both LLMClient.create_message and create_message_cached now use the
streaming context manager (client.messages.stream(...)) and accumulate
text chunks; final usage + stop_reason come from get_final_message().
No timeout on streaming requests.
Tightened the batch tier so individual streams stay well under any
ceiling and progress / failure recovery is more granular:
- ≤50 lines: single call
- 51-120: batches of 30 (max_tokens=16k each)
- 121+: batches of 25 (max_tokens=16k each)
Verified with the 172-line case: 7 batches of 25, 172 drafts produced.
Live streaming call confirmed end-to-end (haiku returned, usage and
stop_reason populated correctly).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
304 lines
11 KiB
Python
304 lines
11 KiB
Python
"""Anthropic SDK wrapper with retry logic and token tracking."""
|
|
|
|
import logging
|
|
import time
|
|
from typing import Any
|
|
|
|
import anthropic
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Cost per token by model (approximate)
|
|
MODEL_COSTS: dict[str, tuple[float, float]] = {
|
|
# (input_cost_per_token, output_cost_per_token)
|
|
"claude-sonnet-4-6": (3.0 / 1_000_000, 15.0 / 1_000_000),
|
|
"claude-opus-4-6": (15.0 / 1_000_000, 75.0 / 1_000_000),
|
|
}
|
|
# Default fallback (Sonnet pricing)
|
|
COST_PER_INPUT_TOKEN = 3.0 / 1_000_000
|
|
COST_PER_OUTPUT_TOKEN = 15.0 / 1_000_000
|
|
|
|
|
|
class LLMClient:
|
|
"""Wrapper around the Anthropic SDK with retry and token tracking.
|
|
|
|
Provides exponential backoff retry on rate limit and server errors.
|
|
Tracks token usage per call for cost monitoring.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
model: str | None = None,
|
|
max_retries: int = 3,
|
|
base_delay: float = 1.0,
|
|
) -> None:
|
|
self.api_key = api_key or settings.ANTHROPIC_API_KEY
|
|
self.model = model or settings.LLM_MODEL
|
|
self.max_retries = max_retries
|
|
self.base_delay = base_delay
|
|
|
|
self.client = anthropic.Anthropic(api_key=self.api_key)
|
|
self.last_usage: dict[str, Any] = {}
|
|
|
|
def create_message(
|
|
self,
|
|
system_prompt: str,
|
|
user_message: str,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
) -> tuple[str, dict[str, Any]]:
|
|
"""Send a message to Claude and return the response with usage data.
|
|
|
|
Args:
|
|
system_prompt: The system prompt.
|
|
user_message: The user message.
|
|
max_tokens: Maximum tokens in the response.
|
|
temperature: Sampling temperature.
|
|
|
|
Returns:
|
|
Tuple of (response_text, usage_dict).
|
|
usage_dict has keys: input_tokens, output_tokens, total_tokens, estimated_cost_usd.
|
|
|
|
Raises:
|
|
anthropic.APIError: If all retries are exhausted.
|
|
"""
|
|
last_error = None
|
|
|
|
for attempt in range(1, self.max_retries + 1):
|
|
try:
|
|
# Streaming required for long requests (>10 min cap on
|
|
# non-streaming calls). All call paths use streaming for
|
|
# consistency.
|
|
response_text = ""
|
|
with self.client.messages.stream(
|
|
model=self.model,
|
|
max_tokens=max_tokens,
|
|
system=system_prompt,
|
|
messages=[{"role": "user", "content": user_message}],
|
|
temperature=temperature,
|
|
) as stream:
|
|
for chunk in stream.text_stream:
|
|
response_text += chunk
|
|
final = stream.get_final_message()
|
|
|
|
input_tokens = final.usage.input_tokens
|
|
output_tokens = final.usage.output_tokens
|
|
total_tokens = input_tokens + output_tokens
|
|
input_rate, output_rate = MODEL_COSTS.get(
|
|
self.model, (COST_PER_INPUT_TOKEN, COST_PER_OUTPUT_TOKEN)
|
|
)
|
|
estimated_cost = (
|
|
input_tokens * input_rate
|
|
+ output_tokens * output_rate
|
|
)
|
|
|
|
stop_reason = getattr(final, "stop_reason", None)
|
|
if stop_reason == "max_tokens":
|
|
logger.warning(
|
|
"LLM hit max_tokens (%d) — response was truncated. "
|
|
"Consider raising max_tokens or batching the input.",
|
|
max_tokens,
|
|
)
|
|
|
|
usage = {
|
|
"input_tokens": input_tokens,
|
|
"output_tokens": output_tokens,
|
|
"total_tokens": total_tokens,
|
|
"estimated_cost_usd": round(estimated_cost, 6),
|
|
"model": self.model,
|
|
"stop_reason": stop_reason,
|
|
}
|
|
self.last_usage = usage
|
|
|
|
return response_text, usage
|
|
|
|
except anthropic.RateLimitError as e:
|
|
last_error = e
|
|
delay = self.base_delay * (2 ** (attempt - 1))
|
|
logger.warning(
|
|
f"Rate limited (attempt {attempt}/{self.max_retries}), "
|
|
f"retrying in {delay}s"
|
|
)
|
|
time.sleep(delay)
|
|
|
|
except anthropic.APIStatusError as e:
|
|
if e.status_code >= 500:
|
|
last_error = e
|
|
delay = self.base_delay * (2 ** (attempt - 1))
|
|
logger.warning(
|
|
f"Server error {e.status_code} (attempt {attempt}/{self.max_retries}), "
|
|
f"retrying in {delay}s"
|
|
)
|
|
time.sleep(delay)
|
|
else:
|
|
raise
|
|
|
|
raise last_error # type: ignore[misc]
|
|
|
|
async def acreate_message(
|
|
self,
|
|
system_prompt: str,
|
|
user_message: str,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
) -> tuple[str, dict[str, Any]]:
|
|
"""Async version of create_message using the async client.
|
|
|
|
Same interface as create_message but uses asyncio.
|
|
"""
|
|
import asyncio
|
|
|
|
# Run sync client in executor to avoid blocking
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(
|
|
None,
|
|
lambda: self.create_message(
|
|
system_prompt, user_message, max_tokens, temperature
|
|
),
|
|
)
|
|
|
|
def create_message_cached(
|
|
self,
|
|
system_prompt: str,
|
|
cached_user_content: str,
|
|
dynamic_user_content: str,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
) -> tuple[str, dict[str, Any]]:
|
|
"""Send a message with prompt caching enabled.
|
|
|
|
Two cache breakpoints are placed:
|
|
1. on the system prompt (the V25 instructions — ~30k tokens)
|
|
2. on the static user content (TM + reference data — variable)
|
|
|
|
The dynamic_user_content (per-batch source lines) is appended
|
|
after the cached block and is NOT cached. This means every batch
|
|
within a job re-uses the cached prefix at ~10% of the input cost,
|
|
making source-line batching cheap.
|
|
|
|
Cache TTL is 5 minutes by default — well within a single job's
|
|
runtime, so all batches benefit from cache hits after the first.
|
|
"""
|
|
last_error = None
|
|
|
|
# Build structured system + user content with cache_control markers.
|
|
system_blocks: list[dict[str, Any]] = [
|
|
{
|
|
"type": "text",
|
|
"text": system_prompt,
|
|
"cache_control": {"type": "ephemeral"},
|
|
}
|
|
]
|
|
user_content: list[dict[str, Any]] = [
|
|
{
|
|
"type": "text",
|
|
"text": cached_user_content,
|
|
"cache_control": {"type": "ephemeral"},
|
|
},
|
|
{
|
|
"type": "text",
|
|
"text": dynamic_user_content,
|
|
},
|
|
]
|
|
|
|
for attempt in range(1, self.max_retries + 1):
|
|
try:
|
|
# Streaming required: a single non-streaming request is
|
|
# capped by the SDK at 10 minutes. Long batches (32k+
|
|
# output tokens) routinely exceed that. Streaming keeps
|
|
# the connection alive event-by-event and has no time cap.
|
|
response_text = ""
|
|
with self.client.messages.stream(
|
|
model=self.model,
|
|
max_tokens=max_tokens,
|
|
system=system_blocks,
|
|
messages=[{"role": "user", "content": user_content}],
|
|
temperature=temperature,
|
|
) as stream:
|
|
for chunk in stream.text_stream:
|
|
response_text += chunk
|
|
final = stream.get_final_message()
|
|
|
|
input_tokens = final.usage.input_tokens
|
|
output_tokens = final.usage.output_tokens
|
|
cache_read = getattr(final.usage, "cache_read_input_tokens", 0) or 0
|
|
cache_creation = getattr(final.usage, "cache_creation_input_tokens", 0) or 0
|
|
total_tokens = input_tokens + output_tokens
|
|
|
|
input_rate, output_rate = MODEL_COSTS.get(
|
|
self.model, (COST_PER_INPUT_TOKEN, COST_PER_OUTPUT_TOKEN)
|
|
)
|
|
# Cache writes cost 1.25x input rate, cache reads 0.1x input rate.
|
|
estimated_cost = (
|
|
input_tokens * input_rate
|
|
+ cache_creation * input_rate * 1.25
|
|
+ cache_read * input_rate * 0.1
|
|
+ output_tokens * output_rate
|
|
)
|
|
|
|
stop_reason = getattr(final, "stop_reason", None)
|
|
if stop_reason == "max_tokens":
|
|
logger.warning(
|
|
"LLM hit max_tokens (%d) — response was truncated. "
|
|
"Consider raising max_tokens or smaller batches.",
|
|
max_tokens,
|
|
)
|
|
|
|
usage = {
|
|
"input_tokens": input_tokens,
|
|
"output_tokens": output_tokens,
|
|
"total_tokens": total_tokens,
|
|
"cache_read_input_tokens": cache_read,
|
|
"cache_creation_input_tokens": cache_creation,
|
|
"estimated_cost_usd": round(estimated_cost, 6),
|
|
"model": self.model,
|
|
"stop_reason": stop_reason,
|
|
}
|
|
self.last_usage = usage
|
|
return response_text, usage
|
|
|
|
except anthropic.RateLimitError as e:
|
|
last_error = e
|
|
delay = self.base_delay * (2 ** (attempt - 1))
|
|
logger.warning(
|
|
f"Rate limited (attempt {attempt}/{self.max_retries}), retrying in {delay}s"
|
|
)
|
|
time.sleep(delay)
|
|
except anthropic.APIStatusError as e:
|
|
if e.status_code >= 500:
|
|
last_error = e
|
|
delay = self.base_delay * (2 ** (attempt - 1))
|
|
logger.warning(
|
|
f"Server error {e.status_code} (attempt {attempt}/{self.max_retries}), retrying in {delay}s"
|
|
)
|
|
time.sleep(delay)
|
|
else:
|
|
raise
|
|
|
|
raise last_error # type: ignore[misc]
|
|
|
|
async def acreate_message_cached(
|
|
self,
|
|
system_prompt: str,
|
|
cached_user_content: str,
|
|
dynamic_user_content: str,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
) -> tuple[str, dict[str, Any]]:
|
|
"""Async version of create_message_cached."""
|
|
import asyncio
|
|
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(
|
|
None,
|
|
lambda: self.create_message_cached(
|
|
system_prompt,
|
|
cached_user_content,
|
|
dynamic_user_content,
|
|
max_tokens,
|
|
temperature,
|
|
),
|
|
)
|