- Cache extracted triples to disk (neo4j_triples.pickle) so Neo4j can be repopulated without expensive LLM re-extraction on cold starts - Split initialization into two phases: fast vector-only (~1-2 min) and background GraphRAG, so the server serves requests while GraphRAG loads - Add GraphRAG status flags to shared_state for monitoring readiness - Update /status endpoint to expose graphrag_ready/initializing/error - Restructure main.py to use single event loop for background task support Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1008 lines
No EOL
41 KiB
Python
1008 lines
No EOL
41 KiB
Python
"""
|
|
HP GraphRAG Integration
|
|
|
|
Integrates GraphRAG functionality into the HP RAG pipeline.
|
|
- GraphRAG for knowledge graph construction from semantically split nodes
|
|
- Community detection and summarization for improved context retrieval
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import re
|
|
import asyncio
|
|
import networkx as nx
|
|
from collections import defaultdict
|
|
from typing import Any, List, Callable, Optional, Union, Dict
|
|
from pathlib import Path
|
|
|
|
# Import LlamaIndex components
|
|
from llama_index.core import Document, Settings
|
|
from llama_index.core.node_parser import SentenceSplitter
|
|
from llama_index.core import PropertyGraphIndex
|
|
from llama_index.core.async_utils import run_jobs
|
|
from llama_index.core.indices.property_graph.utils import default_parse_triplets_fn
|
|
from llama_index.core.graph_stores.types import EntityNode, KG_NODES_KEY, KG_RELATIONS_KEY, Relation
|
|
from llama_index.core.llms.llm import LLM
|
|
from llama_index.core.prompts import PromptTemplate
|
|
from llama_index.core.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT
|
|
from llama_index.core.schema import TransformComponent, BaseNode
|
|
from llama_index.core.query_engine import CustomQueryEngine
|
|
from llama_index.llms.openai import OpenAI
|
|
from llama_index.core.llms import ChatMessage
|
|
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
|
|
from llama_index.core import SimpleDirectoryReader
|
|
from llama_index.core.vector_stores.types import VectorStoreInfo, MetadataInfo
|
|
from llama_index.core.retrievers import VectorIndexRetriever
|
|
|
|
# Community detection (using NetworkX instead of graspologic as a fallback)
|
|
try:
|
|
from community import best_partition # python-louvain package
|
|
except ImportError:
|
|
print("Community detection package not found, using NetworkX built-in community detection")
|
|
|
|
# Import from our modules
|
|
from utils import logger, log_structured
|
|
from config import NEO4J_URL, NEO4J_USERNAME, NEO4J_PASSWORD
|
|
import config
|
|
|
|
# Define the GraphRAGExtractor class
|
|
class GraphRAGExtractor(TransformComponent):
|
|
"""Extract triples from a graph.
|
|
|
|
Uses an LLM and a simple prompt + output parsing to
|
|
extract paths (i.e. triples) and entity, relation descriptions
|
|
from text.
|
|
|
|
Args:
|
|
llm (LLM):
|
|
The language model to use.
|
|
extract_prompt (Union[str, PromptTemplate]):
|
|
The prompt to use for extracting triples.
|
|
parse_fn (callable):
|
|
A function to parse the output of the language
|
|
model.
|
|
num_workers (int):
|
|
The number of workers to use for parallel
|
|
processing.
|
|
max_paths_per_chunk (int):
|
|
The maximum number of paths to extract per chunk.
|
|
"""
|
|
|
|
llm: LLM
|
|
extract_prompt: PromptTemplate
|
|
parse_fn: Callable
|
|
num_workers: int
|
|
max_paths_per_chunk: int
|
|
|
|
def __init__(
|
|
self,
|
|
llm: Optional[LLM] = None,
|
|
extract_prompt: Optional[Union[str, PromptTemplate]] = None,
|
|
parse_fn: Callable = default_parse_triplets_fn,
|
|
max_paths_per_chunk: int = 10,
|
|
num_workers: int = 8,
|
|
) -> None:
|
|
"""Init params."""
|
|
from llama_index.core import Settings
|
|
|
|
if isinstance(extract_prompt, str):
|
|
extract_prompt = PromptTemplate(extract_prompt)
|
|
|
|
super().__init__(
|
|
llm=llm or Settings.llm,
|
|
extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
|
|
parse_fn=parse_fn,
|
|
num_workers=num_workers,
|
|
max_paths_per_chunk=max_paths_per_chunk,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "GraphExtractor"
|
|
|
|
def __call__(
|
|
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
|
|
) -> List[BaseNode]:
|
|
"""Extract triples from nodes."""
|
|
return asyncio.run(
|
|
self.acall(nodes, show_progress=show_progress, **kwargs)
|
|
)
|
|
|
|
async def _aextract(self, node: BaseNode) -> BaseNode:
|
|
"""Extract triples from a node."""
|
|
assert hasattr(node, "text")
|
|
|
|
text = node.get_content(metadata_mode="llm")
|
|
try:
|
|
llm_response = await self.llm.apredict(
|
|
self.extract_prompt,
|
|
text=text,
|
|
max_knowledge_triplets=self.max_paths_per_chunk,
|
|
)
|
|
entities, entities_relationship = self.parse_fn(llm_response)
|
|
except ValueError:
|
|
entities = []
|
|
entities_relationship = []
|
|
|
|
existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
|
|
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
|
|
|
|
entity_metadata = node.metadata.copy()
|
|
for entity, entity_type, description in entities:
|
|
entity_metadata["entity_description"] = description
|
|
entity_node = EntityNode(
|
|
name=entity, label=entity_type,
|
|
properties=entity_metadata
|
|
)
|
|
existing_nodes.append(entity_node)
|
|
|
|
relation_metadata = node.metadata.copy()
|
|
for triple in entities_relationship:
|
|
subj, obj, rel, description = triple
|
|
relation_metadata["relationship_description"] = description
|
|
rel_node = Relation(
|
|
label=rel,
|
|
source_id=subj,
|
|
target_id=obj,
|
|
properties=relation_metadata,
|
|
)
|
|
existing_relations.append(rel_node)
|
|
|
|
node.metadata[KG_NODES_KEY] = existing_nodes
|
|
node.metadata[KG_RELATIONS_KEY] = existing_relations
|
|
return node
|
|
|
|
async def acall(
|
|
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
|
|
) -> List[BaseNode]:
|
|
"""Extract triples from nodes async."""
|
|
jobs = []
|
|
for node in nodes:
|
|
jobs.append(self._aextract(node))
|
|
|
|
return await run_jobs(
|
|
jobs,
|
|
workers=self.num_workers,
|
|
show_progress=show_progress,
|
|
desc="Extracting paths from text",
|
|
)
|
|
|
|
# Define the GraphRAGStore class (integrating with Neo4j)
|
|
import pickle
|
|
import os
|
|
from pathlib import Path
|
|
|
|
class GraphRAGStore:
|
|
community_summary = {}
|
|
entity_info = None
|
|
max_cluster_size = 5
|
|
property_graph_store = None
|
|
communities_built = False # Track if communities have been built
|
|
|
|
# Path for cached community data
|
|
CACHE_DIR = Path("index_storage/graphrag_cache")
|
|
COMMUNITY_CACHE_FILE = CACHE_DIR / "community_summary.pickle"
|
|
ENTITY_INFO_CACHE_FILE = CACHE_DIR / "entity_info.pickle"
|
|
TRIPLES_CACHE_FILE = CACHE_DIR / "neo4j_triples.pickle"
|
|
|
|
def __init__(self, property_graph_store):
|
|
"""Initialize with a property_graph_store (Neo4j or in-memory)."""
|
|
self.property_graph_store = property_graph_store
|
|
self.community_summary = {}
|
|
self.entity_info = None
|
|
self.communities_built = False
|
|
|
|
# Ensure cache directory exists
|
|
os.makedirs(self.CACHE_DIR, exist_ok=True)
|
|
|
|
def add_nodes(self, nodes):
|
|
"""Add nodes to the property graph store."""
|
|
return self.property_graph_store.add_nodes(nodes)
|
|
|
|
def add_relationships(self, relationships):
|
|
"""Add relationships to the property graph store."""
|
|
return self.property_graph_store.add_relationships(relationships)
|
|
|
|
def get_triplets(self):
|
|
"""Get triplets from the property graph store."""
|
|
return self.property_graph_store.get_triplets()
|
|
|
|
def save_triples_to_cache(self):
|
|
"""Save extracted triples (entities + relationships) from Neo4j to a disk cache.
|
|
|
|
This allows restoring triples to Neo4j without expensive LLM re-extraction
|
|
if Neo4j data is lost (e.g., container recreated without volume persistence).
|
|
"""
|
|
try:
|
|
triplets = self.get_triplets()
|
|
if not triplets:
|
|
log_structured('warning', 'No triplets to cache — Neo4j appears empty')
|
|
return False
|
|
|
|
# Collect unique entities and relations from the triplets
|
|
entities = {}
|
|
relations = []
|
|
for entity1, relation, entity2 in triplets:
|
|
entities[entity1.name] = entity1
|
|
entities[entity2.name] = entity2
|
|
relations.append(relation)
|
|
|
|
cache_data = {
|
|
'entities': list(entities.values()),
|
|
'relations': relations,
|
|
'triplet_count': len(triplets),
|
|
}
|
|
|
|
with open(self.TRIPLES_CACHE_FILE, 'wb') as f:
|
|
pickle.dump(cache_data, f)
|
|
|
|
log_structured('info', 'Successfully cached Neo4j triples to disk', {
|
|
'entity_count': len(entities),
|
|
'relation_count': len(relations),
|
|
'triplet_count': len(triplets),
|
|
'cache_file': str(self.TRIPLES_CACHE_FILE)
|
|
})
|
|
return True
|
|
except Exception as e:
|
|
log_structured('error', f'Error saving triples cache: {e}')
|
|
return False
|
|
|
|
def load_triples_from_cache(self):
|
|
"""Load triples from disk cache and restore them to Neo4j.
|
|
|
|
Returns True if triples were successfully restored, False otherwise.
|
|
"""
|
|
if not self.TRIPLES_CACHE_FILE.exists():
|
|
log_structured('info', 'No triples cache file found')
|
|
return False
|
|
|
|
try:
|
|
with open(self.TRIPLES_CACHE_FILE, 'rb') as f:
|
|
cache_data = pickle.load(f)
|
|
|
|
entities = cache_data.get('entities', [])
|
|
relations = cache_data.get('relations', [])
|
|
|
|
if not entities:
|
|
log_structured('warning', 'Triples cache file exists but contains no entities')
|
|
return False
|
|
|
|
log_structured('info', 'Restoring triples from disk cache to Neo4j', {
|
|
'entity_count': len(entities),
|
|
'relation_count': len(relations),
|
|
'cached_triplet_count': cache_data.get('triplet_count', 'unknown')
|
|
})
|
|
|
|
# Restore entities (nodes) to Neo4j
|
|
self.property_graph_store.upsert_nodes(entities)
|
|
log_structured('info', f'Restored {len(entities)} entity nodes to Neo4j')
|
|
|
|
# Restore relations to Neo4j
|
|
if relations:
|
|
self.property_graph_store.upsert_relations(relations)
|
|
log_structured('info', f'Restored {len(relations)} relations to Neo4j')
|
|
|
|
# Verify restoration
|
|
restored_triplets = self.get_triplets()
|
|
log_structured('info', f'Neo4j now contains {len(restored_triplets)} triplets after cache restore')
|
|
|
|
return len(restored_triplets) > 0
|
|
except Exception as e:
|
|
log_structured('error', f'Error restoring triples from cache: {e}')
|
|
return False
|
|
|
|
def generate_community_summary(self, text):
|
|
"""Generate summary for a given text using an LLM with handling for large contexts."""
|
|
|
|
# Check if text is too long and chunk if needed
|
|
if len(text) > 30000: # Approximate character limit
|
|
log_structured('info', f'Community text is large ({len(text)} chars). Chunking for summarization.')
|
|
# Split into smaller chunks (simple approach)
|
|
chunks = [text[i:i+30000] for i in range(0, len(text), 30000)]
|
|
summaries = []
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
try:
|
|
# Use GPT-4o-mini model for better cost efficiency
|
|
llm = OpenAI(model="gpt-4.1-mini")
|
|
messages = [
|
|
ChatMessage(
|
|
role="system",
|
|
content="Summarize these knowledge graph relationships concisely."
|
|
),
|
|
ChatMessage(role="user", content=chunk),
|
|
]
|
|
response = llm.chat(messages)
|
|
summaries.append(str(response).strip())
|
|
log_structured('info', f'Successfully summarized community chunk {i+1}/{len(chunks)}')
|
|
except Exception as e:
|
|
log_structured('error', f'Error summarizing community chunk {i+1}/{len(chunks)}: {e}')
|
|
|
|
# Then summarize the summaries
|
|
if summaries:
|
|
final_summary_text = "\n\n".join(summaries)
|
|
try:
|
|
llm = OpenAI(model="gpt-4.1-mini")
|
|
messages = [
|
|
ChatMessage(
|
|
role="system",
|
|
content="Create a coherent summary from these partial summaries."
|
|
),
|
|
ChatMessage(role="user", content=final_summary_text),
|
|
]
|
|
response = llm.chat(messages)
|
|
return str(response).strip()
|
|
except Exception as e:
|
|
log_structured('error', f'Error creating final summary from chunks: {e}')
|
|
# Return the concatenated summaries if we can't summarize them
|
|
return final_summary_text
|
|
else:
|
|
return "Unable to generate community summary due to size limitations."
|
|
|
|
# For normal size text, use the larger model directly
|
|
try:
|
|
# Use GPT-4o-mini model for better cost efficiency
|
|
llm = OpenAI(model="gpt-4.1-mini")
|
|
messages = [
|
|
ChatMessage(
|
|
role="system",
|
|
content=(
|
|
"You are provided with a set of "
|
|
"relationships from a knowledge graph, each represented as "
|
|
"entity1->entity2->relation-"
|
|
">relationship_description. Your task is to create a summary of "
|
|
"these relationships. The summary should include "
|
|
"the names of the entities involved and a concise synthesis "
|
|
"of the relationship descriptions. The "
|
|
"goal is to capture the most critical and relevant details that "
|
|
"highlight the nature and significance of "
|
|
"each relationship. Ensure that the summary is coherent and "
|
|
"integrates the information in a way that "
|
|
"emphasizes the key aspects of the relationships."
|
|
),
|
|
),
|
|
ChatMessage(role="user", content=text),
|
|
]
|
|
response = llm.chat(messages)
|
|
clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
|
|
return clean_response
|
|
except Exception as e:
|
|
log_structured('error', f'Error generating community summary: {e}')
|
|
return f"Error summarizing community: {str(e)}"
|
|
|
|
def build_communities(self):
|
|
"""Builds communities from the graph and summarizes them."""
|
|
# Skip if communities are already built in this session
|
|
if self.communities_built:
|
|
log_structured('info', 'Communities already built in this session, skipping rebuild')
|
|
return
|
|
|
|
# First check if we can load from cache
|
|
if self.load_from_cache():
|
|
log_structured('info', 'Using cached community data instead of rebuilding')
|
|
self.communities_built = True
|
|
return
|
|
|
|
log_structured('info', 'Building communities from graph data')
|
|
nx_graph = self._create_nx_graph()
|
|
|
|
# Use either Leiden algorithm (from graspologic) or an alternative
|
|
try:
|
|
from graspologic.partition import hierarchical_leiden
|
|
community_hierarchical_clusters = hierarchical_leiden(
|
|
nx_graph, max_cluster_size=self.max_cluster_size
|
|
)
|
|
self.entity_info, community_info = self._collect_community_info(
|
|
nx_graph, community_hierarchical_clusters
|
|
)
|
|
except ImportError:
|
|
# Fallback to community detection using NetworkX or python-louvain
|
|
try:
|
|
from community import best_partition
|
|
partition = best_partition(nx_graph)
|
|
# Reformat partition data to expected structure
|
|
clusters = []
|
|
for node, cluster_id in partition.items():
|
|
class Cluster:
|
|
def __init__(self, node, cluster):
|
|
self.node = node
|
|
self.cluster = cluster
|
|
clusters.append(Cluster(node, cluster_id))
|
|
self.entity_info, community_info = self._collect_community_info(
|
|
nx_graph, clusters
|
|
)
|
|
except ImportError:
|
|
# Use NetworkX's built-in community detection
|
|
from networkx.algorithms import community
|
|
communities = community.greedy_modularity_communities(nx_graph)
|
|
clusters = []
|
|
for i, comm in enumerate(communities):
|
|
for node in comm:
|
|
class Cluster:
|
|
def __init__(self, node, cluster):
|
|
self.node = node
|
|
self.cluster = cluster
|
|
clusters.append(Cluster(node, i))
|
|
self.entity_info, community_info = self._collect_community_info(
|
|
nx_graph, clusters
|
|
)
|
|
|
|
self._summarize_communities(community_info)
|
|
|
|
# Cache the results after building
|
|
self.save_to_cache()
|
|
|
|
# Mark communities as built for this session
|
|
self.communities_built = True
|
|
|
|
def _create_nx_graph(self):
|
|
"""Converts internal graph representation to NetworkX graph."""
|
|
nx_graph = nx.Graph()
|
|
triplets = self.get_triplets()
|
|
for entity1, relation, entity2 in triplets:
|
|
nx_graph.add_node(entity1.name)
|
|
nx_graph.add_node(entity2.name)
|
|
nx_graph.add_edge(
|
|
relation.source_id,
|
|
relation.target_id,
|
|
relationship=relation.label,
|
|
description=relation.properties.get("relationship_description", "No description provided"),
|
|
)
|
|
return nx_graph
|
|
|
|
def _collect_community_info(self, nx_graph, clusters):
|
|
"""
|
|
Collect information for each node based on their community,
|
|
allowing entities to belong to multiple clusters.
|
|
"""
|
|
entity_info = defaultdict(set)
|
|
community_info = defaultdict(list)
|
|
|
|
for item in clusters:
|
|
node = item.node
|
|
cluster_id = item.cluster
|
|
|
|
# Update entity_info
|
|
entity_info[node].add(cluster_id)
|
|
|
|
for neighbor in nx_graph.neighbors(node):
|
|
edge_data = nx_graph.get_edge_data(node, neighbor)
|
|
if edge_data:
|
|
detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
|
|
community_info[cluster_id].append(detail)
|
|
|
|
# Convert sets to lists for easier serialization if needed
|
|
entity_info = {k: list(v) for k, v in entity_info.items()}
|
|
|
|
return dict(entity_info), dict(community_info)
|
|
|
|
def _summarize_communities(self, community_info):
|
|
"""Generate and store summaries for each community."""
|
|
for community_id, details in community_info.items():
|
|
details_text = "\n".join(details) + "." # Ensure it ends with a period
|
|
self.community_summary[community_id] = self.generate_community_summary(details_text)
|
|
|
|
def save_to_cache(self):
|
|
"""Save community data to disk cache."""
|
|
try:
|
|
# Save community summary
|
|
with open(self.COMMUNITY_CACHE_FILE, 'wb') as f:
|
|
pickle.dump(self.community_summary, f)
|
|
|
|
# Save entity info
|
|
with open(self.ENTITY_INFO_CACHE_FILE, 'wb') as f:
|
|
pickle.dump(self.entity_info, f)
|
|
|
|
log_structured('info', 'Successfully cached GraphRAG community data',
|
|
{'community_count': len(self.community_summary),
|
|
'entity_count': len(self.entity_info) if self.entity_info else 0})
|
|
return True
|
|
except Exception as e:
|
|
log_structured('error', f'Error saving GraphRAG cache: {e}')
|
|
return False
|
|
|
|
def load_from_cache(self):
|
|
"""Load community data from disk cache if available."""
|
|
if not self.COMMUNITY_CACHE_FILE.exists() or not self.ENTITY_INFO_CACHE_FILE.exists():
|
|
log_structured('info', 'GraphRAG cache files not found, will build communities from scratch')
|
|
return False
|
|
|
|
try:
|
|
# Load community summary
|
|
with open(self.COMMUNITY_CACHE_FILE, 'rb') as f:
|
|
self.community_summary = pickle.load(f)
|
|
|
|
# Load entity info
|
|
with open(self.ENTITY_INFO_CACHE_FILE, 'rb') as f:
|
|
self.entity_info = pickle.load(f)
|
|
|
|
log_structured('info', 'Successfully loaded GraphRAG community data from cache',
|
|
{'community_count': len(self.community_summary),
|
|
'entity_count': len(self.entity_info) if self.entity_info else 0})
|
|
|
|
# Mark communities as built when successfully loaded from cache
|
|
self.communities_built = True
|
|
return True
|
|
except Exception as e:
|
|
log_structured('error', f'Error loading GraphRAG cache: {e}')
|
|
# Reset to empty in case of partial load
|
|
self.community_summary = {}
|
|
self.entity_info = None
|
|
self.communities_built = False
|
|
return False
|
|
|
|
def get_community_summaries(self):
|
|
"""Returns the community summaries, building them if not already done."""
|
|
if not self.community_summary:
|
|
# Try to load from cache first
|
|
if not self.load_from_cache():
|
|
# If cache load fails, build from scratch
|
|
self.build_communities()
|
|
# Cache the results for next time
|
|
self.save_to_cache()
|
|
return self.community_summary
|
|
|
|
# Define the GraphRAGQueryEngine class
|
|
from typing import Dict, Any
|
|
|
|
class GraphRAGQueryEngine:
|
|
"""Query engine that combines vector retrieval with graph-based community retrieval."""
|
|
|
|
def __init__(
|
|
self,
|
|
vector_retriever: VectorIndexRetriever,
|
|
graph_store: GraphRAGStore,
|
|
llm: Optional[LLM] = None,
|
|
similarity_top_k: int = 20
|
|
):
|
|
"""Initialize with both a vector retriever and graph store."""
|
|
# Initialize all required fields
|
|
self.vector_retriever = vector_retriever
|
|
self.graph_store = graph_store
|
|
self.llm = llm or Settings.llm
|
|
self.similarity_top_k = similarity_top_k
|
|
|
|
# Check if communities are built, but don't try to build them here
|
|
# since that might cause errors with large graphs
|
|
if not hasattr(self.graph_store, 'entity_info') or self.graph_store.entity_info is None:
|
|
log_structured('warning', 'GraphRAGQueryEngine initialized without community data. Vector retrieval will still work, but community retrieval may be limited.')
|
|
|
|
def custom_query(self, query_str: str) -> Dict:
|
|
"""Process query using both vector retrieval and community-based approach."""
|
|
log_structured('info', 'GraphRAG query engine: Starting dual retrieval', {'query': query_str})
|
|
|
|
# Step 1: Get vector search results
|
|
vector_nodes = self.vector_retriever.retrieve(query_str)
|
|
vector_context = "\n\n".join([node.node.get_content() for node in vector_nodes])
|
|
log_structured('info', 'GraphRAG query engine: Vector retrieval complete',
|
|
{'node_count': len(vector_nodes)})
|
|
|
|
# Step 2: Get GraphRAG community results (if communities exist)
|
|
graphrag_context = ""
|
|
community_ids = []
|
|
|
|
if hasattr(self.graph_store, 'entity_info') and self.graph_store.entity_info is not None:
|
|
try:
|
|
entities = self.get_entities(query_str, vector_nodes)
|
|
community_ids = self.retrieve_entity_communities(self.graph_store.entity_info, entities)
|
|
|
|
try:
|
|
community_summaries = self.graph_store.get_community_summaries()
|
|
|
|
if community_ids:
|
|
filtered_summaries = {id: summary for id, summary in community_summaries.items()
|
|
if id in community_ids}
|
|
graphrag_context = "\n\n".join(filtered_summaries.values())
|
|
log_structured('info', 'GraphRAG query engine: Community retrieval complete',
|
|
{'community_count': len(filtered_summaries)})
|
|
else:
|
|
log_structured('info', 'GraphRAG query engine: No relevant communities found')
|
|
except Exception as e:
|
|
log_structured('error', f'Error getting community summaries: {e}')
|
|
# Continue without graph context
|
|
except Exception as e:
|
|
log_structured('error', f'Error during community retrieval: {e}')
|
|
# Continue with just vector context
|
|
else:
|
|
log_structured('warning', 'GraphRAG query engine: No community data available. Using only vector retrieval.')
|
|
|
|
# Step 3: Combine contexts and generate answer
|
|
combined_result = {
|
|
"vector_context": vector_context,
|
|
"graphrag_context": graphrag_context,
|
|
"vector_nodes": vector_nodes,
|
|
"community_ids": community_ids
|
|
}
|
|
|
|
return combined_result
|
|
|
|
def get_entities(self, query_str, vector_nodes):
|
|
"""Extract entities from vector nodes that match the query."""
|
|
entities = set()
|
|
|
|
# Extract entities from the retrieved nodes
|
|
for node_with_score in vector_nodes:
|
|
node = node_with_score.node
|
|
if hasattr(node, 'metadata') and KG_NODES_KEY in node.metadata:
|
|
for entity_node in node.metadata[KG_NODES_KEY]:
|
|
if hasattr(entity_node, 'name'):
|
|
entities.add(entity_node.name)
|
|
|
|
# If no entities were found in metadata, try extracting them from text
|
|
if not entities:
|
|
pattern = r"(?:^|\s)([A-Z][a-zA-Z0-9\s]+)(?:\s|$)"
|
|
for node_with_score in vector_nodes:
|
|
matches = re.findall(pattern, node_with_score.node.get_content())
|
|
entities.update(matches)
|
|
|
|
log_structured('debug', 'GraphRAG query engine: Extracted entities',
|
|
{'entities': list(entities), 'count': len(entities)})
|
|
return list(entities)
|
|
|
|
def retrieve_entity_communities(self, entity_info, entities):
|
|
"""
|
|
Retrieve cluster information for given entities,
|
|
allowing for multiple clusters per entity.
|
|
|
|
Args:
|
|
entity_info (dict): Dictionary mapping entities to their cluster IDs (list).
|
|
entities (list): List of entity names to retrieve information for.
|
|
|
|
Returns:
|
|
List of community or cluster IDs to which an entity belongs.
|
|
"""
|
|
community_ids = []
|
|
|
|
for entity in entities:
|
|
if entity in entity_info:
|
|
community_ids.extend(entity_info[entity])
|
|
else:
|
|
# Try case-insensitive matching as fallback
|
|
for stored_entity, clusters in entity_info.items():
|
|
if stored_entity.lower() == entity.lower():
|
|
community_ids.extend(clusters)
|
|
break
|
|
|
|
return list(set(community_ids))
|
|
|
|
def custom_parse_fn(response_str: str) -> Any:
|
|
"""Custom parser for LLM responses that extract entities and relationships"""
|
|
json_pattern = r"\{.*\}"
|
|
match = re.search(json_pattern, response_str, re.DOTALL)
|
|
entities = []
|
|
relationships = []
|
|
|
|
if not match:
|
|
return entities, relationships
|
|
|
|
json_str = match.group(0)
|
|
try:
|
|
data = json.loads(json_str)
|
|
entities = [
|
|
(
|
|
entity["entity_name"],
|
|
entity["entity_type"],
|
|
entity.get("entity_description", f"Description of {entity['entity_name']}"),
|
|
)
|
|
for entity in data.get("entities", [])
|
|
]
|
|
relationships = [
|
|
(
|
|
relation["source_entity"],
|
|
relation["target_entity"],
|
|
relation["relation"],
|
|
relation.get("relationship_description", f"Relationship between {relation['source_entity']} and {relation['target_entity']}"),
|
|
)
|
|
for relation in data.get("relationships", [])
|
|
]
|
|
return entities, relationships
|
|
except (json.JSONDecodeError, KeyError) as e:
|
|
log_structured('error', f"Error parsing response: {e}", {'json_str': json_str[:200]})
|
|
return entities, relationships
|
|
|
|
# Define the prompt template for triple extraction
|
|
KG_TRIPLET_EXTRACT_TMPL = """
|
|
-Goal-
|
|
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
|
|
|
|
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.
|
|
|
|
-Steps-
|
|
1. Identify all entities. For each identified entity, extract the following information:
|
|
- entity_name: Name of the entity, capitalized
|
|
- entity_type: Type of the entity
|
|
- entity_description: Comprehensive description of the entity's attributes and activities
|
|
|
|
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
|
|
For each pair of related entities, extract the following information:
|
|
- source_entity: name of the source entity, as identified in step 1
|
|
- target_entity: name of the target entity, as identified in step 1
|
|
- relation: relationship between source_entity and target_entity
|
|
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
|
|
|
|
3. Output Formatting:
|
|
- Return the result in valid JSON format with two keys: 'entities' (list of entity objects) and 'relationships' (list of relationship objects).
|
|
- Exclude any text outside the JSON structure (e.g., no explanations or comments).
|
|
- If no entities or relationships are identified, return empty lists: { "entities": [], "relationships": [] }.
|
|
|
|
-Real Data-
|
|
######################
|
|
text: {text}
|
|
######################
|
|
output:
|
|
"""
|
|
|
|
def create_graph_components(llm, nodes=None, max_paths_per_chunk=10, force_reindex=False):
|
|
"""
|
|
Create GraphRAG components for the HP RAG pipeline.
|
|
|
|
Args:
|
|
llm: The LLM to use for graph extraction and querying
|
|
nodes: List of nodes to create the graph from (only used if indexing is needed)
|
|
max_paths_per_chunk: Maximum number of paths to extract per chunk
|
|
force_reindex: If True, always recreate the index even if content exists
|
|
|
|
Returns:
|
|
tuple: (graph_store, property_graph_index)
|
|
"""
|
|
log_structured('info', 'Creating GraphRAG components')
|
|
|
|
# Note: The graph_store object created here will automatically:
|
|
# 1. Try to load community data from cache files when build_communities() is called
|
|
# 2. Save to cache after building communities if loading failed
|
|
|
|
# Connect to Neo4j - hard error if not available
|
|
property_graph_store = None
|
|
try:
|
|
log_structured('info', f'Connecting to Neo4j at {NEO4J_URL}')
|
|
property_graph_store = Neo4jPropertyGraphStore(
|
|
username=NEO4J_USERNAME,
|
|
password=NEO4J_PASSWORD,
|
|
url=NEO4J_URL
|
|
)
|
|
log_structured('info', 'Successfully connected to Neo4j database')
|
|
except Exception as e:
|
|
log_structured('critical', f'FATAL ERROR: Cannot connect to Neo4j: {e}')
|
|
raise RuntimeError(f"Neo4j connection failed. This application requires Neo4j to be running. Error: {e}")
|
|
|
|
# Create GraphRAGStore wrapper
|
|
graph_store = GraphRAGStore(property_graph_store)
|
|
|
|
# Check if Neo4j already has content
|
|
triplets = graph_store.get_triplets()
|
|
has_existing_content = len(triplets) > 0
|
|
|
|
log_structured('info', f'Neo4j check: Found {len(triplets)} triplets')
|
|
|
|
if has_existing_content and not force_reindex:
|
|
# BRANCH 1: Neo4j has data — use it, but also ensure disk cache exists
|
|
log_structured('info', f'Neo4j already contains {len(triplets)} triplets. Skipping indexing.')
|
|
|
|
# Ensure triples are also cached to disk for future recovery
|
|
if not graph_store.TRIPLES_CACHE_FILE.exists():
|
|
log_structured('info', 'Neo4j has data but no triples cache on disk — creating cache now')
|
|
graph_store.save_triples_to_cache()
|
|
|
|
# Create a minimal PropertyGraphIndex without indexing
|
|
property_graph_index = PropertyGraphIndex(
|
|
nodes=[],
|
|
property_graph_store=property_graph_store,
|
|
)
|
|
|
|
if not graph_store.communities_built:
|
|
log_structured('info', 'Building graph communities from existing Neo4j data')
|
|
try:
|
|
graph_store.build_communities()
|
|
except Exception as e:
|
|
log_structured('error', f'Error building communities: {e}')
|
|
|
|
elif not has_existing_content and not force_reindex and graph_store.TRIPLES_CACHE_FILE.exists():
|
|
# BRANCH 2: Neo4j is empty but triples cache exists — restore from cache
|
|
log_structured('info', 'Neo4j is empty but triples cache exists. Restoring from disk cache.')
|
|
|
|
restored = graph_store.load_triples_from_cache()
|
|
if restored:
|
|
log_structured('info', 'Successfully restored triples from cache.')
|
|
|
|
property_graph_index = PropertyGraphIndex(
|
|
nodes=[],
|
|
property_graph_store=property_graph_store,
|
|
)
|
|
|
|
if not graph_store.communities_built:
|
|
try:
|
|
graph_store.build_communities()
|
|
except Exception as e:
|
|
log_structured('error', f'Error building communities from restored data: {e}')
|
|
else:
|
|
# Cache restore failed — fall through to LLM extraction
|
|
log_structured('warning', 'Triples cache restore failed. Falling back to LLM extraction.')
|
|
if not nodes:
|
|
raise ValueError("Nodes must be provided when Neo4j is empty and cache restore fails")
|
|
|
|
kg_extractor = GraphRAGExtractor(
|
|
llm=llm,
|
|
extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
|
|
max_paths_per_chunk=max_paths_per_chunk,
|
|
parse_fn=custom_parse_fn,
|
|
)
|
|
|
|
property_graph_index = PropertyGraphIndex(
|
|
nodes=nodes,
|
|
kg_extractors=[kg_extractor],
|
|
property_graph_store=property_graph_store,
|
|
show_progress=True,
|
|
)
|
|
|
|
graph_store.save_triples_to_cache()
|
|
|
|
try:
|
|
graph_store.build_communities()
|
|
except Exception as e:
|
|
log_structured('error', f'Error building communities: {e}')
|
|
|
|
else:
|
|
# BRANCH 3: Full LLM extraction (force_reindex or no cache)
|
|
if not nodes:
|
|
raise ValueError("Nodes must be provided for indexing when Neo4j is empty or force_reindex=True")
|
|
|
|
kg_extractor = GraphRAGExtractor(
|
|
llm=llm,
|
|
extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
|
|
max_paths_per_chunk=max_paths_per_chunk,
|
|
parse_fn=custom_parse_fn,
|
|
)
|
|
|
|
if has_existing_content and force_reindex:
|
|
# Clear Neo4j before re-extraction
|
|
try:
|
|
from neo4j import GraphDatabase
|
|
driver = GraphDatabase.driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
|
with driver.session() as session:
|
|
session.run("MATCH (n) DETACH DELETE n")
|
|
driver.close()
|
|
except Exception as e:
|
|
log_structured('warning', f'Error clearing Neo4j database: {e}')
|
|
|
|
property_graph_index = PropertyGraphIndex(
|
|
nodes=nodes,
|
|
kg_extractors=[kg_extractor],
|
|
property_graph_store=property_graph_store,
|
|
show_progress=True,
|
|
)
|
|
|
|
# Cache the newly extracted triples
|
|
graph_store.save_triples_to_cache()
|
|
|
|
try:
|
|
graph_store.build_communities()
|
|
except Exception as e:
|
|
log_structured('error', f'Error building communities: {e}')
|
|
|
|
return graph_store, property_graph_index
|
|
|
|
def create_graphrag_query_engine(vector_retriever, graph_store, llm, similarity_top_k=20):
|
|
"""
|
|
Create GraphRAG query engine that combines vector and graph-based retrieval.
|
|
|
|
Args:
|
|
vector_retriever: VectorIndexRetriever for standard retrieval
|
|
graph_store: GraphRAGStore for community-based retrieval
|
|
llm: LLM for generating answer
|
|
similarity_top_k: Number of top results to retrieve
|
|
|
|
Returns:
|
|
GraphRAGQueryEngine: Query engine for hybrid retrieval
|
|
"""
|
|
from utils import log_structured
|
|
|
|
try:
|
|
# Explicitly validate inputs before passing to constructor
|
|
if vector_retriever is None:
|
|
raise ValueError("vector_retriever cannot be None")
|
|
if graph_store is None:
|
|
raise ValueError("graph_store cannot be None")
|
|
if llm is None:
|
|
raise ValueError("llm cannot be None")
|
|
|
|
# Log for debugging
|
|
log_structured('debug', 'Creating GraphRAGQueryEngine with parameters', {
|
|
'vector_retriever_type': type(vector_retriever).__name__,
|
|
'graph_store_type': type(graph_store).__name__,
|
|
'llm_type': type(llm).__name__,
|
|
'similarity_top_k': similarity_top_k
|
|
})
|
|
|
|
# Create the engine
|
|
return GraphRAGQueryEngine(
|
|
vector_retriever=vector_retriever,
|
|
graph_store=graph_store,
|
|
llm=llm,
|
|
similarity_top_k=similarity_top_k,
|
|
)
|
|
except Exception as e:
|
|
log_structured('error', f'Error in create_graphrag_query_engine: {e}')
|
|
raise # Re-raise the exception for proper handling
|
|
|
|
def generate_final_answer(query, retrieval_result, llm):
|
|
"""
|
|
Generate a final answer using both vector and graph-based context.
|
|
|
|
Args:
|
|
query: The user's query
|
|
retrieval_result: Result from GraphRAGQueryEngine with vector and graph contexts
|
|
llm: LLM for generating the final response
|
|
|
|
Returns:
|
|
str: The final answer
|
|
"""
|
|
vector_context = retrieval_result.get("vector_context", "")
|
|
graphrag_context = retrieval_result.get("graphrag_context", "")
|
|
|
|
# Log the contexts for debugging (truncated for brevity)
|
|
log_structured('debug', 'Generating final answer with dual context', {
|
|
'query': query,
|
|
'vector_context_length': len(vector_context),
|
|
'graphrag_context_length': len(graphrag_context)
|
|
})
|
|
|
|
if not vector_context and not graphrag_context:
|
|
return "I couldn't find any relevant information to answer your question."
|
|
|
|
# If no model was provided or we're forcing to use a specific model
|
|
if llm is None or not hasattr(llm, 'chat'):
|
|
# Fallback to gpt-4.1-mini for better cost efficiency
|
|
llm = OpenAI(model="gpt-4.1-mini")
|
|
log_structured('info', 'Using gpt-4.1-mini model for final answer generation')
|
|
|
|
prompt = f"""
|
|
Based on the following information from two different sources, please answer this question: {query}
|
|
|
|
SOURCE 1 - VECTOR RETRIEVAL:
|
|
{vector_context}
|
|
|
|
SOURCE 2 - KNOWLEDGE GRAPH COMMUNITIES:
|
|
{graphrag_context}
|
|
|
|
Answer the question based on all the provided information. If there are differences between the sources,
|
|
try to reconcile them or note the discrepancy. Please be concise and direct.
|
|
"""
|
|
|
|
messages = [
|
|
ChatMessage(role="system", content=prompt),
|
|
ChatMessage(role="user", content="Please provide a comprehensive answer based on all the information provided.")
|
|
]
|
|
|
|
response = llm.chat(messages)
|
|
|
|
# Extract just the message content, not the entire response object
|
|
if hasattr(response, 'message') and hasattr(response.message, 'content'):
|
|
content = response.message.content
|
|
elif hasattr(response, 'content'):
|
|
content = response.content
|
|
else:
|
|
# Fallback: convert to string but clean it
|
|
content = str(response)
|
|
|
|
# Clean any remaining thinking patterns from the response
|
|
import re
|
|
thinking_patterns = [
|
|
r'(?i)Thought:.*?Action:.*?Action Input:.*', # Remove the specific pattern
|
|
r'(?i)^Thought:.*', # Remove any line starting with "Thought:"
|
|
r'(?i)Action:.*?Action Input:.*', # Remove Action/Action Input patterns
|
|
r'(?i)^(Thought|Action|Observation):.*', # Remove ReAct patterns
|
|
]
|
|
|
|
for pattern in thinking_patterns:
|
|
content = re.sub(pattern, '', content, flags=re.DOTALL | re.MULTILINE)
|
|
|
|
# Clean up extra whitespace
|
|
content = re.sub(r'\n{3,}', '\n\n', content)
|
|
content = content.strip()
|
|
|
|
# Final safety check
|
|
if not content or re.search(r'(?i)^(Thought|Action|Observation):', content):
|
|
log_structured('warning', 'GraphRAG final answer still contains thinking patterns, using fallback')
|
|
content = "I found relevant information in the HP marketing materials that can help answer your question. Please let me know if you need more specific details."
|
|
|
|
return content |