master_adapt_detect/test_cost_calculator.py
2025-10-01 14:32:55 -05:00

166 lines
No EOL
5.7 KiB
Python

#!/usr/bin/env python3
"""
Test script for the cost calculator functionality
"""
import sys
import os
from pathlib import Path
# Add current directory to path so we can import our modules
sys.path.insert(0, str(Path(__file__).parent))
from cost_calculator import CostCalculator, TokenUsage, ApiCallCost, extract_token_usage_from_response
def test_cost_calculator():
"""Test the cost calculator functionality"""
print("Testing Cost Calculator...")
# Test 1: Basic cost calculation
print("\n1. Testing basic cost calculation:")
calc = CostCalculator(enable_tracking=True)
# Test cost calculation with sample token usage
input_cost, output_cost, cached_cost, total_cost = calc.calculate_cost(
prompt_tokens=1500,
completion_tokens=800,
cached_tokens=200
)
print(f" Input tokens (1500): ${input_cost:.4f}")
print(f" Output tokens (800): ${output_cost:.4f}")
print(f" Cached tokens (200): ${cached_cost:.4f}")
print(f" Total cost: ${total_cost:.4f}")
# Test 2: API call tracking
print("\n2. Testing API call tracking:")
# Simulate multiple API calls
calc.track_api_call(
operation_type="panel_counting_censorship",
prompt_tokens=1500,
completion_tokens=800,
cached_tokens=200,
layout_name="test_layout_1.jpg"
)
calc.track_api_call(
operation_type="detection",
prompt_tokens=2000,
completion_tokens=1200,
cached_tokens=0,
layout_name="test_layout_2.jpg"
)
calc.track_api_call(
operation_type="one_at_a_time_detection",
prompt_tokens=800,
completion_tokens=400,
cached_tokens=100,
layout_name="test_layout_3.jpg",
master_id="1011A_1011_05"
)
print(f" Tracked {len(calc.api_calls)} API calls")
print(f" Total cost so far: ${calc.total_cost:.4f}")
# Test 3: Layout cost breakdown
print("\n3. Testing layout cost breakdown:")
breakdown = calc.get_layout_cost_breakdown("test_layout_1.jpg")
if breakdown:
print(f" Layout: {breakdown['layout_name']}")
print(f" Total cost: ${breakdown['total_cost']:.4f}")
print(f" Input tokens: {breakdown['cost_breakdown']['input_tokens']}")
print(f" Output tokens: {breakdown['cost_breakdown']['output_tokens']}")
print(f" API calls: {breakdown['cost_breakdown']['api_calls_made']}")
# Test 4: Session summary
print("\n4. Testing session summary:")
summary = calc.get_session_summary()
if summary['tracking_enabled']:
print(f" Total cost: ${summary['session_totals']['total_cost']:.4f}")
print(f" Total tokens: {summary['session_totals']['total_input_tokens'] + summary['session_totals']['total_output_tokens']:,}")
print(f" Layouts processed: {summary['session_totals']['layouts_processed']}")
print(f" Avg cost per layout: ${summary['averages']['cost_per_layout']:.4f}")
# Test 5: Monthly cost estimation
print("\n5. Testing monthly cost estimation:")
estimate = calc.estimate_monthly_cost(300)
if 'error' not in estimate:
print(f" Based on {estimate['based_on_layouts']} layouts:")
print(f" Average cost per layout: ${estimate['average_cost_per_layout']:.4f}")
print(f" Monthly estimate (300 layouts): ${estimate['estimated_monthly_cost']:.2f}")
print(f" Annual estimate: ${estimate['estimated_annual_cost']:.2f}")
# Test 6: Cost report generation
print("\n6. Testing cost report generation:")
report_file = calc.save_cost_report("test_cost_report")
if report_file:
print(f" Cost report saved to: {report_file}")
# Test 7: Print cost summary
print("\n7. Testing cost summary output:")
calc.print_cost_summary()
print("\nCost calculator test completed successfully!")
def test_token_usage():
"""Test the TokenUsage data class"""
print("\nTesting TokenUsage data class...")
# Test valid token usage
usage = TokenUsage(
prompt_tokens=1500,
completion_tokens=800,
total_tokens=2300,
cached_tokens=200
)
print(f" Prompt tokens: {usage.prompt_tokens}")
print(f" Completion tokens: {usage.completion_tokens}")
print(f" Total tokens: {usage.total_tokens}")
print(f" Cached tokens: {usage.cached_tokens}")
# Test token usage validation
try:
invalid_usage = TokenUsage(
prompt_tokens=1500,
completion_tokens=800,
total_tokens=2000, # Should be 2300
cached_tokens=200
)
print(" ERROR: Should have raised ValueError for invalid total")
except ValueError as e:
print(f" ✓ Correctly caught validation error: {e}")
def test_disabled_tracking():
"""Test cost calculator with tracking disabled"""
print("\nTesting disabled cost tracking...")
calc = CostCalculator(enable_tracking=False)
# All operations should return zeros or empty results
input_cost, output_cost, cached_cost, total_cost = calc.calculate_cost(1500, 800, 200)
print(f" Cost calculation (disabled): ${total_cost:.4f}")
api_call = calc.track_api_call("test", 1500, 800, 200, "test.jpg")
print(f" API call tracking (disabled): ${api_call.total_cost:.4f}")
summary = calc.get_session_summary()
print(f" Session summary (disabled): {summary['tracking_enabled']}")
if __name__ == "__main__":
try:
test_cost_calculator()
test_token_usage()
test_disabled_tracking()
print("\n✅ All tests passed!")
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)