- Model renames: gpt-5.2 → gpt-5.4-2026-03-05, gemini-3-pro-preview → gemini-3.1-pro-preview; retire gpt-4.1 via alias fallback - New: llm_usage_context.py (ContextVar-based attribution), model_pricing.py (tiered pricing + 60s cache), usage_event.py (append-only telemetry), quota.py (user/FG quota enforcement with 80% warning) - Wire _record_usage into all 3 LLM methods; set_llm_context at every service entry point - Fix admin_required decorator (was sync, never awaited User.find_by_id); add active_required and with_user_context decorators - Inject user_id into ContextVar from JWT on every authenticated request - Add DB indexes for usage_events, model_pricing, users collections - Seed script for model pricing (gpt-5.4 single-tier, gemini-3.1 two-tier 200k threshold) - Fix parse_json_response NameError (logger undefined at module level) - 70 passing tests: conftest.py with sys.modules stubs, test_usage_infrastructure.py (52 tests), rewrite stale test_llm_service.py (18 tests) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
664 lines
29 KiB
Python
664 lines
29 KiB
Python
"""
|
|
Tests for the usage tracking infrastructure added in Phase A-C:
|
|
- LLMCallContext / llm_usage_context module
|
|
- ModelPricing.pick_tier and ModelPricing.compute_cost
|
|
- LLMService._extract_usage_metadata and LLMService._resolve_model
|
|
- UsageEvent.record (non-fatal DB error handling, field validation)
|
|
- check_quota (admin bypass, warning threshold, exceeded error)
|
|
"""
|
|
|
|
import asyncio
|
|
import sys
|
|
import pytest
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
# 1. LLM usage context — pure stdlib, no mocking needed
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
class TestLLMUsageContext:
|
|
def setup_method(self):
|
|
# Import fresh each test to avoid cross-test ContextVar pollution
|
|
from app.services.llm_usage_context import (
|
|
LLMCallContext, _ctx, current_context, set_llm_context, llm_context
|
|
)
|
|
self.LLMCallContext = LLMCallContext
|
|
self._ctx = _ctx
|
|
self.current_context = current_context
|
|
self.set_llm_context = set_llm_context
|
|
self.llm_context = llm_context
|
|
# Reset context to default before each test
|
|
self._ctx.set(LLMCallContext())
|
|
|
|
def test_default_context(self):
|
|
ctx = self.current_context()
|
|
assert ctx.user_id is None
|
|
assert ctx.focus_group_id is None
|
|
assert ctx.persona_id is None
|
|
assert ctx.feature == "other"
|
|
assert ctx.task_id is None
|
|
|
|
def test_set_llm_context_mutates(self):
|
|
self.set_llm_context(user_id="u1", feature="persona_generate")
|
|
ctx = self.current_context()
|
|
assert ctx.user_id == "u1"
|
|
assert ctx.feature == "persona_generate"
|
|
# Unset fields remain None
|
|
assert ctx.focus_group_id is None
|
|
|
|
def test_set_llm_context_merges_with_existing(self):
|
|
self.set_llm_context(user_id="u1")
|
|
self.set_llm_context(focus_group_id="fg1")
|
|
ctx = self.current_context()
|
|
# Both values present
|
|
assert ctx.user_id == "u1"
|
|
assert ctx.focus_group_id == "fg1"
|
|
|
|
def test_llm_context_restores_on_exit(self):
|
|
self.set_llm_context(user_id="outer")
|
|
with self.llm_context(user_id="inner", feature="moderator"):
|
|
assert self.current_context().user_id == "inner"
|
|
assert self.current_context().feature == "moderator"
|
|
# Restored
|
|
restored = self.current_context()
|
|
assert restored.user_id == "outer"
|
|
assert restored.feature == "other"
|
|
|
|
def test_llm_context_restores_on_exception(self):
|
|
self.set_llm_context(user_id="before")
|
|
try:
|
|
with self.llm_context(user_id="during"):
|
|
raise ValueError("boom")
|
|
except ValueError:
|
|
pass
|
|
assert self.current_context().user_id == "before"
|
|
|
|
def test_llm_context_stacking(self):
|
|
with self.llm_context(user_id="u1"):
|
|
with self.llm_context(focus_group_id="fg1"):
|
|
with self.llm_context(feature="key_themes"):
|
|
ctx = self.current_context()
|
|
assert ctx.user_id == "u1"
|
|
assert ctx.focus_group_id == "fg1"
|
|
assert ctx.feature == "key_themes"
|
|
assert self.current_context().feature == "other" # restored
|
|
assert self.current_context().focus_group_id is None # restored
|
|
assert self.current_context().user_id is None # restored
|
|
|
|
def test_frozen_dataclass_immutable(self):
|
|
ctx = self.LLMCallContext(user_id="x")
|
|
with pytest.raises((AttributeError, TypeError)):
|
|
ctx.user_id = "y" # type: ignore
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
# 2. ModelPricing pure logic — pick_tier and compute_cost
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
class TestModelPricingPureLogic:
|
|
def setup_method(self):
|
|
from app.models.model_pricing import ModelPricing
|
|
self.ModelPricing = ModelPricing
|
|
|
|
def _gemini_pricing(self):
|
|
"""Two-tier Gemini pricing: <200k and >=200k."""
|
|
return {
|
|
"model": "gemini-3.1-pro-preview",
|
|
"tiers": [
|
|
{
|
|
"threshold_input_tokens": 0,
|
|
"input_per_mtok": 2.0,
|
|
"cached_input_per_mtok": None,
|
|
"output_per_mtok": 12.0,
|
|
},
|
|
{
|
|
"threshold_input_tokens": 200_000,
|
|
"input_per_mtok": 4.0,
|
|
"cached_input_per_mtok": None,
|
|
"output_per_mtok": 18.0,
|
|
},
|
|
],
|
|
}
|
|
|
|
def _gpt_pricing(self):
|
|
"""Single-tier GPT pricing."""
|
|
return {
|
|
"model": "gpt-5.4-2026-03-05",
|
|
"tiers": [
|
|
{
|
|
"threshold_input_tokens": 0,
|
|
"input_per_mtok": 2.50,
|
|
"cached_input_per_mtok": 0.25,
|
|
"output_per_mtok": 15.0,
|
|
}
|
|
],
|
|
}
|
|
|
|
# ── pick_tier ──────────────────────────────────────────────────────────────
|
|
|
|
def test_pick_tier_below_all_thresholds(self):
|
|
pricing = self._gemini_pricing()
|
|
tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=1000)
|
|
assert tier["threshold_input_tokens"] == 0
|
|
assert tier["input_per_mtok"] == 2.0
|
|
|
|
def test_pick_tier_above_high_threshold(self):
|
|
pricing = self._gemini_pricing()
|
|
tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=250_000)
|
|
assert tier["threshold_input_tokens"] == 200_000
|
|
assert tier["input_per_mtok"] == 4.0
|
|
|
|
def test_pick_tier_exactly_at_threshold(self):
|
|
pricing = self._gemini_pricing()
|
|
tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=200_000)
|
|
assert tier["threshold_input_tokens"] == 200_000
|
|
|
|
def test_pick_tier_single_tier(self):
|
|
pricing = self._gpt_pricing()
|
|
tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=500_000)
|
|
assert tier["input_per_mtok"] == 2.50
|
|
|
|
def test_pick_tier_none_pricing_returns_none(self):
|
|
assert self.ModelPricing.pick_tier(None, 1000) is None
|
|
|
|
def test_pick_tier_empty_tiers_returns_none(self):
|
|
assert self.ModelPricing.pick_tier({"tiers": []}, 1000) is None
|
|
|
|
# ── compute_cost ──────────────────────────────────────────────────────────
|
|
|
|
def test_compute_cost_none_pricing_returns_zeros(self):
|
|
cost = self.ModelPricing.compute_cost(None, 100, 50)
|
|
assert cost == {"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}
|
|
|
|
def test_compute_cost_basic_gemini_low_tier(self):
|
|
# 100k input, 10k output, no cached — stays in low tier
|
|
cost = self.ModelPricing.compute_cost(
|
|
self._gemini_pricing(),
|
|
prompt_tokens=100_000,
|
|
completion_tokens=10_000,
|
|
cached_tokens=0,
|
|
)
|
|
expected_input = 100_000 * 2.0 / 1_000_000 # $0.20
|
|
expected_output = 10_000 * 12.0 / 1_000_000 # $0.12
|
|
assert abs(cost["input"] - expected_input) < 1e-9
|
|
assert cost["cached"] == 0.0
|
|
assert abs(cost["output"] - expected_output) < 1e-9
|
|
assert abs(cost["total"] - (expected_input + expected_output)) < 1e-9
|
|
|
|
def test_compute_cost_basic_gemini_high_tier(self):
|
|
# 300k input → high tier: $4/$18
|
|
cost = self.ModelPricing.compute_cost(
|
|
self._gemini_pricing(),
|
|
prompt_tokens=300_000,
|
|
completion_tokens=5_000,
|
|
)
|
|
expected_input = 300_000 * 4.0 / 1_000_000
|
|
expected_output = 5_000 * 18.0 / 1_000_000
|
|
assert abs(cost["input"] - expected_input) < 1e-9
|
|
assert abs(cost["output"] - expected_output) < 1e-9
|
|
|
|
def test_compute_cost_with_cached_tokens_gpt(self):
|
|
# 10k input, 2k cached (cheaper), 5k output
|
|
cost = self.ModelPricing.compute_cost(
|
|
self._gpt_pricing(),
|
|
prompt_tokens=10_000,
|
|
completion_tokens=5_000,
|
|
cached_tokens=2_000,
|
|
)
|
|
# Billable input = 10k - 2k = 8k at $2.50/M
|
|
expected_input = 8_000 * 2.50 / 1_000_000
|
|
# Cached 2k at $0.25/M
|
|
expected_cached = 2_000 * 0.25 / 1_000_000
|
|
# Output 5k at $15/M
|
|
expected_output = 5_000 * 15.0 / 1_000_000
|
|
expected_total = expected_input + expected_cached + expected_output
|
|
|
|
assert abs(cost["input"] - expected_input) < 1e-9
|
|
assert abs(cost["cached"] - expected_cached) < 1e-9
|
|
assert abs(cost["output"] - expected_output) < 1e-9
|
|
assert abs(cost["total"] - expected_total) < 1e-9
|
|
|
|
def test_compute_cost_total_equals_sum_of_components(self):
|
|
cost = self.ModelPricing.compute_cost(
|
|
self._gpt_pricing(),
|
|
prompt_tokens=50_000,
|
|
completion_tokens=20_000,
|
|
cached_tokens=5_000,
|
|
)
|
|
assert abs(cost["total"] - (cost["input"] + cost["cached"] + cost["output"])) < 1e-9
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
# 3. LLMService._extract_usage_metadata and _resolve_model
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
class TestLLMServicePureStaticMethods:
|
|
def setup_method(self):
|
|
from app.services.llm_service import LLMService, MODEL_ALIASES, DEFAULT_MODEL
|
|
self.LLMService = LLMService
|
|
self.MODEL_ALIASES = MODEL_ALIASES
|
|
self.DEFAULT_MODEL = DEFAULT_MODEL
|
|
|
|
# ── _resolve_model ────────────────────────────────────────────────────────
|
|
|
|
def test_resolve_model_none_returns_default(self):
|
|
assert self.LLMService._resolve_model(None) == self.DEFAULT_MODEL
|
|
|
|
def test_resolve_model_default_unchanged(self):
|
|
assert self.LLMService._resolve_model("gemini-3.1-pro-preview") == "gemini-3.1-pro-preview"
|
|
|
|
def test_resolve_model_gpt5_alias(self):
|
|
assert self.LLMService._resolve_model("gpt-5") == "gpt-5.4-2026-03-05"
|
|
|
|
def test_resolve_model_gpt52_alias(self):
|
|
assert self.LLMService._resolve_model("gpt-5.2") == "gpt-5.4-2026-03-05"
|
|
|
|
def test_resolve_model_gemini3_alias(self):
|
|
assert self.LLMService._resolve_model("gemini-3-pro-preview") == "gemini-3.1-pro-preview"
|
|
|
|
def test_resolve_model_gpt41_retired_falls_to_gemini(self):
|
|
assert self.LLMService._resolve_model("gpt-4.1") == "gemini-3.1-pro-preview"
|
|
|
|
def test_resolve_model_known_openai_unchanged(self):
|
|
assert self.LLMService._resolve_model("gpt-5.4-2026-03-05") == "gpt-5.4-2026-03-05"
|
|
|
|
def test_resolve_model_unknown_passthrough(self):
|
|
# Unknown model names should pass through untouched
|
|
assert self.LLMService._resolve_model("some-future-model") == "some-future-model"
|
|
|
|
# ── _extract_usage_metadata ───────────────────────────────────────────────
|
|
|
|
def test_extract_gemini_full(self):
|
|
response = MagicMock()
|
|
um = MagicMock()
|
|
um.prompt_token_count = 1000
|
|
um.candidates_token_count = 200
|
|
um.cached_content_token_count = 50
|
|
response.usage_metadata = um
|
|
|
|
result = self.LLMService._extract_usage_metadata(response, "gemini")
|
|
assert result == {"prompt": 1000, "completion": 200, "cached": 50, "reasoning": 0}
|
|
|
|
def test_extract_gemini_missing_usage_metadata(self):
|
|
response = MagicMock(spec=[]) # no attributes
|
|
result = self.LLMService._extract_usage_metadata(response, "gemini")
|
|
assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0}
|
|
|
|
def test_extract_gemini_none_usage_metadata(self):
|
|
response = MagicMock()
|
|
response.usage_metadata = None
|
|
result = self.LLMService._extract_usage_metadata(response, "gemini")
|
|
assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0}
|
|
|
|
def test_extract_openai_responses_api(self):
|
|
"""OpenAI Responses API format (gpt-5.4-2026-03-05)."""
|
|
response = MagicMock()
|
|
usage = MagicMock()
|
|
usage.input_tokens = 5000
|
|
usage.output_tokens = 1000
|
|
input_details = MagicMock()
|
|
input_details.cached_tokens = 200
|
|
output_details = MagicMock()
|
|
output_details.reasoning_tokens = 300
|
|
usage.input_tokens_details = input_details
|
|
usage.output_tokens_details = output_details
|
|
response.usage = usage
|
|
|
|
result = self.LLMService._extract_usage_metadata(response, "openai")
|
|
assert result == {"prompt": 5000, "completion": 1000, "cached": 200, "reasoning": 300}
|
|
|
|
def test_extract_openai_chat_completions(self):
|
|
"""OpenAI Chat Completions format (legacy)."""
|
|
response = MagicMock()
|
|
usage = MagicMock(spec=['prompt_tokens', 'completion_tokens', 'prompt_tokens_details'])
|
|
usage.prompt_tokens = 3000
|
|
usage.completion_tokens = 800
|
|
details = MagicMock()
|
|
details.cached_tokens = 100
|
|
usage.prompt_tokens_details = details
|
|
response.usage = usage
|
|
|
|
result = self.LLMService._extract_usage_metadata(response, "openai")
|
|
assert result == {"prompt": 3000, "completion": 800, "cached": 100, "reasoning": 0}
|
|
|
|
def test_extract_openai_missing_usage(self):
|
|
response = MagicMock()
|
|
response.usage = None
|
|
result = self.LLMService._extract_usage_metadata(response, "openai")
|
|
assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0}
|
|
|
|
def test_extract_unknown_provider_returns_zeros(self):
|
|
result = self.LLMService._extract_usage_metadata(MagicMock(), "anthropic")
|
|
assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0}
|
|
|
|
def test_extract_gemini_token_counts_default_to_zero_when_none(self):
|
|
"""Fields that return None should be coerced to 0, not None."""
|
|
response = MagicMock()
|
|
um = MagicMock()
|
|
um.prompt_token_count = None
|
|
um.candidates_token_count = None
|
|
um.cached_content_token_count = None
|
|
response.usage_metadata = um
|
|
|
|
result = self.LLMService._extract_usage_metadata(response, "gemini")
|
|
assert result["prompt"] == 0
|
|
assert result["completion"] == 0
|
|
assert result["cached"] == 0
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
# 4. UsageEvent.record — non-fatal DB error handling, field coercion
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
class TestUsageEventRecord:
|
|
def setup_method(self):
|
|
# Reset module-level imports for clean state
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_normalizes_unknown_feature(self):
|
|
"""Unknown feature names should be stored as 'other'."""
|
|
captured = {}
|
|
|
|
async def fake_insert(doc):
|
|
captured['doc'] = doc
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert)
|
|
|
|
with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)):
|
|
from app.models.usage_event import UsageEvent
|
|
await UsageEvent.record(
|
|
provider="gemini",
|
|
model="gemini-3.1-pro-preview",
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0},
|
|
feature="totally_invalid_feature_xyz",
|
|
)
|
|
|
|
assert captured['doc']['feature'] == "other"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_valid_feature_preserved(self):
|
|
captured = {}
|
|
|
|
async def fake_insert(doc):
|
|
captured['doc'] = doc
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert)
|
|
|
|
with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)):
|
|
from app.models.usage_event import UsageEvent
|
|
await UsageEvent.record(
|
|
provider="gemini",
|
|
model="gemini-3.1-pro-preview",
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0},
|
|
feature="key_themes",
|
|
)
|
|
|
|
assert captured['doc']['feature'] == "key_themes"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_total_tokens_computed(self):
|
|
captured = {}
|
|
|
|
async def fake_insert(doc):
|
|
captured['doc'] = doc
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert)
|
|
|
|
with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)):
|
|
from app.models.usage_event import UsageEvent
|
|
await UsageEvent.record(
|
|
provider="gemini",
|
|
model="gemini-3.1-pro-preview",
|
|
prompt_tokens=300,
|
|
completion_tokens=150,
|
|
cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0},
|
|
)
|
|
|
|
assert captured['doc']['total_tokens'] == 450
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_error_truncated_to_500_chars(self):
|
|
captured = {}
|
|
|
|
async def fake_insert(doc):
|
|
captured['doc'] = doc
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert)
|
|
|
|
long_error = "E" * 1000
|
|
|
|
with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)):
|
|
from app.models.usage_event import UsageEvent
|
|
await UsageEvent.record(
|
|
provider="gemini",
|
|
model="gemini-3.1-pro-preview",
|
|
prompt_tokens=10,
|
|
completion_tokens=5,
|
|
cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0},
|
|
status="error",
|
|
error=long_error,
|
|
)
|
|
|
|
assert len(captured['doc']['error']) == 500
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_db_error_is_swallowed(self):
|
|
"""A DB write failure must not propagate — telemetry cannot kill LLM calls."""
|
|
mock_db = MagicMock()
|
|
mock_db.usage_events.insert_one = AsyncMock(side_effect=RuntimeError("DB down"))
|
|
|
|
with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)):
|
|
from app.models.usage_event import UsageEvent
|
|
# Should not raise
|
|
await UsageEvent.record(
|
|
provider="gemini",
|
|
model="gemini-3.1-pro-preview",
|
|
prompt_tokens=10,
|
|
completion_tokens=5,
|
|
cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0},
|
|
)
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
# 5. check_quota — admin bypass, warning threshold, exceeded error
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
class TestCheckQuota:
|
|
"""
|
|
check_quota does lazy imports inside the function body:
|
|
from app.models.user import User
|
|
from app.models.usage_event import UsageEvent
|
|
from app.models.focus_group import FocusGroup
|
|
|
|
We inject mocks via patch.dict(sys.modules) so that when the function
|
|
executes those import statements at runtime it gets our mock objects.
|
|
"""
|
|
|
|
def _make_user_mod(self, user_doc):
|
|
m = MagicMock()
|
|
m.User.find_by_id = AsyncMock(return_value=user_doc)
|
|
return m
|
|
|
|
def _make_usage_mod(self, total):
|
|
m = MagicMock()
|
|
m.UsageEvent.sum_cost = AsyncMock(return_value=total)
|
|
return m
|
|
|
|
def _make_fg_mod(self, fg_doc):
|
|
m = MagicMock()
|
|
m.FocusGroup.find_by_id = AsyncMock(return_value=fg_doc)
|
|
return m
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_user_no_fg_returns_none(self):
|
|
from app.models.quota import check_quota
|
|
result = await check_quota(user_id=None, focus_group_id=None)
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_admin_bypasses_user_quota(self):
|
|
user_mod = self._make_user_mod(
|
|
{"role": "admin", "quota": {"monthly_usd": 0.01}, "override_quota": False}
|
|
)
|
|
usage_mod = self._make_usage_mod(999.0)
|
|
|
|
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
|
|
from app.models.quota import check_quota
|
|
result = await check_quota(user_id="admin_id", focus_group_id=None)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_override_quota_bypasses_user_quota(self):
|
|
user_mod = self._make_user_mod(
|
|
{"role": "user", "quota": {"monthly_usd": 1.0}, "override_quota": True}
|
|
)
|
|
usage_mod = self._make_usage_mod(999.0)
|
|
|
|
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
|
|
from app.models.quota import check_quota
|
|
result = await check_quota(user_id="u1", focus_group_id=None)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_quota_configured_returns_none(self):
|
|
user_mod = self._make_user_mod({"role": "user", "quota": {}, "override_quota": False})
|
|
usage_mod = self._make_usage_mod(10.0)
|
|
|
|
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
|
|
from app.models.quota import check_quota
|
|
result = await check_quota(user_id="u1", focus_group_id=None)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_under_quota_returns_none(self):
|
|
user_mod = self._make_user_mod(
|
|
{"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False}
|
|
)
|
|
usage_mod = self._make_usage_mod(20.0)
|
|
|
|
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
|
|
from app.models.quota import check_quota
|
|
result = await check_quota(user_id="u1", focus_group_id=None)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_quota_at_80_percent_returns_warning(self):
|
|
user_mod = self._make_user_mod(
|
|
{"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False}
|
|
)
|
|
usage_mod = self._make_usage_mod(42.0) # 84%
|
|
|
|
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
|
|
from app.models.quota import check_quota, QuotaWarning
|
|
result = await check_quota(user_id="u1", focus_group_id=None)
|
|
|
|
assert isinstance(result, QuotaWarning)
|
|
assert result.scope == "user"
|
|
assert result.limit_usd == 50.0
|
|
assert result.used_usd == 42.0
|
|
assert result.pct == pytest.approx(0.84, abs=0.001)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_quota_exceeded_raises_error(self):
|
|
user_mod = self._make_user_mod(
|
|
{"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False}
|
|
)
|
|
usage_mod = self._make_usage_mod(50.01)
|
|
|
|
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
|
|
from app.models.quota import check_quota, QuotaExceededError
|
|
with pytest.raises(QuotaExceededError) as exc_info:
|
|
await check_quota(user_id="u1", focus_group_id=None)
|
|
|
|
err = exc_info.value
|
|
assert err.scope == "user"
|
|
assert err.limit_usd == 50.0
|
|
assert err.used_usd == pytest.approx(50.01)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fg_quota_exceeded_raises_error(self):
|
|
fg_mod = self._make_fg_mod({"quota": {"total_usd": 10.0}})
|
|
usage_mod = self._make_usage_mod(10.01)
|
|
|
|
with patch.dict(sys.modules, {'app.models.focus_group': fg_mod, 'app.models.usage_event': usage_mod}):
|
|
from app.models.quota import check_quota, QuotaExceededError
|
|
with pytest.raises(QuotaExceededError) as exc_info:
|
|
await check_quota(user_id=None, focus_group_id="fg1")
|
|
|
|
assert exc_info.value.scope == "focus_group"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fg_quota_at_80_percent_returns_warning(self):
|
|
fg_mod = self._make_fg_mod({"quota": {"total_usd": 100.0}})
|
|
usage_mod = self._make_usage_mod(85.0)
|
|
|
|
with patch.dict(sys.modules, {'app.models.focus_group': fg_mod, 'app.models.usage_event': usage_mod}):
|
|
from app.models.quota import check_quota, QuotaWarning
|
|
result = await check_quota(user_id=None, focus_group_id="fg1")
|
|
|
|
assert isinstance(result, QuotaWarning)
|
|
assert result.scope == "focus_group"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_db_error_in_quota_check_is_non_fatal(self):
|
|
"""If User.find_by_id raises, quota check swallows and allows the call."""
|
|
user_mod = MagicMock()
|
|
user_mod.User.find_by_id = AsyncMock(side_effect=RuntimeError("DB timeout"))
|
|
|
|
with patch.dict(sys.modules, {'app.models.user': user_mod}):
|
|
from app.models.quota import check_quota
|
|
result = await check_quota(user_id="u1", focus_group_id=None)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exactly_at_quota_limit_is_exceeded(self):
|
|
"""Spending exactly the limit should trigger QuotaExceededError (>=)."""
|
|
user_mod = self._make_user_mod(
|
|
{"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False}
|
|
)
|
|
usage_mod = self._make_usage_mod(50.0) # exactly at limit
|
|
|
|
with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}):
|
|
from app.models.quota import check_quota, QuotaExceededError
|
|
with pytest.raises(QuotaExceededError):
|
|
await check_quota(user_id="u1", focus_group_id=None)
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
# 6. QuotaExceededError — message and attributes
|
|
# ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
class TestQuotaExceededError:
|
|
def test_attributes(self):
|
|
from app.models.quota import QuotaExceededError
|
|
from datetime import datetime, timezone
|
|
|
|
period = datetime(2026, 4, 1, tzinfo=timezone.utc)
|
|
err = QuotaExceededError("user", 50.0, 52.34, period)
|
|
|
|
assert err.scope == "user"
|
|
assert err.limit_usd == 50.0
|
|
assert err.used_usd == 52.34
|
|
assert err.period_start == period
|
|
assert "user" in str(err)
|
|
assert "52.3400" in str(err)
|
|
assert "50.00" in str(err)
|
|
|
|
def test_is_exception(self):
|
|
from app.models.quota import QuotaExceededError
|
|
assert issubclass(QuotaExceededError, Exception)
|