hp_chatbot/graphRAG.py
michael 594f749d4c Initial commit: HP Marketing Materials GraphRAG Chatbot
Full-stack GraphRAG chatbot for HP marketing materials with:
- Python/Flask backend with custom ReAct agent (LlamaIndex)
- Neo4j knowledge graph + vector search hybrid retrieval
- LlamaParse multimodal document processing (text + images)
- React/Vite frontend with conversation management
- MongoDB conversation persistence
- MSAL authentication support

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 08:37:58 -06:00

581 lines
No EOL
22 KiB
Python

import os
import json
import re
import asyncio
import networkx as nx
from collections import defaultdict
from typing import Any, List, Callable, Optional, Union, Dict
from dotenv import load_dotenv
# Import LlamaIndex components
from llama_index.core import Document, Settings
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import PropertyGraphIndex
from llama_index.core.async_utils import run_jobs
from llama_index.core.indices.property_graph.utils import default_parse_triplets_fn
from llama_index.core.graph_stores.types import EntityNode, KG_NODES_KEY, KG_RELATIONS_KEY, Relation
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT
from llama_index.core.schema import TransformComponent, BaseNode
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.core import SimpleDirectoryReader
# Community detection (using NetworkX instead of graspologic as a fallback)
try:
from community import best_partition # python-louvain package
except ImportError:
print("Community detection package not found, using NetworkX built-in community detection")
# Load environment variables from .env file
load_dotenv()
# Use API key from environment as fallback
if not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = "sk-proj-wXcoIn81Vwg4Iaw0vhmYT3BlbkFJmt1eOxeEAF1juUfhzMtk"
# Define the GraphRAGExtractor class
class GraphRAGExtractor(TransformComponent):
"""Extract triples from a graph.
Uses an LLM and a simple prompt + output parsing to
extract paths (i.e. triples) and entity, relation descriptions
from text.
Args:
llm (LLM):
The language model to use.
extract_prompt (Union[str, PromptTemplate]):
The prompt to use for extracting triples.
parse_fn (callable):
A function to parse the output of the language
model.
num_workers (int):
The number of workers to use for parallel
processing.
max_paths_per_chunk (int):
The maximum number of paths to extract per chunk.
"""
llm: LLM
extract_prompt: PromptTemplate
parse_fn: Callable
num_workers: int
max_paths_per_chunk: int
def __init__(
self,
llm: Optional[LLM] = None,
extract_prompt: Optional[Union[str, PromptTemplate]] = None,
parse_fn: Callable = default_parse_triplets_fn,
max_paths_per_chunk: int = 10,
num_workers: int = 4,
) -> None:
"""Init params."""
from llama_index.core import Settings
if isinstance(extract_prompt, str):
extract_prompt = PromptTemplate(extract_prompt)
super().__init__(
llm=llm or Settings.llm,
extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
parse_fn=parse_fn,
num_workers=num_workers,
max_paths_per_chunk=max_paths_per_chunk,
)
@classmethod
def class_name(cls) -> str:
return "GraphExtractor"
def __call__(
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
"""Extract triples from nodes."""
return asyncio.run(
self.acall(nodes, show_progress=show_progress, **kwargs)
)
async def _aextract(self, node: BaseNode) -> BaseNode:
"""Extract triples from a node."""
assert hasattr(node, "text")
text = node.get_content(metadata_mode="llm")
try:
llm_response = await self.llm.apredict(
self.extract_prompt,
text=text,
max_knowledge_triplets=self.max_paths_per_chunk,
)
entities, entities_relationship = self.parse_fn(llm_response)
except ValueError:
entities = []
entities_relationship = []
existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
entity_metadata = node.metadata.copy()
for entity, entity_type, description in entities:
entity_metadata["entity_description"] = description
entity_node = EntityNode(
name=entity, label=entity_type,
properties=entity_metadata
)
existing_nodes.append(entity_node)
relation_metadata = node.metadata.copy()
for triple in entities_relationship:
subj, obj, rel, description = triple
relation_metadata["relationship_description"] = description
rel_node = Relation(
label=rel,
source_id=subj,
target_id=obj,
properties=relation_metadata,
)
existing_relations.append(rel_node)
node.metadata[KG_NODES_KEY] = existing_nodes
node.metadata[KG_RELATIONS_KEY] = existing_relations
return node
async def acall(
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
"""Extract triples from nodes async."""
jobs = []
for node in nodes:
jobs.append(self._aextract(node))
return await run_jobs(
jobs,
workers=self.num_workers,
show_progress=show_progress,
desc="Extracting paths from text",
)
# Define the GraphRAGStore class
class GraphRAGStore(Neo4jPropertyGraphStore):
community_summary = {}
entity_info = None
max_cluster_size = 5
def generate_community_summary(self, text):
"""Generate summary for a given text using an LLM."""
messages = [
ChatMessage(
role="system",
content=(
"You are provided with a set of "
"relationships from a knowledge graph, each represented as "
"entity1->entity2->relation-"
">relationship_description. Your task is to create a summary of "
"these "
"relationships. The summary should include "
"the names of the entities involved and a concise synthesis "
"of the relationship descriptions. The "
"goal is to capture the most critical and relevant details that "
"highlight the nature and significance of "
"each relationship. Ensure that the summary is coherent and "
"integrates the information in a way that "
"emphasizes the key aspects of the relationships."
),
),
ChatMessage(role="user", content=text),
]
response = OpenAI().chat(messages)
clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
return clean_response
def build_communities(self):
"""Builds communities from the graph and summarizes them."""
nx_graph = self._create_nx_graph()
# Use either Leiden algorithm (from graspologic) or an alternative
try:
from graspologic.partition import hierarchical_leiden
community_hierarchical_clusters = hierarchical_leiden(
nx_graph, max_cluster_size=self.max_cluster_size
)
self.entity_info, community_info = self._collect_community_info(
nx_graph, community_hierarchical_clusters
)
except ImportError:
# Fallback to community detection using NetworkX or python-louvain
try:
from community import best_partition
partition = best_partition(nx_graph)
# Reformat partition data to expected structure
clusters = []
for node, cluster_id in partition.items():
class Cluster:
def __init__(self, node, cluster):
self.node = node
self.cluster = cluster
clusters.append(Cluster(node, cluster_id))
self.entity_info, community_info = self._collect_community_info(
nx_graph, clusters
)
except ImportError:
# Use NetworkX's built-in community detection
from networkx.algorithms import community
communities = community.greedy_modularity_communities(nx_graph)
clusters = []
for i, comm in enumerate(communities):
for node in comm:
class Cluster:
def __init__(self, node, cluster):
self.node = node
self.cluster = cluster
clusters.append(Cluster(node, i))
self.entity_info, community_info = self._collect_community_info(
nx_graph, clusters
)
self._summarize_communities(community_info)
def _create_nx_graph(self):
"""Converts internal graph representation to NetworkX graph."""
nx_graph = nx.Graph()
triplets = self.get_triplets()
for entity1, relation, entity2 in triplets:
nx_graph.add_node(entity1.name)
nx_graph.add_node(entity2.name)
nx_graph.add_edge(
relation.source_id,
relation.target_id,
relationship=relation.label,
description=relation.properties.get("relationship_description", "No description provided"),
)
return nx_graph
def _collect_community_info(self, nx_graph, clusters):
"""
Collect information for each node based on their community,
allowing entities to belong to multiple clusters.
"""
entity_info = defaultdict(set)
community_info = defaultdict(list)
for item in clusters:
node = item.node
cluster_id = item.cluster
# Update entity_info
entity_info[node].add(cluster_id)
for neighbor in nx_graph.neighbors(node):
edge_data = nx_graph.get_edge_data(node, neighbor)
if edge_data:
detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
community_info[cluster_id].append(detail)
# Convert sets to lists for easier serialization if needed
entity_info = {k: list(v) for k, v in entity_info.items()}
return dict(entity_info), dict(community_info)
def _summarize_communities(self, community_info):
"""Generate and store summaries for each community."""
for community_id, details in community_info.items():
details_text = "\n".join(details) + "." # Ensure it ends with a period
self.community_summary[community_id] = self.generate_community_summary(details_text)
def get_community_summaries(self):
"""Returns the community summaries, building them if not already done."""
if not self.community_summary:
self.build_communities()
return self.community_summary
# Define the GraphRAGQueryEngine class
class GraphRAGQueryEngine(CustomQueryEngine):
graph_store: Union[GraphRAGStore, Any] # Accept any type of graph store
index: PropertyGraphIndex
llm: LLM
similarity_top_k: int = 20
def custom_query(self, query_str: str) -> str:
"""Process query using either community-based approach or direct retrieval."""
# Check if we're using GraphRAGStore with communities or SimplePropertyGraphStore
if hasattr(self.graph_store, 'get_community_summaries'):
# GraphRAG approach with communities
entities = self.get_entities(query_str, self.similarity_top_k)
community_ids = self.retrieve_entity_communities(
self.graph_store.entity_info, entities
)
community_summaries = self.graph_store.get_community_summaries()
community_answers = [
self.generate_answer_from_summary(community_summary, query_str)
for id, community_summary in community_summaries.items()
if id in community_ids
]
final_answer = self.aggregate_answers(community_answers)
return final_answer
else:
# Simple approach for SimplePropertyGraphStore
# Just get relevant nodes and generate answer
nodes = self.index.as_retriever(
similarity_top_k=self.similarity_top_k
).retrieve(query_str)
if not nodes:
return "I couldn't find any relevant information to answer your question."
# Combine text from all retrieved nodes
context = "\n\n".join([node.get_content() for node in nodes])
# Generate answer using the LLM
prompt = f"Based on the following information, please answer this question: {query_str}\n\nInformation:\n{context}"
messages = [
ChatMessage(role="system", content=prompt),
ChatMessage(role="user", content="Please provide a comprehensive answer based on the information provided.")
]
response = self.llm.chat(messages)
return str(response).strip()
def get_entities(self, query_str, similarity_top_k):
nodes_retrieved = self.index.as_retriever(
similarity_top_k=similarity_top_k
).retrieve(query_str)
entities = set()
pattern = r"^(\w+(?:\s+\w+)*)\s*->\s*([a-zA-Z\s]+?)\s*->\s*(\w+(?:\s+\w+)*)$"
for node in nodes_retrieved:
matches = re.findall(
pattern, node.text, re.MULTILINE | re.IGNORECASE
)
for match in matches:
subject = match[0]
obj = match[2]
entities.add(subject)
entities.add(obj)
return list(entities)
def retrieve_entity_communities(self, entity_info, entities):
"""
Retrieve cluster information for given entities,
allowing for multiple clusters per entity.
Args:
entity_info (dict): Dictionary mapping entities to their cluster IDs (list).
entities (list): List of entity names to retrieve information for.
Returns:
List of community or cluster IDs to which an entity belongs.
"""
community_ids = []
for entity in entities:
if entity in entity_info:
community_ids.extend(entity_info[entity])
return list(set(community_ids))
def generate_answer_from_summary(self, community_summary, query):
"""Generate an answer from a community summary based on a given query using LLM."""
prompt = (
f"Given the community summary: {community_summary}, "
f"how would you answer the following query? Query: {query}"
)
messages = [
ChatMessage(role="system", content=prompt),
ChatMessage(
role="user",
content="I need an answer based on the above information.",
),
]
response = self.llm.chat(messages)
cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
return cleaned_response
def aggregate_answers(self, community_answers):
"""Aggregate individual community answers into a final, coherent response."""
prompt = "Combine the following intermediate answers into a final, concise response."
messages = [
ChatMessage(role="system", content=prompt),
ChatMessage(
role="user",
content=f"Intermediate answers: {community_answers}",
),
]
final_response = self.llm.chat(messages)
cleaned_final_response = re.sub(
r"^assistant:\s*", "", str(final_response)
).strip()
return cleaned_final_response
def custom_parse_fn(response_str: str) -> Any:
"""Custom parser for LLM responses that extract entities and relationships"""
json_pattern = r"\{.*\}"
match = re.search(json_pattern, response_str, re.DOTALL)
entities = []
relationships = []
if not match:
return entities, relationships
json_str = match.group(0)
try:
data = json.loads(json_str)
entities = [
(
entity["entity_name"],
entity["entity_type"],
entity.get("entity_description", f"Description of {entity['entity_name']}"),
)
for entity in data.get("entities", [])
]
relationships = [
(
relation["source_entity"],
relation["target_entity"],
relation["relation"],
relation.get("relationship_description", f"Relationship between {relation['source_entity']} and {relation['target_entity']}"),
)
for relation in data.get("relationships", [])
]
return entities, relationships
except (json.JSONDecodeError, KeyError) as e:
print(f"Error parsing response: {e}")
print(f"Problematic JSON: {json_str[:200]}...")
return entities, relationships
# Define the prompt template for triple extraction
KG_TRIPLET_EXTRACT_TMPL = """
-Goal-
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: Type of the entity
- entity_description: Comprehensive description of the entity's attributes and activities
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relation: relationship between source_entity and target_entity
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
3. Output Formatting:
- Return the result in valid JSON format with two keys: 'entities' (list of entity objects) and 'relationships' (list of relationship objects).
- Exclude any text outside the JSON structure (e.g., no explanations or comments).
- If no entities or relationships are identified, return empty lists: { "entities": [], "relationships": [] }.
-Real Data-
######################
text: {text}
######################
output:
"""
def main():
print("Starting GraphRAG document processing...")
# Load documents from specified directory
documents = SimpleDirectoryReader(
input_dir="supporting_files/files_for_rag_store"
).load_data()
print(f"Loaded {len(documents)} documents")
# Create nodes using a sentence splitter
splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
nodes = splitter.get_nodes_from_documents(documents)
print(f"Created {len(nodes)} nodes from documents")
# Initialize the LLM
llm = OpenAI(model="gpt-4")
# Create the knowledge graph extractor
kg_extractor = GraphRAGExtractor(
llm=llm,
extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
max_paths_per_chunk=2,
parse_fn=custom_parse_fn,
)
# Connect to Neo4j running in Docker
neo4j_username = "neo4j"
neo4j_password = "tavern-easy-museum-arthur-coconut-3483"
neo4j_url = "bolt://localhost:7687"
print(f"Connecting to Neo4j at {neo4j_url} with username '{neo4j_username}'")
# Create GraphRAGStore (our extended Neo4j store)
try:
graph_store = GraphRAGStore(
username=neo4j_username,
password=neo4j_password,
url=neo4j_url
)
print("Successfully connected to Neo4j database")
except Exception as e:
print(f"Error connecting to Neo4j: {e}")
print("Falling back to in-memory graph store. Some features may be limited.")
# Fallback to in-memory graph store if Neo4j connection fails
from llama_index.core.graph_stores import SimplePropertyGraphStore
graph_store = SimplePropertyGraphStore()
# Build the index
index = PropertyGraphIndex(
nodes=nodes,
kg_extractors=[kg_extractor],
property_graph_store=graph_store,
show_progress=True,
)
print("Building graph communities...")
try:
# Build communities for graph-based querying
# Only for GraphRAGStore, not for SimplePropertyGraphStore
if hasattr(graph_store, 'build_communities'):
graph_store.build_communities()
print("Communities built successfully")
else:
print("Skipping community building (not using Neo4j)")
except Exception as e:
print(f"Error building communities: {e}")
# Create the query engine
query_engine = GraphRAGQueryEngine(
graph_store=graph_store,
llm=llm,
index=index,
similarity_top_k=10,
)
# Simple interactive query loop
print("\n--- GraphRAG Query System Ready ---")
print("Type 'exit' to quit")
while True:
query = input("\nEnter your query: ")
if query.lower() in ('exit', 'quit'):
break
try:
response = query_engine.custom_query(query)
print("\nResponse:")
print(response)
except Exception as e:
print(f"Error processing query: {e}")
print("GraphRAG session ended")
if __name__ == "__main__":
main()