Full-stack application combining LlamaIndex vector search with Neo4j knowledge graph (GraphRAG) for answering queries about Netflix marketing materials. Flask/Hypercorn backend with custom ReAct agent, React frontend. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
581 lines
No EOL
22 KiB
Python
581 lines
No EOL
22 KiB
Python
import os
|
|
import json
|
|
import re
|
|
import asyncio
|
|
import networkx as nx
|
|
from collections import defaultdict
|
|
from typing import Any, List, Callable, Optional, Union, Dict
|
|
from dotenv import load_dotenv
|
|
|
|
# Import LlamaIndex components
|
|
from llama_index.core import Document, Settings
|
|
from llama_index.core.node_parser import SentenceSplitter
|
|
from llama_index.core import PropertyGraphIndex
|
|
from llama_index.core.async_utils import run_jobs
|
|
from llama_index.core.indices.property_graph.utils import default_parse_triplets_fn
|
|
from llama_index.core.graph_stores.types import EntityNode, KG_NODES_KEY, KG_RELATIONS_KEY, Relation
|
|
from llama_index.core.llms.llm import LLM
|
|
from llama_index.core.prompts import PromptTemplate
|
|
from llama_index.core.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT
|
|
from llama_index.core.schema import TransformComponent, BaseNode
|
|
from llama_index.core.query_engine import CustomQueryEngine
|
|
from llama_index.llms.openai import OpenAI
|
|
from llama_index.core.llms import ChatMessage
|
|
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
|
|
from llama_index.core import SimpleDirectoryReader
|
|
|
|
# Community detection (using NetworkX instead of graspologic as a fallback)
|
|
try:
|
|
from community import best_partition # python-louvain package
|
|
except ImportError:
|
|
print("Community detection package not found, using NetworkX built-in community detection")
|
|
|
|
# Load environment variables from .env file and set OpenAI API key from netflix_back_end.py
|
|
load_dotenv()
|
|
|
|
# Use API key from netflix_back_end.py as fallback
|
|
if not os.environ.get("OPENAI_API_KEY"):
|
|
os.environ["OPENAI_API_KEY"] = "sk-proj-wXcoIn81Vwg4Iaw0vhmYT3BlbkFJmt1eOxeEAF1juUfhzMtk"
|
|
|
|
# Define the GraphRAGExtractor class
|
|
class GraphRAGExtractor(TransformComponent):
|
|
"""Extract triples from a graph.
|
|
|
|
Uses an LLM and a simple prompt + output parsing to
|
|
extract paths (i.e. triples) and entity, relation descriptions
|
|
from text.
|
|
|
|
Args:
|
|
llm (LLM):
|
|
The language model to use.
|
|
extract_prompt (Union[str, PromptTemplate]):
|
|
The prompt to use for extracting triples.
|
|
parse_fn (callable):
|
|
A function to parse the output of the language
|
|
model.
|
|
num_workers (int):
|
|
The number of workers to use for parallel
|
|
processing.
|
|
max_paths_per_chunk (int):
|
|
The maximum number of paths to extract per chunk.
|
|
"""
|
|
|
|
llm: LLM
|
|
extract_prompt: PromptTemplate
|
|
parse_fn: Callable
|
|
num_workers: int
|
|
max_paths_per_chunk: int
|
|
|
|
def __init__(
|
|
self,
|
|
llm: Optional[LLM] = None,
|
|
extract_prompt: Optional[Union[str, PromptTemplate]] = None,
|
|
parse_fn: Callable = default_parse_triplets_fn,
|
|
max_paths_per_chunk: int = 10,
|
|
num_workers: int = 4,
|
|
) -> None:
|
|
"""Init params."""
|
|
from llama_index.core import Settings
|
|
|
|
if isinstance(extract_prompt, str):
|
|
extract_prompt = PromptTemplate(extract_prompt)
|
|
|
|
super().__init__(
|
|
llm=llm or Settings.llm,
|
|
extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
|
|
parse_fn=parse_fn,
|
|
num_workers=num_workers,
|
|
max_paths_per_chunk=max_paths_per_chunk,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "GraphExtractor"
|
|
|
|
def __call__(
|
|
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
|
|
) -> List[BaseNode]:
|
|
"""Extract triples from nodes."""
|
|
return asyncio.run(
|
|
self.acall(nodes, show_progress=show_progress, **kwargs)
|
|
)
|
|
|
|
async def _aextract(self, node: BaseNode) -> BaseNode:
|
|
"""Extract triples from a node."""
|
|
assert hasattr(node, "text")
|
|
|
|
text = node.get_content(metadata_mode="llm")
|
|
try:
|
|
llm_response = await self.llm.apredict(
|
|
self.extract_prompt,
|
|
text=text,
|
|
max_knowledge_triplets=self.max_paths_per_chunk,
|
|
)
|
|
entities, entities_relationship = self.parse_fn(llm_response)
|
|
except ValueError:
|
|
entities = []
|
|
entities_relationship = []
|
|
|
|
existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
|
|
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
|
|
|
|
entity_metadata = node.metadata.copy()
|
|
for entity, entity_type, description in entities:
|
|
entity_metadata["entity_description"] = description
|
|
entity_node = EntityNode(
|
|
name=entity, label=entity_type,
|
|
properties=entity_metadata
|
|
)
|
|
existing_nodes.append(entity_node)
|
|
|
|
relation_metadata = node.metadata.copy()
|
|
for triple in entities_relationship:
|
|
subj, obj, rel, description = triple
|
|
relation_metadata["relationship_description"] = description
|
|
rel_node = Relation(
|
|
label=rel,
|
|
source_id=subj,
|
|
target_id=obj,
|
|
properties=relation_metadata,
|
|
)
|
|
existing_relations.append(rel_node)
|
|
|
|
node.metadata[KG_NODES_KEY] = existing_nodes
|
|
node.metadata[KG_RELATIONS_KEY] = existing_relations
|
|
return node
|
|
|
|
async def acall(
|
|
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
|
|
) -> List[BaseNode]:
|
|
"""Extract triples from nodes async."""
|
|
jobs = []
|
|
for node in nodes:
|
|
jobs.append(self._aextract(node))
|
|
|
|
return await run_jobs(
|
|
jobs,
|
|
workers=self.num_workers,
|
|
show_progress=show_progress,
|
|
desc="Extracting paths from text",
|
|
)
|
|
|
|
# Define the GraphRAGStore class
|
|
class GraphRAGStore(Neo4jPropertyGraphStore):
|
|
community_summary = {}
|
|
entity_info = None
|
|
max_cluster_size = 5
|
|
|
|
def generate_community_summary(self, text):
|
|
"""Generate summary for a given text using an LLM."""
|
|
messages = [
|
|
ChatMessage(
|
|
role="system",
|
|
content=(
|
|
"You are provided with a set of "
|
|
"relationships from a knowledge graph, each represented as "
|
|
"entity1->entity2->relation-"
|
|
">relationship_description. Your task is to create a summary of "
|
|
"these "
|
|
"relationships. The summary should include "
|
|
"the names of the entities involved and a concise synthesis "
|
|
"of the relationship descriptions. The "
|
|
"goal is to capture the most critical and relevant details that "
|
|
"highlight the nature and significance of "
|
|
"each relationship. Ensure that the summary is coherent and "
|
|
"integrates the information in a way that "
|
|
"emphasizes the key aspects of the relationships."
|
|
),
|
|
),
|
|
ChatMessage(role="user", content=text),
|
|
]
|
|
response = OpenAI().chat(messages)
|
|
clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
|
|
return clean_response
|
|
|
|
def build_communities(self):
|
|
"""Builds communities from the graph and summarizes them."""
|
|
nx_graph = self._create_nx_graph()
|
|
|
|
# Use either Leiden algorithm (from graspologic) or an alternative
|
|
try:
|
|
from graspologic.partition import hierarchical_leiden
|
|
community_hierarchical_clusters = hierarchical_leiden(
|
|
nx_graph, max_cluster_size=self.max_cluster_size
|
|
)
|
|
self.entity_info, community_info = self._collect_community_info(
|
|
nx_graph, community_hierarchical_clusters
|
|
)
|
|
except ImportError:
|
|
# Fallback to community detection using NetworkX or python-louvain
|
|
try:
|
|
from community import best_partition
|
|
partition = best_partition(nx_graph)
|
|
# Reformat partition data to expected structure
|
|
clusters = []
|
|
for node, cluster_id in partition.items():
|
|
class Cluster:
|
|
def __init__(self, node, cluster):
|
|
self.node = node
|
|
self.cluster = cluster
|
|
clusters.append(Cluster(node, cluster_id))
|
|
self.entity_info, community_info = self._collect_community_info(
|
|
nx_graph, clusters
|
|
)
|
|
except ImportError:
|
|
# Use NetworkX's built-in community detection
|
|
from networkx.algorithms import community
|
|
communities = community.greedy_modularity_communities(nx_graph)
|
|
clusters = []
|
|
for i, comm in enumerate(communities):
|
|
for node in comm:
|
|
class Cluster:
|
|
def __init__(self, node, cluster):
|
|
self.node = node
|
|
self.cluster = cluster
|
|
clusters.append(Cluster(node, i))
|
|
self.entity_info, community_info = self._collect_community_info(
|
|
nx_graph, clusters
|
|
)
|
|
|
|
self._summarize_communities(community_info)
|
|
|
|
def _create_nx_graph(self):
|
|
"""Converts internal graph representation to NetworkX graph."""
|
|
nx_graph = nx.Graph()
|
|
triplets = self.get_triplets()
|
|
for entity1, relation, entity2 in triplets:
|
|
nx_graph.add_node(entity1.name)
|
|
nx_graph.add_node(entity2.name)
|
|
nx_graph.add_edge(
|
|
relation.source_id,
|
|
relation.target_id,
|
|
relationship=relation.label,
|
|
description=relation.properties.get("relationship_description", "No description provided"),
|
|
)
|
|
return nx_graph
|
|
|
|
def _collect_community_info(self, nx_graph, clusters):
|
|
"""
|
|
Collect information for each node based on their community,
|
|
allowing entities to belong to multiple clusters.
|
|
"""
|
|
entity_info = defaultdict(set)
|
|
community_info = defaultdict(list)
|
|
|
|
for item in clusters:
|
|
node = item.node
|
|
cluster_id = item.cluster
|
|
|
|
# Update entity_info
|
|
entity_info[node].add(cluster_id)
|
|
|
|
for neighbor in nx_graph.neighbors(node):
|
|
edge_data = nx_graph.get_edge_data(node, neighbor)
|
|
if edge_data:
|
|
detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
|
|
community_info[cluster_id].append(detail)
|
|
|
|
# Convert sets to lists for easier serialization if needed
|
|
entity_info = {k: list(v) for k, v in entity_info.items()}
|
|
|
|
return dict(entity_info), dict(community_info)
|
|
|
|
def _summarize_communities(self, community_info):
|
|
"""Generate and store summaries for each community."""
|
|
for community_id, details in community_info.items():
|
|
details_text = "\n".join(details) + "." # Ensure it ends with a period
|
|
self.community_summary[community_id] = self.generate_community_summary(details_text)
|
|
|
|
def get_community_summaries(self):
|
|
"""Returns the community summaries, building them if not already done."""
|
|
if not self.community_summary:
|
|
self.build_communities()
|
|
return self.community_summary
|
|
|
|
# Define the GraphRAGQueryEngine class
|
|
class GraphRAGQueryEngine(CustomQueryEngine):
|
|
graph_store: Union[GraphRAGStore, Any] # Accept any type of graph store
|
|
index: PropertyGraphIndex
|
|
llm: LLM
|
|
similarity_top_k: int = 20
|
|
|
|
def custom_query(self, query_str: str) -> str:
|
|
"""Process query using either community-based approach or direct retrieval."""
|
|
# Check if we're using GraphRAGStore with communities or SimplePropertyGraphStore
|
|
if hasattr(self.graph_store, 'get_community_summaries'):
|
|
# GraphRAG approach with communities
|
|
entities = self.get_entities(query_str, self.similarity_top_k)
|
|
|
|
community_ids = self.retrieve_entity_communities(
|
|
self.graph_store.entity_info, entities
|
|
)
|
|
community_summaries = self.graph_store.get_community_summaries()
|
|
community_answers = [
|
|
self.generate_answer_from_summary(community_summary, query_str)
|
|
for id, community_summary in community_summaries.items()
|
|
if id in community_ids
|
|
]
|
|
|
|
final_answer = self.aggregate_answers(community_answers)
|
|
return final_answer
|
|
else:
|
|
# Simple approach for SimplePropertyGraphStore
|
|
# Just get relevant nodes and generate answer
|
|
nodes = self.index.as_retriever(
|
|
similarity_top_k=self.similarity_top_k
|
|
).retrieve(query_str)
|
|
|
|
if not nodes:
|
|
return "I couldn't find any relevant information to answer your question."
|
|
|
|
# Combine text from all retrieved nodes
|
|
context = "\n\n".join([node.get_content() for node in nodes])
|
|
|
|
# Generate answer using the LLM
|
|
prompt = f"Based on the following information, please answer this question: {query_str}\n\nInformation:\n{context}"
|
|
messages = [
|
|
ChatMessage(role="system", content=prompt),
|
|
ChatMessage(role="user", content="Please provide a comprehensive answer based on the information provided.")
|
|
]
|
|
response = self.llm.chat(messages)
|
|
return str(response).strip()
|
|
|
|
def get_entities(self, query_str, similarity_top_k):
|
|
nodes_retrieved = self.index.as_retriever(
|
|
similarity_top_k=similarity_top_k
|
|
).retrieve(query_str)
|
|
|
|
entities = set()
|
|
pattern = r"^(\w+(?:\s+\w+)*)\s*->\s*([a-zA-Z\s]+?)\s*->\s*(\w+(?:\s+\w+)*)$"
|
|
|
|
for node in nodes_retrieved:
|
|
matches = re.findall(
|
|
pattern, node.text, re.MULTILINE | re.IGNORECASE
|
|
)
|
|
for match in matches:
|
|
subject = match[0]
|
|
obj = match[2]
|
|
entities.add(subject)
|
|
entities.add(obj)
|
|
|
|
return list(entities)
|
|
|
|
def retrieve_entity_communities(self, entity_info, entities):
|
|
"""
|
|
Retrieve cluster information for given entities,
|
|
allowing for multiple clusters per entity.
|
|
|
|
Args:
|
|
entity_info (dict): Dictionary mapping entities to their cluster IDs (list).
|
|
entities (list): List of entity names to retrieve information for.
|
|
|
|
Returns:
|
|
List of community or cluster IDs to which an entity belongs.
|
|
"""
|
|
community_ids = []
|
|
|
|
for entity in entities:
|
|
if entity in entity_info:
|
|
community_ids.extend(entity_info[entity])
|
|
|
|
return list(set(community_ids))
|
|
|
|
def generate_answer_from_summary(self, community_summary, query):
|
|
"""Generate an answer from a community summary based on a given query using LLM."""
|
|
prompt = (
|
|
f"Given the community summary: {community_summary}, "
|
|
f"how would you answer the following query? Query: {query}"
|
|
)
|
|
messages = [
|
|
ChatMessage(role="system", content=prompt),
|
|
ChatMessage(
|
|
role="user",
|
|
content="I need an answer based on the above information.",
|
|
),
|
|
]
|
|
response = self.llm.chat(messages)
|
|
cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
|
|
return cleaned_response
|
|
|
|
def aggregate_answers(self, community_answers):
|
|
"""Aggregate individual community answers into a final, coherent response."""
|
|
prompt = "Combine the following intermediate answers into a final, concise response."
|
|
messages = [
|
|
ChatMessage(role="system", content=prompt),
|
|
ChatMessage(
|
|
role="user",
|
|
content=f"Intermediate answers: {community_answers}",
|
|
),
|
|
]
|
|
final_response = self.llm.chat(messages)
|
|
cleaned_final_response = re.sub(
|
|
r"^assistant:\s*", "", str(final_response)
|
|
).strip()
|
|
return cleaned_final_response
|
|
|
|
def custom_parse_fn(response_str: str) -> Any:
|
|
"""Custom parser for LLM responses that extract entities and relationships"""
|
|
json_pattern = r"\{.*\}"
|
|
match = re.search(json_pattern, response_str, re.DOTALL)
|
|
entities = []
|
|
relationships = []
|
|
|
|
if not match:
|
|
return entities, relationships
|
|
|
|
json_str = match.group(0)
|
|
try:
|
|
data = json.loads(json_str)
|
|
entities = [
|
|
(
|
|
entity["entity_name"],
|
|
entity["entity_type"],
|
|
entity.get("entity_description", f"Description of {entity['entity_name']}"),
|
|
)
|
|
for entity in data.get("entities", [])
|
|
]
|
|
relationships = [
|
|
(
|
|
relation["source_entity"],
|
|
relation["target_entity"],
|
|
relation["relation"],
|
|
relation.get("relationship_description", f"Relationship between {relation['source_entity']} and {relation['target_entity']}"),
|
|
)
|
|
for relation in data.get("relationships", [])
|
|
]
|
|
return entities, relationships
|
|
except (json.JSONDecodeError, KeyError) as e:
|
|
print(f"Error parsing response: {e}")
|
|
print(f"Problematic JSON: {json_str[:200]}...")
|
|
return entities, relationships
|
|
|
|
# Define the prompt template for triple extraction
|
|
KG_TRIPLET_EXTRACT_TMPL = """
|
|
-Goal-
|
|
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
|
|
|
|
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.
|
|
|
|
-Steps-
|
|
1. Identify all entities. For each identified entity, extract the following information:
|
|
- entity_name: Name of the entity, capitalized
|
|
- entity_type: Type of the entity
|
|
- entity_description: Comprehensive description of the entity's attributes and activities
|
|
|
|
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
|
|
For each pair of related entities, extract the following information:
|
|
- source_entity: name of the source entity, as identified in step 1
|
|
- target_entity: name of the target entity, as identified in step 1
|
|
- relation: relationship between source_entity and target_entity
|
|
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
|
|
|
|
3. Output Formatting:
|
|
- Return the result in valid JSON format with two keys: 'entities' (list of entity objects) and 'relationships' (list of relationship objects).
|
|
- Exclude any text outside the JSON structure (e.g., no explanations or comments).
|
|
- If no entities or relationships are identified, return empty lists: { "entities": [], "relationships": [] }.
|
|
|
|
-Real Data-
|
|
######################
|
|
text: {text}
|
|
######################
|
|
output:
|
|
"""
|
|
|
|
def main():
|
|
print("Starting GraphRAG document processing...")
|
|
|
|
# Load documents from specified directory
|
|
documents = SimpleDirectoryReader(
|
|
input_dir="supporting_files/files_for_rag_store"
|
|
).load_data()
|
|
|
|
print(f"Loaded {len(documents)} documents")
|
|
|
|
# Create nodes using a sentence splitter
|
|
splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
|
|
nodes = splitter.get_nodes_from_documents(documents)
|
|
|
|
print(f"Created {len(nodes)} nodes from documents")
|
|
|
|
# Initialize the LLM
|
|
llm = OpenAI(model="gpt-4")
|
|
|
|
# Create the knowledge graph extractor
|
|
kg_extractor = GraphRAGExtractor(
|
|
llm=llm,
|
|
extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
|
|
max_paths_per_chunk=2,
|
|
parse_fn=custom_parse_fn,
|
|
)
|
|
|
|
# Connect to Neo4j running in Docker
|
|
neo4j_username = "neo4j"
|
|
neo4j_password = "tavern-easy-museum-arthur-coconut-3483"
|
|
neo4j_url = "bolt://localhost:7687"
|
|
|
|
print(f"Connecting to Neo4j at {neo4j_url} with username '{neo4j_username}'")
|
|
|
|
# Create GraphRAGStore (our extended Neo4j store)
|
|
try:
|
|
graph_store = GraphRAGStore(
|
|
username=neo4j_username,
|
|
password=neo4j_password,
|
|
url=neo4j_url
|
|
)
|
|
print("Successfully connected to Neo4j database")
|
|
except Exception as e:
|
|
print(f"Error connecting to Neo4j: {e}")
|
|
print("Falling back to in-memory graph store. Some features may be limited.")
|
|
# Fallback to in-memory graph store if Neo4j connection fails
|
|
from llama_index.core.graph_stores import SimplePropertyGraphStore
|
|
graph_store = SimplePropertyGraphStore()
|
|
|
|
# Build the index
|
|
index = PropertyGraphIndex(
|
|
nodes=nodes,
|
|
kg_extractors=[kg_extractor],
|
|
property_graph_store=graph_store,
|
|
show_progress=True,
|
|
)
|
|
|
|
print("Building graph communities...")
|
|
try:
|
|
# Build communities for graph-based querying
|
|
# Only for GraphRAGStore, not for SimplePropertyGraphStore
|
|
if hasattr(graph_store, 'build_communities'):
|
|
graph_store.build_communities()
|
|
print("Communities built successfully")
|
|
else:
|
|
print("Skipping community building (not using Neo4j)")
|
|
except Exception as e:
|
|
print(f"Error building communities: {e}")
|
|
|
|
# Create the query engine
|
|
query_engine = GraphRAGQueryEngine(
|
|
graph_store=graph_store,
|
|
llm=llm,
|
|
index=index,
|
|
similarity_top_k=10,
|
|
)
|
|
|
|
# Simple interactive query loop
|
|
print("\n--- GraphRAG Query System Ready ---")
|
|
print("Type 'exit' to quit")
|
|
|
|
while True:
|
|
query = input("\nEnter your query: ")
|
|
|
|
if query.lower() in ('exit', 'quit'):
|
|
break
|
|
|
|
try:
|
|
response = query_engine.custom_query(query)
|
|
print("\nResponse:")
|
|
print(response)
|
|
except Exception as e:
|
|
print(f"Error processing query: {e}")
|
|
|
|
print("GraphRAG session ended")
|
|
|
|
if __name__ == "__main__":
|
|
main() |