hp_chatbot/graph_rag_integration.py
michael 5554aa043f Add GraphRAG startup optimization: triple caching and background init
- Cache extracted triples to disk (neo4j_triples.pickle) so Neo4j can be
  repopulated without expensive LLM re-extraction on cold starts
- Split initialization into two phases: fast vector-only (~1-2 min) and
  background GraphRAG, so the server serves requests while GraphRAG loads
- Add GraphRAG status flags to shared_state for monitoring readiness
- Update /status endpoint to expose graphrag_ready/initializing/error
- Restructure main.py to use single event loop for background task support

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-23 17:45:05 -06:00

1008 lines
No EOL
41 KiB
Python

"""
HP GraphRAG Integration
Integrates GraphRAG functionality into the HP RAG pipeline.
- GraphRAG for knowledge graph construction from semantically split nodes
- Community detection and summarization for improved context retrieval
"""
import os
import json
import re
import asyncio
import networkx as nx
from collections import defaultdict
from typing import Any, List, Callable, Optional, Union, Dict
from pathlib import Path
# Import LlamaIndex components
from llama_index.core import Document, Settings
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import PropertyGraphIndex
from llama_index.core.async_utils import run_jobs
from llama_index.core.indices.property_graph.utils import default_parse_triplets_fn
from llama_index.core.graph_stores.types import EntityNode, KG_NODES_KEY, KG_RELATIONS_KEY, Relation
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT
from llama_index.core.schema import TransformComponent, BaseNode
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.core import SimpleDirectoryReader
from llama_index.core.vector_stores.types import VectorStoreInfo, MetadataInfo
from llama_index.core.retrievers import VectorIndexRetriever
# Community detection (using NetworkX instead of graspologic as a fallback)
try:
from community import best_partition # python-louvain package
except ImportError:
print("Community detection package not found, using NetworkX built-in community detection")
# Import from our modules
from utils import logger, log_structured
from config import NEO4J_URL, NEO4J_USERNAME, NEO4J_PASSWORD
import config
# Define the GraphRAGExtractor class
class GraphRAGExtractor(TransformComponent):
"""Extract triples from a graph.
Uses an LLM and a simple prompt + output parsing to
extract paths (i.e. triples) and entity, relation descriptions
from text.
Args:
llm (LLM):
The language model to use.
extract_prompt (Union[str, PromptTemplate]):
The prompt to use for extracting triples.
parse_fn (callable):
A function to parse the output of the language
model.
num_workers (int):
The number of workers to use for parallel
processing.
max_paths_per_chunk (int):
The maximum number of paths to extract per chunk.
"""
llm: LLM
extract_prompt: PromptTemplate
parse_fn: Callable
num_workers: int
max_paths_per_chunk: int
def __init__(
self,
llm: Optional[LLM] = None,
extract_prompt: Optional[Union[str, PromptTemplate]] = None,
parse_fn: Callable = default_parse_triplets_fn,
max_paths_per_chunk: int = 10,
num_workers: int = 8,
) -> None:
"""Init params."""
from llama_index.core import Settings
if isinstance(extract_prompt, str):
extract_prompt = PromptTemplate(extract_prompt)
super().__init__(
llm=llm or Settings.llm,
extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
parse_fn=parse_fn,
num_workers=num_workers,
max_paths_per_chunk=max_paths_per_chunk,
)
@classmethod
def class_name(cls) -> str:
return "GraphExtractor"
def __call__(
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
"""Extract triples from nodes."""
return asyncio.run(
self.acall(nodes, show_progress=show_progress, **kwargs)
)
async def _aextract(self, node: BaseNode) -> BaseNode:
"""Extract triples from a node."""
assert hasattr(node, "text")
text = node.get_content(metadata_mode="llm")
try:
llm_response = await self.llm.apredict(
self.extract_prompt,
text=text,
max_knowledge_triplets=self.max_paths_per_chunk,
)
entities, entities_relationship = self.parse_fn(llm_response)
except ValueError:
entities = []
entities_relationship = []
existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
entity_metadata = node.metadata.copy()
for entity, entity_type, description in entities:
entity_metadata["entity_description"] = description
entity_node = EntityNode(
name=entity, label=entity_type,
properties=entity_metadata
)
existing_nodes.append(entity_node)
relation_metadata = node.metadata.copy()
for triple in entities_relationship:
subj, obj, rel, description = triple
relation_metadata["relationship_description"] = description
rel_node = Relation(
label=rel,
source_id=subj,
target_id=obj,
properties=relation_metadata,
)
existing_relations.append(rel_node)
node.metadata[KG_NODES_KEY] = existing_nodes
node.metadata[KG_RELATIONS_KEY] = existing_relations
return node
async def acall(
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
"""Extract triples from nodes async."""
jobs = []
for node in nodes:
jobs.append(self._aextract(node))
return await run_jobs(
jobs,
workers=self.num_workers,
show_progress=show_progress,
desc="Extracting paths from text",
)
# Define the GraphRAGStore class (integrating with Neo4j)
import pickle
import os
from pathlib import Path
class GraphRAGStore:
community_summary = {}
entity_info = None
max_cluster_size = 5
property_graph_store = None
communities_built = False # Track if communities have been built
# Path for cached community data
CACHE_DIR = Path("index_storage/graphrag_cache")
COMMUNITY_CACHE_FILE = CACHE_DIR / "community_summary.pickle"
ENTITY_INFO_CACHE_FILE = CACHE_DIR / "entity_info.pickle"
TRIPLES_CACHE_FILE = CACHE_DIR / "neo4j_triples.pickle"
def __init__(self, property_graph_store):
"""Initialize with a property_graph_store (Neo4j or in-memory)."""
self.property_graph_store = property_graph_store
self.community_summary = {}
self.entity_info = None
self.communities_built = False
# Ensure cache directory exists
os.makedirs(self.CACHE_DIR, exist_ok=True)
def add_nodes(self, nodes):
"""Add nodes to the property graph store."""
return self.property_graph_store.add_nodes(nodes)
def add_relationships(self, relationships):
"""Add relationships to the property graph store."""
return self.property_graph_store.add_relationships(relationships)
def get_triplets(self):
"""Get triplets from the property graph store."""
return self.property_graph_store.get_triplets()
def save_triples_to_cache(self):
"""Save extracted triples (entities + relationships) from Neo4j to a disk cache.
This allows restoring triples to Neo4j without expensive LLM re-extraction
if Neo4j data is lost (e.g., container recreated without volume persistence).
"""
try:
triplets = self.get_triplets()
if not triplets:
log_structured('warning', 'No triplets to cache — Neo4j appears empty')
return False
# Collect unique entities and relations from the triplets
entities = {}
relations = []
for entity1, relation, entity2 in triplets:
entities[entity1.name] = entity1
entities[entity2.name] = entity2
relations.append(relation)
cache_data = {
'entities': list(entities.values()),
'relations': relations,
'triplet_count': len(triplets),
}
with open(self.TRIPLES_CACHE_FILE, 'wb') as f:
pickle.dump(cache_data, f)
log_structured('info', 'Successfully cached Neo4j triples to disk', {
'entity_count': len(entities),
'relation_count': len(relations),
'triplet_count': len(triplets),
'cache_file': str(self.TRIPLES_CACHE_FILE)
})
return True
except Exception as e:
log_structured('error', f'Error saving triples cache: {e}')
return False
def load_triples_from_cache(self):
"""Load triples from disk cache and restore them to Neo4j.
Returns True if triples were successfully restored, False otherwise.
"""
if not self.TRIPLES_CACHE_FILE.exists():
log_structured('info', 'No triples cache file found')
return False
try:
with open(self.TRIPLES_CACHE_FILE, 'rb') as f:
cache_data = pickle.load(f)
entities = cache_data.get('entities', [])
relations = cache_data.get('relations', [])
if not entities:
log_structured('warning', 'Triples cache file exists but contains no entities')
return False
log_structured('info', 'Restoring triples from disk cache to Neo4j', {
'entity_count': len(entities),
'relation_count': len(relations),
'cached_triplet_count': cache_data.get('triplet_count', 'unknown')
})
# Restore entities (nodes) to Neo4j
self.property_graph_store.upsert_nodes(entities)
log_structured('info', f'Restored {len(entities)} entity nodes to Neo4j')
# Restore relations to Neo4j
if relations:
self.property_graph_store.upsert_relations(relations)
log_structured('info', f'Restored {len(relations)} relations to Neo4j')
# Verify restoration
restored_triplets = self.get_triplets()
log_structured('info', f'Neo4j now contains {len(restored_triplets)} triplets after cache restore')
return len(restored_triplets) > 0
except Exception as e:
log_structured('error', f'Error restoring triples from cache: {e}')
return False
def generate_community_summary(self, text):
"""Generate summary for a given text using an LLM with handling for large contexts."""
# Check if text is too long and chunk if needed
if len(text) > 30000: # Approximate character limit
log_structured('info', f'Community text is large ({len(text)} chars). Chunking for summarization.')
# Split into smaller chunks (simple approach)
chunks = [text[i:i+30000] for i in range(0, len(text), 30000)]
summaries = []
for i, chunk in enumerate(chunks):
try:
# Use GPT-4o-mini model for better cost efficiency
llm = OpenAI(model="gpt-4.1-mini")
messages = [
ChatMessage(
role="system",
content="Summarize these knowledge graph relationships concisely."
),
ChatMessage(role="user", content=chunk),
]
response = llm.chat(messages)
summaries.append(str(response).strip())
log_structured('info', f'Successfully summarized community chunk {i+1}/{len(chunks)}')
except Exception as e:
log_structured('error', f'Error summarizing community chunk {i+1}/{len(chunks)}: {e}')
# Then summarize the summaries
if summaries:
final_summary_text = "\n\n".join(summaries)
try:
llm = OpenAI(model="gpt-4.1-mini")
messages = [
ChatMessage(
role="system",
content="Create a coherent summary from these partial summaries."
),
ChatMessage(role="user", content=final_summary_text),
]
response = llm.chat(messages)
return str(response).strip()
except Exception as e:
log_structured('error', f'Error creating final summary from chunks: {e}')
# Return the concatenated summaries if we can't summarize them
return final_summary_text
else:
return "Unable to generate community summary due to size limitations."
# For normal size text, use the larger model directly
try:
# Use GPT-4o-mini model for better cost efficiency
llm = OpenAI(model="gpt-4.1-mini")
messages = [
ChatMessage(
role="system",
content=(
"You are provided with a set of "
"relationships from a knowledge graph, each represented as "
"entity1->entity2->relation-"
">relationship_description. Your task is to create a summary of "
"these relationships. The summary should include "
"the names of the entities involved and a concise synthesis "
"of the relationship descriptions. The "
"goal is to capture the most critical and relevant details that "
"highlight the nature and significance of "
"each relationship. Ensure that the summary is coherent and "
"integrates the information in a way that "
"emphasizes the key aspects of the relationships."
),
),
ChatMessage(role="user", content=text),
]
response = llm.chat(messages)
clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
return clean_response
except Exception as e:
log_structured('error', f'Error generating community summary: {e}')
return f"Error summarizing community: {str(e)}"
def build_communities(self):
"""Builds communities from the graph and summarizes them."""
# Skip if communities are already built in this session
if self.communities_built:
log_structured('info', 'Communities already built in this session, skipping rebuild')
return
# First check if we can load from cache
if self.load_from_cache():
log_structured('info', 'Using cached community data instead of rebuilding')
self.communities_built = True
return
log_structured('info', 'Building communities from graph data')
nx_graph = self._create_nx_graph()
# Use either Leiden algorithm (from graspologic) or an alternative
try:
from graspologic.partition import hierarchical_leiden
community_hierarchical_clusters = hierarchical_leiden(
nx_graph, max_cluster_size=self.max_cluster_size
)
self.entity_info, community_info = self._collect_community_info(
nx_graph, community_hierarchical_clusters
)
except ImportError:
# Fallback to community detection using NetworkX or python-louvain
try:
from community import best_partition
partition = best_partition(nx_graph)
# Reformat partition data to expected structure
clusters = []
for node, cluster_id in partition.items():
class Cluster:
def __init__(self, node, cluster):
self.node = node
self.cluster = cluster
clusters.append(Cluster(node, cluster_id))
self.entity_info, community_info = self._collect_community_info(
nx_graph, clusters
)
except ImportError:
# Use NetworkX's built-in community detection
from networkx.algorithms import community
communities = community.greedy_modularity_communities(nx_graph)
clusters = []
for i, comm in enumerate(communities):
for node in comm:
class Cluster:
def __init__(self, node, cluster):
self.node = node
self.cluster = cluster
clusters.append(Cluster(node, i))
self.entity_info, community_info = self._collect_community_info(
nx_graph, clusters
)
self._summarize_communities(community_info)
# Cache the results after building
self.save_to_cache()
# Mark communities as built for this session
self.communities_built = True
def _create_nx_graph(self):
"""Converts internal graph representation to NetworkX graph."""
nx_graph = nx.Graph()
triplets = self.get_triplets()
for entity1, relation, entity2 in triplets:
nx_graph.add_node(entity1.name)
nx_graph.add_node(entity2.name)
nx_graph.add_edge(
relation.source_id,
relation.target_id,
relationship=relation.label,
description=relation.properties.get("relationship_description", "No description provided"),
)
return nx_graph
def _collect_community_info(self, nx_graph, clusters):
"""
Collect information for each node based on their community,
allowing entities to belong to multiple clusters.
"""
entity_info = defaultdict(set)
community_info = defaultdict(list)
for item in clusters:
node = item.node
cluster_id = item.cluster
# Update entity_info
entity_info[node].add(cluster_id)
for neighbor in nx_graph.neighbors(node):
edge_data = nx_graph.get_edge_data(node, neighbor)
if edge_data:
detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
community_info[cluster_id].append(detail)
# Convert sets to lists for easier serialization if needed
entity_info = {k: list(v) for k, v in entity_info.items()}
return dict(entity_info), dict(community_info)
def _summarize_communities(self, community_info):
"""Generate and store summaries for each community."""
for community_id, details in community_info.items():
details_text = "\n".join(details) + "." # Ensure it ends with a period
self.community_summary[community_id] = self.generate_community_summary(details_text)
def save_to_cache(self):
"""Save community data to disk cache."""
try:
# Save community summary
with open(self.COMMUNITY_CACHE_FILE, 'wb') as f:
pickle.dump(self.community_summary, f)
# Save entity info
with open(self.ENTITY_INFO_CACHE_FILE, 'wb') as f:
pickle.dump(self.entity_info, f)
log_structured('info', 'Successfully cached GraphRAG community data',
{'community_count': len(self.community_summary),
'entity_count': len(self.entity_info) if self.entity_info else 0})
return True
except Exception as e:
log_structured('error', f'Error saving GraphRAG cache: {e}')
return False
def load_from_cache(self):
"""Load community data from disk cache if available."""
if not self.COMMUNITY_CACHE_FILE.exists() or not self.ENTITY_INFO_CACHE_FILE.exists():
log_structured('info', 'GraphRAG cache files not found, will build communities from scratch')
return False
try:
# Load community summary
with open(self.COMMUNITY_CACHE_FILE, 'rb') as f:
self.community_summary = pickle.load(f)
# Load entity info
with open(self.ENTITY_INFO_CACHE_FILE, 'rb') as f:
self.entity_info = pickle.load(f)
log_structured('info', 'Successfully loaded GraphRAG community data from cache',
{'community_count': len(self.community_summary),
'entity_count': len(self.entity_info) if self.entity_info else 0})
# Mark communities as built when successfully loaded from cache
self.communities_built = True
return True
except Exception as e:
log_structured('error', f'Error loading GraphRAG cache: {e}')
# Reset to empty in case of partial load
self.community_summary = {}
self.entity_info = None
self.communities_built = False
return False
def get_community_summaries(self):
"""Returns the community summaries, building them if not already done."""
if not self.community_summary:
# Try to load from cache first
if not self.load_from_cache():
# If cache load fails, build from scratch
self.build_communities()
# Cache the results for next time
self.save_to_cache()
return self.community_summary
# Define the GraphRAGQueryEngine class
from typing import Dict, Any
class GraphRAGQueryEngine:
"""Query engine that combines vector retrieval with graph-based community retrieval."""
def __init__(
self,
vector_retriever: VectorIndexRetriever,
graph_store: GraphRAGStore,
llm: Optional[LLM] = None,
similarity_top_k: int = 20
):
"""Initialize with both a vector retriever and graph store."""
# Initialize all required fields
self.vector_retriever = vector_retriever
self.graph_store = graph_store
self.llm = llm or Settings.llm
self.similarity_top_k = similarity_top_k
# Check if communities are built, but don't try to build them here
# since that might cause errors with large graphs
if not hasattr(self.graph_store, 'entity_info') or self.graph_store.entity_info is None:
log_structured('warning', 'GraphRAGQueryEngine initialized without community data. Vector retrieval will still work, but community retrieval may be limited.')
def custom_query(self, query_str: str) -> Dict:
"""Process query using both vector retrieval and community-based approach."""
log_structured('info', 'GraphRAG query engine: Starting dual retrieval', {'query': query_str})
# Step 1: Get vector search results
vector_nodes = self.vector_retriever.retrieve(query_str)
vector_context = "\n\n".join([node.node.get_content() for node in vector_nodes])
log_structured('info', 'GraphRAG query engine: Vector retrieval complete',
{'node_count': len(vector_nodes)})
# Step 2: Get GraphRAG community results (if communities exist)
graphrag_context = ""
community_ids = []
if hasattr(self.graph_store, 'entity_info') and self.graph_store.entity_info is not None:
try:
entities = self.get_entities(query_str, vector_nodes)
community_ids = self.retrieve_entity_communities(self.graph_store.entity_info, entities)
try:
community_summaries = self.graph_store.get_community_summaries()
if community_ids:
filtered_summaries = {id: summary for id, summary in community_summaries.items()
if id in community_ids}
graphrag_context = "\n\n".join(filtered_summaries.values())
log_structured('info', 'GraphRAG query engine: Community retrieval complete',
{'community_count': len(filtered_summaries)})
else:
log_structured('info', 'GraphRAG query engine: No relevant communities found')
except Exception as e:
log_structured('error', f'Error getting community summaries: {e}')
# Continue without graph context
except Exception as e:
log_structured('error', f'Error during community retrieval: {e}')
# Continue with just vector context
else:
log_structured('warning', 'GraphRAG query engine: No community data available. Using only vector retrieval.')
# Step 3: Combine contexts and generate answer
combined_result = {
"vector_context": vector_context,
"graphrag_context": graphrag_context,
"vector_nodes": vector_nodes,
"community_ids": community_ids
}
return combined_result
def get_entities(self, query_str, vector_nodes):
"""Extract entities from vector nodes that match the query."""
entities = set()
# Extract entities from the retrieved nodes
for node_with_score in vector_nodes:
node = node_with_score.node
if hasattr(node, 'metadata') and KG_NODES_KEY in node.metadata:
for entity_node in node.metadata[KG_NODES_KEY]:
if hasattr(entity_node, 'name'):
entities.add(entity_node.name)
# If no entities were found in metadata, try extracting them from text
if not entities:
pattern = r"(?:^|\s)([A-Z][a-zA-Z0-9\s]+)(?:\s|$)"
for node_with_score in vector_nodes:
matches = re.findall(pattern, node_with_score.node.get_content())
entities.update(matches)
log_structured('debug', 'GraphRAG query engine: Extracted entities',
{'entities': list(entities), 'count': len(entities)})
return list(entities)
def retrieve_entity_communities(self, entity_info, entities):
"""
Retrieve cluster information for given entities,
allowing for multiple clusters per entity.
Args:
entity_info (dict): Dictionary mapping entities to their cluster IDs (list).
entities (list): List of entity names to retrieve information for.
Returns:
List of community or cluster IDs to which an entity belongs.
"""
community_ids = []
for entity in entities:
if entity in entity_info:
community_ids.extend(entity_info[entity])
else:
# Try case-insensitive matching as fallback
for stored_entity, clusters in entity_info.items():
if stored_entity.lower() == entity.lower():
community_ids.extend(clusters)
break
return list(set(community_ids))
def custom_parse_fn(response_str: str) -> Any:
"""Custom parser for LLM responses that extract entities and relationships"""
json_pattern = r"\{.*\}"
match = re.search(json_pattern, response_str, re.DOTALL)
entities = []
relationships = []
if not match:
return entities, relationships
json_str = match.group(0)
try:
data = json.loads(json_str)
entities = [
(
entity["entity_name"],
entity["entity_type"],
entity.get("entity_description", f"Description of {entity['entity_name']}"),
)
for entity in data.get("entities", [])
]
relationships = [
(
relation["source_entity"],
relation["target_entity"],
relation["relation"],
relation.get("relationship_description", f"Relationship between {relation['source_entity']} and {relation['target_entity']}"),
)
for relation in data.get("relationships", [])
]
return entities, relationships
except (json.JSONDecodeError, KeyError) as e:
log_structured('error', f"Error parsing response: {e}", {'json_str': json_str[:200]})
return entities, relationships
# Define the prompt template for triple extraction
KG_TRIPLET_EXTRACT_TMPL = """
-Goal-
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: Type of the entity
- entity_description: Comprehensive description of the entity's attributes and activities
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relation: relationship between source_entity and target_entity
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
3. Output Formatting:
- Return the result in valid JSON format with two keys: 'entities' (list of entity objects) and 'relationships' (list of relationship objects).
- Exclude any text outside the JSON structure (e.g., no explanations or comments).
- If no entities or relationships are identified, return empty lists: { "entities": [], "relationships": [] }.
-Real Data-
######################
text: {text}
######################
output:
"""
def create_graph_components(llm, nodes=None, max_paths_per_chunk=10, force_reindex=False):
"""
Create GraphRAG components for the HP RAG pipeline.
Args:
llm: The LLM to use for graph extraction and querying
nodes: List of nodes to create the graph from (only used if indexing is needed)
max_paths_per_chunk: Maximum number of paths to extract per chunk
force_reindex: If True, always recreate the index even if content exists
Returns:
tuple: (graph_store, property_graph_index)
"""
log_structured('info', 'Creating GraphRAG components')
# Note: The graph_store object created here will automatically:
# 1. Try to load community data from cache files when build_communities() is called
# 2. Save to cache after building communities if loading failed
# Connect to Neo4j - hard error if not available
property_graph_store = None
try:
log_structured('info', f'Connecting to Neo4j at {NEO4J_URL}')
property_graph_store = Neo4jPropertyGraphStore(
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD,
url=NEO4J_URL
)
log_structured('info', 'Successfully connected to Neo4j database')
except Exception as e:
log_structured('critical', f'FATAL ERROR: Cannot connect to Neo4j: {e}')
raise RuntimeError(f"Neo4j connection failed. This application requires Neo4j to be running. Error: {e}")
# Create GraphRAGStore wrapper
graph_store = GraphRAGStore(property_graph_store)
# Check if Neo4j already has content
triplets = graph_store.get_triplets()
has_existing_content = len(triplets) > 0
log_structured('info', f'Neo4j check: Found {len(triplets)} triplets')
if has_existing_content and not force_reindex:
# BRANCH 1: Neo4j has data — use it, but also ensure disk cache exists
log_structured('info', f'Neo4j already contains {len(triplets)} triplets. Skipping indexing.')
# Ensure triples are also cached to disk for future recovery
if not graph_store.TRIPLES_CACHE_FILE.exists():
log_structured('info', 'Neo4j has data but no triples cache on disk — creating cache now')
graph_store.save_triples_to_cache()
# Create a minimal PropertyGraphIndex without indexing
property_graph_index = PropertyGraphIndex(
nodes=[],
property_graph_store=property_graph_store,
)
if not graph_store.communities_built:
log_structured('info', 'Building graph communities from existing Neo4j data')
try:
graph_store.build_communities()
except Exception as e:
log_structured('error', f'Error building communities: {e}')
elif not has_existing_content and not force_reindex and graph_store.TRIPLES_CACHE_FILE.exists():
# BRANCH 2: Neo4j is empty but triples cache exists — restore from cache
log_structured('info', 'Neo4j is empty but triples cache exists. Restoring from disk cache.')
restored = graph_store.load_triples_from_cache()
if restored:
log_structured('info', 'Successfully restored triples from cache.')
property_graph_index = PropertyGraphIndex(
nodes=[],
property_graph_store=property_graph_store,
)
if not graph_store.communities_built:
try:
graph_store.build_communities()
except Exception as e:
log_structured('error', f'Error building communities from restored data: {e}')
else:
# Cache restore failed — fall through to LLM extraction
log_structured('warning', 'Triples cache restore failed. Falling back to LLM extraction.')
if not nodes:
raise ValueError("Nodes must be provided when Neo4j is empty and cache restore fails")
kg_extractor = GraphRAGExtractor(
llm=llm,
extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
max_paths_per_chunk=max_paths_per_chunk,
parse_fn=custom_parse_fn,
)
property_graph_index = PropertyGraphIndex(
nodes=nodes,
kg_extractors=[kg_extractor],
property_graph_store=property_graph_store,
show_progress=True,
)
graph_store.save_triples_to_cache()
try:
graph_store.build_communities()
except Exception as e:
log_structured('error', f'Error building communities: {e}')
else:
# BRANCH 3: Full LLM extraction (force_reindex or no cache)
if not nodes:
raise ValueError("Nodes must be provided for indexing when Neo4j is empty or force_reindex=True")
kg_extractor = GraphRAGExtractor(
llm=llm,
extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
max_paths_per_chunk=max_paths_per_chunk,
parse_fn=custom_parse_fn,
)
if has_existing_content and force_reindex:
# Clear Neo4j before re-extraction
try:
from neo4j import GraphDatabase
driver = GraphDatabase.driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
with driver.session() as session:
session.run("MATCH (n) DETACH DELETE n")
driver.close()
except Exception as e:
log_structured('warning', f'Error clearing Neo4j database: {e}')
property_graph_index = PropertyGraphIndex(
nodes=nodes,
kg_extractors=[kg_extractor],
property_graph_store=property_graph_store,
show_progress=True,
)
# Cache the newly extracted triples
graph_store.save_triples_to_cache()
try:
graph_store.build_communities()
except Exception as e:
log_structured('error', f'Error building communities: {e}')
return graph_store, property_graph_index
def create_graphrag_query_engine(vector_retriever, graph_store, llm, similarity_top_k=20):
"""
Create GraphRAG query engine that combines vector and graph-based retrieval.
Args:
vector_retriever: VectorIndexRetriever for standard retrieval
graph_store: GraphRAGStore for community-based retrieval
llm: LLM for generating answer
similarity_top_k: Number of top results to retrieve
Returns:
GraphRAGQueryEngine: Query engine for hybrid retrieval
"""
from utils import log_structured
try:
# Explicitly validate inputs before passing to constructor
if vector_retriever is None:
raise ValueError("vector_retriever cannot be None")
if graph_store is None:
raise ValueError("graph_store cannot be None")
if llm is None:
raise ValueError("llm cannot be None")
# Log for debugging
log_structured('debug', 'Creating GraphRAGQueryEngine with parameters', {
'vector_retriever_type': type(vector_retriever).__name__,
'graph_store_type': type(graph_store).__name__,
'llm_type': type(llm).__name__,
'similarity_top_k': similarity_top_k
})
# Create the engine
return GraphRAGQueryEngine(
vector_retriever=vector_retriever,
graph_store=graph_store,
llm=llm,
similarity_top_k=similarity_top_k,
)
except Exception as e:
log_structured('error', f'Error in create_graphrag_query_engine: {e}')
raise # Re-raise the exception for proper handling
def generate_final_answer(query, retrieval_result, llm):
"""
Generate a final answer using both vector and graph-based context.
Args:
query: The user's query
retrieval_result: Result from GraphRAGQueryEngine with vector and graph contexts
llm: LLM for generating the final response
Returns:
str: The final answer
"""
vector_context = retrieval_result.get("vector_context", "")
graphrag_context = retrieval_result.get("graphrag_context", "")
# Log the contexts for debugging (truncated for brevity)
log_structured('debug', 'Generating final answer with dual context', {
'query': query,
'vector_context_length': len(vector_context),
'graphrag_context_length': len(graphrag_context)
})
if not vector_context and not graphrag_context:
return "I couldn't find any relevant information to answer your question."
# If no model was provided or we're forcing to use a specific model
if llm is None or not hasattr(llm, 'chat'):
# Fallback to gpt-4.1-mini for better cost efficiency
llm = OpenAI(model="gpt-4.1-mini")
log_structured('info', 'Using gpt-4.1-mini model for final answer generation')
prompt = f"""
Based on the following information from two different sources, please answer this question: {query}
SOURCE 1 - VECTOR RETRIEVAL:
{vector_context}
SOURCE 2 - KNOWLEDGE GRAPH COMMUNITIES:
{graphrag_context}
Answer the question based on all the provided information. If there are differences between the sources,
try to reconcile them or note the discrepancy. Please be concise and direct.
"""
messages = [
ChatMessage(role="system", content=prompt),
ChatMessage(role="user", content="Please provide a comprehensive answer based on all the information provided.")
]
response = llm.chat(messages)
# Extract just the message content, not the entire response object
if hasattr(response, 'message') and hasattr(response.message, 'content'):
content = response.message.content
elif hasattr(response, 'content'):
content = response.content
else:
# Fallback: convert to string but clean it
content = str(response)
# Clean any remaining thinking patterns from the response
import re
thinking_patterns = [
r'(?i)Thought:.*?Action:.*?Action Input:.*', # Remove the specific pattern
r'(?i)^Thought:.*', # Remove any line starting with "Thought:"
r'(?i)Action:.*?Action Input:.*', # Remove Action/Action Input patterns
r'(?i)^(Thought|Action|Observation):.*', # Remove ReAct patterns
]
for pattern in thinking_patterns:
content = re.sub(pattern, '', content, flags=re.DOTALL | re.MULTILINE)
# Clean up extra whitespace
content = re.sub(r'\n{3,}', '\n\n', content)
content = content.strip()
# Final safety check
if not content or re.search(r'(?i)^(Thought|Action|Observation):', content):
log_structured('warning', 'GraphRAG final answer still contains thinking patterns, using fallback')
content = "I found relevant information in the HP marketing materials that can help answer your question. Please let me know if you need more specific details."
return content