274 lines
No EOL
12 KiB
Python
274 lines
No EOL
12 KiB
Python
"""
|
|
Consolidation processor for merging multiple LLM analysis results
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from typing import List, Dict, Any, Tuple
|
|
from dataclasses import dataclass
|
|
import os
|
|
|
|
from llm_service import ProviderManager, LLMResponse
|
|
from config import config
|
|
|
|
@dataclass
|
|
class ConsolidationResult:
|
|
"""Result of consolidation process"""
|
|
consolidated_deliverables: List[Any] # BaseDeliverable
|
|
expanded_assets: List[Any] # MarketingAsset
|
|
consolidation_metadata: Dict[str, Any]
|
|
warnings: List[str]
|
|
|
|
class ConsolidationProcessor:
|
|
"""Processes multiple LLM analysis results into a single consolidated output"""
|
|
|
|
def __init__(self):
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
self.provider_manager = ProviderManager()
|
|
|
|
async def consolidate_results(
|
|
self,
|
|
analysis_responses: List[LLMResponse],
|
|
consolidation_model: str,
|
|
document_content: str = ""
|
|
) -> ConsolidationResult:
|
|
"""
|
|
Consolidate multiple analysis results using the specified consolidation model
|
|
|
|
Args:
|
|
analysis_responses: List of LLM responses from primary analysis
|
|
consolidation_model: Model key for consolidation (e.g., 'anthropic-opus4')
|
|
document_content: Optional original document content for context
|
|
|
|
Returns:
|
|
ConsolidationResult with final consolidated deliverables
|
|
"""
|
|
self.logger.info(f"Starting consolidation with {len(analysis_responses)} model results using {consolidation_model}")
|
|
|
|
# Log individual model deliverable counts
|
|
successful_models = []
|
|
deliverable_counts = []
|
|
for i, response in enumerate(analysis_responses):
|
|
if response.success:
|
|
count = self._count_deliverables_in_response(response.content)
|
|
deliverable_counts.append(count)
|
|
successful_models.append(f"{response.provider} {response.model_used}")
|
|
self.logger.info(f"Model {i+1} ({response.provider} {response.model_used}): {count} base deliverables")
|
|
|
|
if deliverable_counts:
|
|
avg_deliverables = sum(deliverable_counts) / len(deliverable_counts)
|
|
self.logger.info(f"Average deliverables across {len(deliverable_counts)} models: {avg_deliverables:.1f}")
|
|
else:
|
|
self.logger.warning("No successful model responses to analyze")
|
|
|
|
# Extract and format results from all models
|
|
formatted_results = self._format_model_results(analysis_responses)
|
|
|
|
# Prepare consolidation prompt
|
|
consolidation_prompt = self._prepare_consolidation_prompt(formatted_results)
|
|
|
|
# Load system message for consolidation
|
|
system_message = self._load_consolidation_system_prompt()
|
|
|
|
# Execute consolidation using specified model
|
|
try:
|
|
provider = self.provider_manager.get_provider(consolidation_model)
|
|
messages = provider.prepare_messages(system_message, consolidation_prompt)
|
|
|
|
# Use the universal base deliverable schema for structured output
|
|
from process_brief_enhanced import UNIVERSAL_BASE_DELIVERABLE_SCHEMA
|
|
|
|
consolidation_response = await provider.generate_response(
|
|
messages=messages,
|
|
schema=UNIVERSAL_BASE_DELIVERABLE_SCHEMA
|
|
)
|
|
|
|
if not consolidation_response.success:
|
|
raise Exception(f"Consolidation failed: {consolidation_response.error}")
|
|
|
|
# Parse the consolidated results - import here to avoid circular import
|
|
from process_brief_enhanced import BaseDeliverable, expand_deliverables
|
|
|
|
consolidated_data = json.loads(consolidation_response.content)
|
|
base_deliverables = [BaseDeliverable(**item) for item in consolidated_data['assets']]
|
|
|
|
self.logger.info(f"Consolidation completed: {len(base_deliverables)} base deliverables")
|
|
|
|
# Expand consolidated base deliverables into individual assets
|
|
expanded_assets, expansion_warnings = expand_deliverables(base_deliverables)
|
|
self.logger.info(f"Expansion completed: {len(expanded_assets)} individual assets")
|
|
|
|
# Create consolidation metadata
|
|
metadata = self._create_consolidation_metadata(
|
|
analysis_responses,
|
|
consolidation_response,
|
|
base_deliverables,
|
|
expanded_assets
|
|
)
|
|
|
|
return ConsolidationResult(
|
|
consolidated_deliverables=base_deliverables,
|
|
expanded_assets=expanded_assets,
|
|
consolidation_metadata=metadata,
|
|
warnings=expansion_warnings
|
|
)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Consolidation failed: {e}")
|
|
raise
|
|
|
|
def _count_deliverables_in_response(self, content: str) -> int:
|
|
"""Count the number of deliverables in a model's JSON response"""
|
|
try:
|
|
data = json.loads(content)
|
|
if isinstance(data, dict) and 'assets' in data:
|
|
return len(data['assets'])
|
|
return 0
|
|
except (json.JSONDecodeError, KeyError, TypeError):
|
|
return 0
|
|
|
|
def _format_model_results(self, responses: List[LLMResponse]) -> str:
|
|
"""Format analysis results from multiple models for consolidation prompt"""
|
|
formatted_results = []
|
|
|
|
for i, response in enumerate(responses):
|
|
if response.success:
|
|
model_info = f"**MODEL {i+1}: {response.provider.upper()} {response.model_used}**"
|
|
|
|
# Try to extract JSON content
|
|
try:
|
|
# Parse the JSON to validate it
|
|
result_data = json.loads(response.content)
|
|
formatted_content = json.dumps(result_data, indent=2)
|
|
except json.JSONDecodeError:
|
|
# Fallback to raw content if not valid JSON
|
|
formatted_content = response.content
|
|
|
|
formatted_results.append(f"{model_info}\n```json\n{formatted_content}\n```")
|
|
else:
|
|
self.logger.warning(f"Skipping failed response from {response.provider} {response.model_used}: {response.error}")
|
|
|
|
return "\n\n".join(formatted_results)
|
|
|
|
def _prepare_consolidation_prompt(self, formatted_results: str) -> str:
|
|
"""Prepare the consolidation prompt with model results"""
|
|
try:
|
|
# Load consolidation prompt template
|
|
prompt_path = os.path.join(os.path.dirname(__file__), 'prompts', 'consolidation_analysis.txt')
|
|
with open(prompt_path, 'r', encoding='utf-8') as f:
|
|
template = f.read()
|
|
|
|
return template.format(models_results=formatted_results)
|
|
|
|
except FileNotFoundError:
|
|
self.logger.error("Consolidation prompt template not found")
|
|
raise
|
|
except Exception as e:
|
|
self.logger.error(f"Error preparing consolidation prompt: {e}")
|
|
raise
|
|
|
|
def _load_consolidation_system_prompt(self) -> str:
|
|
"""Load system prompt for consolidation"""
|
|
return """You are an expert data consolidation specialist. Your task is to intelligently merge multiple LLM analysis results into the most complete and accurate dataset possible. Follow the consolidation strategy provided in the user prompt, with emphasis on completeness and thoroughness. Return only valid JSON in the specified format."""
|
|
|
|
def _create_consolidation_metadata(
|
|
self,
|
|
analysis_responses: List[LLMResponse],
|
|
consolidation_response: LLMResponse,
|
|
base_deliverables: List[Any],
|
|
expanded_assets: List[Any]
|
|
) -> Dict[str, Any]:
|
|
"""Create metadata about the consolidation process"""
|
|
|
|
# Analyze model contributions
|
|
model_stats = {}
|
|
total_primary_tokens = 0
|
|
total_primary_cost = 0.0
|
|
|
|
for response in analysis_responses:
|
|
if response.success:
|
|
model_key = f"{response.provider}_{response.model_used}"
|
|
model_stats[model_key] = {
|
|
'tokens_used': response.token_usage.get_total(),
|
|
'processing_time': response.processing_time,
|
|
'success': True
|
|
}
|
|
total_primary_tokens += response.token_usage.get_total()
|
|
|
|
# Estimate cost for this response
|
|
try:
|
|
# Find the correct model key for this response
|
|
provider_model_key = None
|
|
for key in config.MODEL_MAPPINGS.keys():
|
|
provider_name, model_name = config.get_model_info(key)
|
|
if provider_name == response.provider and model_name == response.model_used:
|
|
provider_model_key = key
|
|
break
|
|
|
|
if provider_model_key:
|
|
provider = self.provider_manager.get_provider(provider_model_key)
|
|
cost = provider.estimate_cost(
|
|
response.token_usage.input_tokens,
|
|
response.token_usage.output_tokens,
|
|
response.token_usage.cached_input_tokens
|
|
)
|
|
total_primary_cost += cost
|
|
model_stats[model_key]['estimated_cost'] = cost
|
|
else:
|
|
model_stats[model_key]['estimated_cost'] = 0.0
|
|
except:
|
|
model_stats[model_key]['estimated_cost'] = 0.0
|
|
else:
|
|
model_key = f"{response.provider}_{response.model_used}"
|
|
model_stats[model_key] = {
|
|
'tokens_used': 0,
|
|
'processing_time': response.processing_time,
|
|
'success': False,
|
|
'error': response.error,
|
|
'estimated_cost': 0.0
|
|
}
|
|
|
|
# Consolidation model stats
|
|
consolidation_cost = 0.0
|
|
try:
|
|
# Find the correct model key for consolidation response
|
|
consolidation_model_key = None
|
|
for key in config.MODEL_MAPPINGS.keys():
|
|
provider_name, model_name = config.get_model_info(key)
|
|
if provider_name == consolidation_response.provider and model_name == consolidation_response.model_used:
|
|
consolidation_model_key = key
|
|
break
|
|
|
|
if consolidation_model_key:
|
|
provider = self.provider_manager.get_provider(consolidation_model_key)
|
|
consolidation_cost = provider.estimate_cost(
|
|
consolidation_response.token_usage.input_tokens,
|
|
consolidation_response.token_usage.output_tokens,
|
|
consolidation_response.token_usage.cached_input_tokens
|
|
)
|
|
except:
|
|
pass
|
|
|
|
return {
|
|
'consolidation_model': consolidation_response.model_used,
|
|
'consolidation_provider': consolidation_response.provider,
|
|
'primary_models_used': len([r for r in analysis_responses if r.success]),
|
|
'total_models_attempted': len(analysis_responses),
|
|
'base_deliverables_count': len(base_deliverables),
|
|
'final_assets_count': len(expanded_assets),
|
|
'model_statistics': model_stats,
|
|
'token_usage': {
|
|
'primary_analysis_total': total_primary_tokens,
|
|
'consolidation_tokens': consolidation_response.token_usage.get_total(),
|
|
'grand_total': total_primary_tokens + consolidation_response.token_usage.get_total()
|
|
},
|
|
'cost_breakdown': {
|
|
'primary_analysis_cost': round(total_primary_cost, 4),
|
|
'consolidation_cost': round(consolidation_cost, 4),
|
|
'total_cost': round(total_primary_cost + consolidation_cost, 4)
|
|
},
|
|
'processing_times': {
|
|
'consolidation_time': consolidation_response.processing_time,
|
|
'primary_models_avg_time': sum(r.processing_time for r in analysis_responses if r.success) / max(1, len([r for r in analysis_responses if r.success]))
|
|
}
|
|
} |