""" Tests for the usage tracking infrastructure added in Phase A-C: - LLMCallContext / llm_usage_context module - ModelPricing.pick_tier and ModelPricing.compute_cost - LLMService._extract_usage_metadata and LLMService._resolve_model - UsageEvent.record (non-fatal DB error handling, field validation) - check_quota (admin bypass, warning threshold, exceeded error) """ import asyncio import sys import pytest from unittest.mock import MagicMock, AsyncMock, patch # ══════════════════════════════════════════════════════════════════════════════ # 1. LLM usage context — pure stdlib, no mocking needed # ══════════════════════════════════════════════════════════════════════════════ class TestLLMUsageContext: def setup_method(self): # Import fresh each test to avoid cross-test ContextVar pollution from app.services.llm_usage_context import ( LLMCallContext, _ctx, current_context, set_llm_context, llm_context ) self.LLMCallContext = LLMCallContext self._ctx = _ctx self.current_context = current_context self.set_llm_context = set_llm_context self.llm_context = llm_context # Reset context to default before each test self._ctx.set(LLMCallContext()) def test_default_context(self): ctx = self.current_context() assert ctx.user_id is None assert ctx.focus_group_id is None assert ctx.persona_id is None assert ctx.feature == "other" assert ctx.task_id is None def test_set_llm_context_mutates(self): self.set_llm_context(user_id="u1", feature="persona_generate") ctx = self.current_context() assert ctx.user_id == "u1" assert ctx.feature == "persona_generate" # Unset fields remain None assert ctx.focus_group_id is None def test_set_llm_context_merges_with_existing(self): self.set_llm_context(user_id="u1") self.set_llm_context(focus_group_id="fg1") ctx = self.current_context() # Both values present assert ctx.user_id == "u1" assert ctx.focus_group_id == "fg1" def test_llm_context_restores_on_exit(self): self.set_llm_context(user_id="outer") with self.llm_context(user_id="inner", feature="moderator"): assert self.current_context().user_id == "inner" assert self.current_context().feature == "moderator" # Restored restored = self.current_context() assert restored.user_id == "outer" assert restored.feature == "other" def test_llm_context_restores_on_exception(self): self.set_llm_context(user_id="before") try: with self.llm_context(user_id="during"): raise ValueError("boom") except ValueError: pass assert self.current_context().user_id == "before" def test_llm_context_stacking(self): with self.llm_context(user_id="u1"): with self.llm_context(focus_group_id="fg1"): with self.llm_context(feature="key_themes"): ctx = self.current_context() assert ctx.user_id == "u1" assert ctx.focus_group_id == "fg1" assert ctx.feature == "key_themes" assert self.current_context().feature == "other" # restored assert self.current_context().focus_group_id is None # restored assert self.current_context().user_id is None # restored def test_frozen_dataclass_immutable(self): ctx = self.LLMCallContext(user_id="x") with pytest.raises((AttributeError, TypeError)): ctx.user_id = "y" # type: ignore # ══════════════════════════════════════════════════════════════════════════════ # 2. ModelPricing pure logic — pick_tier and compute_cost # ══════════════════════════════════════════════════════════════════════════════ class TestModelPricingPureLogic: def setup_method(self): from app.models.model_pricing import ModelPricing self.ModelPricing = ModelPricing def _gemini_pricing(self): """Two-tier Gemini pricing: <200k and >=200k.""" return { "model": "gemini-3.1-pro-preview", "tiers": [ { "threshold_input_tokens": 0, "input_per_mtok": 2.0, "cached_input_per_mtok": None, "output_per_mtok": 12.0, }, { "threshold_input_tokens": 200_000, "input_per_mtok": 4.0, "cached_input_per_mtok": None, "output_per_mtok": 18.0, }, ], } def _gpt_pricing(self): """Single-tier GPT pricing.""" return { "model": "gpt-5.4-2026-03-05", "tiers": [ { "threshold_input_tokens": 0, "input_per_mtok": 2.50, "cached_input_per_mtok": 0.25, "output_per_mtok": 15.0, } ], } # ── pick_tier ────────────────────────────────────────────────────────────── def test_pick_tier_below_all_thresholds(self): pricing = self._gemini_pricing() tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=1000) assert tier["threshold_input_tokens"] == 0 assert tier["input_per_mtok"] == 2.0 def test_pick_tier_above_high_threshold(self): pricing = self._gemini_pricing() tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=250_000) assert tier["threshold_input_tokens"] == 200_000 assert tier["input_per_mtok"] == 4.0 def test_pick_tier_exactly_at_threshold(self): pricing = self._gemini_pricing() tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=200_000) assert tier["threshold_input_tokens"] == 200_000 def test_pick_tier_single_tier(self): pricing = self._gpt_pricing() tier = self.ModelPricing.pick_tier(pricing, prompt_tokens=500_000) assert tier["input_per_mtok"] == 2.50 def test_pick_tier_none_pricing_returns_none(self): assert self.ModelPricing.pick_tier(None, 1000) is None def test_pick_tier_empty_tiers_returns_none(self): assert self.ModelPricing.pick_tier({"tiers": []}, 1000) is None # ── compute_cost ────────────────────────────────────────────────────────── def test_compute_cost_none_pricing_returns_zeros(self): cost = self.ModelPricing.compute_cost(None, 100, 50) assert cost == {"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0} def test_compute_cost_basic_gemini_low_tier(self): # 100k input, 10k output, no cached — stays in low tier cost = self.ModelPricing.compute_cost( self._gemini_pricing(), prompt_tokens=100_000, completion_tokens=10_000, cached_tokens=0, ) expected_input = 100_000 * 2.0 / 1_000_000 # $0.20 expected_output = 10_000 * 12.0 / 1_000_000 # $0.12 assert abs(cost["input"] - expected_input) < 1e-9 assert cost["cached"] == 0.0 assert abs(cost["output"] - expected_output) < 1e-9 assert abs(cost["total"] - (expected_input + expected_output)) < 1e-9 def test_compute_cost_basic_gemini_high_tier(self): # 300k input → high tier: $4/$18 cost = self.ModelPricing.compute_cost( self._gemini_pricing(), prompt_tokens=300_000, completion_tokens=5_000, ) expected_input = 300_000 * 4.0 / 1_000_000 expected_output = 5_000 * 18.0 / 1_000_000 assert abs(cost["input"] - expected_input) < 1e-9 assert abs(cost["output"] - expected_output) < 1e-9 def test_compute_cost_with_cached_tokens_gpt(self): # 10k input, 2k cached (cheaper), 5k output cost = self.ModelPricing.compute_cost( self._gpt_pricing(), prompt_tokens=10_000, completion_tokens=5_000, cached_tokens=2_000, ) # Billable input = 10k - 2k = 8k at $2.50/M expected_input = 8_000 * 2.50 / 1_000_000 # Cached 2k at $0.25/M expected_cached = 2_000 * 0.25 / 1_000_000 # Output 5k at $15/M expected_output = 5_000 * 15.0 / 1_000_000 expected_total = expected_input + expected_cached + expected_output assert abs(cost["input"] - expected_input) < 1e-9 assert abs(cost["cached"] - expected_cached) < 1e-9 assert abs(cost["output"] - expected_output) < 1e-9 assert abs(cost["total"] - expected_total) < 1e-9 def test_compute_cost_total_equals_sum_of_components(self): cost = self.ModelPricing.compute_cost( self._gpt_pricing(), prompt_tokens=50_000, completion_tokens=20_000, cached_tokens=5_000, ) assert abs(cost["total"] - (cost["input"] + cost["cached"] + cost["output"])) < 1e-9 # ══════════════════════════════════════════════════════════════════════════════ # 3. LLMService._extract_usage_metadata and _resolve_model # ══════════════════════════════════════════════════════════════════════════════ class TestLLMServicePureStaticMethods: def setup_method(self): from app.services.llm_service import LLMService, MODEL_ALIASES, DEFAULT_MODEL self.LLMService = LLMService self.MODEL_ALIASES = MODEL_ALIASES self.DEFAULT_MODEL = DEFAULT_MODEL # ── _resolve_model ──────────────────────────────────────────────────────── def test_resolve_model_none_returns_default(self): assert self.LLMService._resolve_model(None) == self.DEFAULT_MODEL def test_resolve_model_default_unchanged(self): assert self.LLMService._resolve_model("gemini-3.1-pro-preview") == "gemini-3.1-pro-preview" def test_resolve_model_gpt5_alias(self): assert self.LLMService._resolve_model("gpt-5") == "gpt-5.4-2026-03-05" def test_resolve_model_gpt52_alias(self): assert self.LLMService._resolve_model("gpt-5.2") == "gpt-5.4-2026-03-05" def test_resolve_model_gemini3_alias(self): assert self.LLMService._resolve_model("gemini-3-pro-preview") == "gemini-3.1-pro-preview" def test_resolve_model_gpt41_retired_falls_to_gemini(self): assert self.LLMService._resolve_model("gpt-4.1") == "gemini-3.1-pro-preview" def test_resolve_model_known_openai_unchanged(self): assert self.LLMService._resolve_model("gpt-5.4-2026-03-05") == "gpt-5.4-2026-03-05" def test_resolve_model_unknown_passthrough(self): # Unknown model names should pass through untouched assert self.LLMService._resolve_model("some-future-model") == "some-future-model" # ── _extract_usage_metadata ─────────────────────────────────────────────── def test_extract_gemini_full(self): response = MagicMock() um = MagicMock() um.prompt_token_count = 1000 um.candidates_token_count = 200 um.cached_content_token_count = 50 response.usage_metadata = um result = self.LLMService._extract_usage_metadata(response, "gemini") assert result == {"prompt": 1000, "completion": 200, "cached": 50, "reasoning": 0} def test_extract_gemini_missing_usage_metadata(self): response = MagicMock(spec=[]) # no attributes result = self.LLMService._extract_usage_metadata(response, "gemini") assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0} def test_extract_gemini_none_usage_metadata(self): response = MagicMock() response.usage_metadata = None result = self.LLMService._extract_usage_metadata(response, "gemini") assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0} def test_extract_openai_responses_api(self): """OpenAI Responses API format (gpt-5.4-2026-03-05).""" response = MagicMock() usage = MagicMock() usage.input_tokens = 5000 usage.output_tokens = 1000 input_details = MagicMock() input_details.cached_tokens = 200 output_details = MagicMock() output_details.reasoning_tokens = 300 usage.input_tokens_details = input_details usage.output_tokens_details = output_details response.usage = usage result = self.LLMService._extract_usage_metadata(response, "openai") assert result == {"prompt": 5000, "completion": 1000, "cached": 200, "reasoning": 300} def test_extract_openai_chat_completions(self): """OpenAI Chat Completions format (legacy).""" response = MagicMock() usage = MagicMock(spec=['prompt_tokens', 'completion_tokens', 'prompt_tokens_details']) usage.prompt_tokens = 3000 usage.completion_tokens = 800 details = MagicMock() details.cached_tokens = 100 usage.prompt_tokens_details = details response.usage = usage result = self.LLMService._extract_usage_metadata(response, "openai") assert result == {"prompt": 3000, "completion": 800, "cached": 100, "reasoning": 0} def test_extract_openai_missing_usage(self): response = MagicMock() response.usage = None result = self.LLMService._extract_usage_metadata(response, "openai") assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0} def test_extract_unknown_provider_returns_zeros(self): result = self.LLMService._extract_usage_metadata(MagicMock(), "anthropic") assert result == {"prompt": 0, "completion": 0, "cached": 0, "reasoning": 0} def test_extract_gemini_token_counts_default_to_zero_when_none(self): """Fields that return None should be coerced to 0, not None.""" response = MagicMock() um = MagicMock() um.prompt_token_count = None um.candidates_token_count = None um.cached_content_token_count = None response.usage_metadata = um result = self.LLMService._extract_usage_metadata(response, "gemini") assert result["prompt"] == 0 assert result["completion"] == 0 assert result["cached"] == 0 # ══════════════════════════════════════════════════════════════════════════════ # 4. UsageEvent.record — non-fatal DB error handling, field coercion # ══════════════════════════════════════════════════════════════════════════════ class TestUsageEventRecord: def setup_method(self): # Reset module-level imports for clean state pass @pytest.mark.asyncio async def test_record_normalizes_unknown_feature(self): """Unknown feature names should be stored as 'other'.""" captured = {} async def fake_insert(doc): captured['doc'] = doc mock_db = MagicMock() mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert) with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)): from app.models.usage_event import UsageEvent await UsageEvent.record( provider="gemini", model="gemini-3.1-pro-preview", prompt_tokens=100, completion_tokens=50, cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}, feature="totally_invalid_feature_xyz", ) assert captured['doc']['feature'] == "other" @pytest.mark.asyncio async def test_record_valid_feature_preserved(self): captured = {} async def fake_insert(doc): captured['doc'] = doc mock_db = MagicMock() mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert) with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)): from app.models.usage_event import UsageEvent await UsageEvent.record( provider="gemini", model="gemini-3.1-pro-preview", prompt_tokens=100, completion_tokens=50, cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}, feature="key_themes", ) assert captured['doc']['feature'] == "key_themes" @pytest.mark.asyncio async def test_record_total_tokens_computed(self): captured = {} async def fake_insert(doc): captured['doc'] = doc mock_db = MagicMock() mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert) with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)): from app.models.usage_event import UsageEvent await UsageEvent.record( provider="gemini", model="gemini-3.1-pro-preview", prompt_tokens=300, completion_tokens=150, cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}, ) assert captured['doc']['total_tokens'] == 450 @pytest.mark.asyncio async def test_record_error_truncated_to_500_chars(self): captured = {} async def fake_insert(doc): captured['doc'] = doc mock_db = MagicMock() mock_db.usage_events.insert_one = AsyncMock(side_effect=fake_insert) long_error = "E" * 1000 with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)): from app.models.usage_event import UsageEvent await UsageEvent.record( provider="gemini", model="gemini-3.1-pro-preview", prompt_tokens=10, completion_tokens=5, cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}, status="error", error=long_error, ) assert len(captured['doc']['error']) == 500 @pytest.mark.asyncio async def test_record_db_error_is_swallowed(self): """A DB write failure must not propagate — telemetry cannot kill LLM calls.""" mock_db = MagicMock() mock_db.usage_events.insert_one = AsyncMock(side_effect=RuntimeError("DB down")) with patch('app.models.usage_event.get_db', AsyncMock(return_value=mock_db)): from app.models.usage_event import UsageEvent # Should not raise await UsageEvent.record( provider="gemini", model="gemini-3.1-pro-preview", prompt_tokens=10, completion_tokens=5, cost_usd={"input": 0.0, "cached": 0.0, "output": 0.0, "total": 0.0}, ) # ══════════════════════════════════════════════════════════════════════════════ # 5. check_quota — admin bypass, warning threshold, exceeded error # ══════════════════════════════════════════════════════════════════════════════ class TestCheckQuota: """ check_quota does lazy imports inside the function body: from app.models.user import User from app.models.usage_event import UsageEvent from app.models.focus_group import FocusGroup We inject mocks via patch.dict(sys.modules) so that when the function executes those import statements at runtime it gets our mock objects. """ def _make_user_mod(self, user_doc): m = MagicMock() m.User.find_by_id = AsyncMock(return_value=user_doc) return m def _make_usage_mod(self, total): m = MagicMock() m.UsageEvent.sum_cost = AsyncMock(return_value=total) return m def _make_fg_mod(self, fg_doc): m = MagicMock() m.FocusGroup.find_by_id = AsyncMock(return_value=fg_doc) return m @pytest.mark.asyncio async def test_no_user_no_fg_returns_none(self): from app.models.quota import check_quota result = await check_quota(user_id=None, focus_group_id=None) assert result is None @pytest.mark.asyncio async def test_admin_bypasses_user_quota(self): user_mod = self._make_user_mod( {"role": "admin", "quota": {"monthly_usd": 0.01}, "override_quota": False} ) usage_mod = self._make_usage_mod(999.0) with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): from app.models.quota import check_quota result = await check_quota(user_id="admin_id", focus_group_id=None) assert result is None @pytest.mark.asyncio async def test_override_quota_bypasses_user_quota(self): user_mod = self._make_user_mod( {"role": "user", "quota": {"monthly_usd": 1.0}, "override_quota": True} ) usage_mod = self._make_usage_mod(999.0) with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): from app.models.quota import check_quota result = await check_quota(user_id="u1", focus_group_id=None) assert result is None @pytest.mark.asyncio async def test_no_quota_configured_returns_none(self): user_mod = self._make_user_mod({"role": "user", "quota": {}, "override_quota": False}) usage_mod = self._make_usage_mod(10.0) with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): from app.models.quota import check_quota result = await check_quota(user_id="u1", focus_group_id=None) assert result is None @pytest.mark.asyncio async def test_under_quota_returns_none(self): user_mod = self._make_user_mod( {"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False} ) usage_mod = self._make_usage_mod(20.0) with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): from app.models.quota import check_quota result = await check_quota(user_id="u1", focus_group_id=None) assert result is None @pytest.mark.asyncio async def test_quota_at_80_percent_returns_warning(self): user_mod = self._make_user_mod( {"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False} ) usage_mod = self._make_usage_mod(42.0) # 84% with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): from app.models.quota import check_quota, QuotaWarning result = await check_quota(user_id="u1", focus_group_id=None) assert isinstance(result, QuotaWarning) assert result.scope == "user" assert result.limit_usd == 50.0 assert result.used_usd == 42.0 assert result.pct == pytest.approx(0.84, abs=0.001) @pytest.mark.asyncio async def test_quota_exceeded_raises_error(self): user_mod = self._make_user_mod( {"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False} ) usage_mod = self._make_usage_mod(50.01) with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): from app.models.quota import check_quota, QuotaExceededError with pytest.raises(QuotaExceededError) as exc_info: await check_quota(user_id="u1", focus_group_id=None) err = exc_info.value assert err.scope == "user" assert err.limit_usd == 50.0 assert err.used_usd == pytest.approx(50.01) @pytest.mark.asyncio async def test_fg_quota_exceeded_raises_error(self): fg_mod = self._make_fg_mod({"quota": {"total_usd": 10.0}}) usage_mod = self._make_usage_mod(10.01) with patch.dict(sys.modules, {'app.models.focus_group': fg_mod, 'app.models.usage_event': usage_mod}): from app.models.quota import check_quota, QuotaExceededError with pytest.raises(QuotaExceededError) as exc_info: await check_quota(user_id=None, focus_group_id="fg1") assert exc_info.value.scope == "focus_group" @pytest.mark.asyncio async def test_fg_quota_at_80_percent_returns_warning(self): fg_mod = self._make_fg_mod({"quota": {"total_usd": 100.0}}) usage_mod = self._make_usage_mod(85.0) with patch.dict(sys.modules, {'app.models.focus_group': fg_mod, 'app.models.usage_event': usage_mod}): from app.models.quota import check_quota, QuotaWarning result = await check_quota(user_id=None, focus_group_id="fg1") assert isinstance(result, QuotaWarning) assert result.scope == "focus_group" @pytest.mark.asyncio async def test_db_error_in_quota_check_is_non_fatal(self): """If User.find_by_id raises, quota check swallows and allows the call.""" user_mod = MagicMock() user_mod.User.find_by_id = AsyncMock(side_effect=RuntimeError("DB timeout")) with patch.dict(sys.modules, {'app.models.user': user_mod}): from app.models.quota import check_quota result = await check_quota(user_id="u1", focus_group_id=None) assert result is None @pytest.mark.asyncio async def test_exactly_at_quota_limit_is_exceeded(self): """Spending exactly the limit should trigger QuotaExceededError (>=).""" user_mod = self._make_user_mod( {"role": "user", "quota": {"monthly_usd": 50.0}, "override_quota": False} ) usage_mod = self._make_usage_mod(50.0) # exactly at limit with patch.dict(sys.modules, {'app.models.user': user_mod, 'app.models.usage_event': usage_mod}): from app.models.quota import check_quota, QuotaExceededError with pytest.raises(QuotaExceededError): await check_quota(user_id="u1", focus_group_id=None) # ══════════════════════════════════════════════════════════════════════════════ # 6. QuotaExceededError — message and attributes # ══════════════════════════════════════════════════════════════════════════════ class TestQuotaExceededError: def test_attributes(self): from app.models.quota import QuotaExceededError from datetime import datetime, timezone period = datetime(2026, 4, 1, tzinfo=timezone.utc) err = QuotaExceededError("user", 50.0, 52.34, period) assert err.scope == "user" assert err.limit_usd == 50.0 assert err.used_usd == 52.34 assert err.period_start == period assert "user" in str(err) assert "52.3400" in str(err) assert "50.00" in str(err) def test_is_exception(self): from app.models.quota import QuotaExceededError assert issubclass(QuotaExceededError, Exception)