import os import json import re import asyncio import networkx as nx from collections import defaultdict from typing import Any, List, Callable, Optional, Union, Dict from dotenv import load_dotenv # Import LlamaIndex components from llama_index.core import Document, Settings from llama_index.core.node_parser import SentenceSplitter from llama_index.core import PropertyGraphIndex from llama_index.core.async_utils import run_jobs from llama_index.core.indices.property_graph.utils import default_parse_triplets_fn from llama_index.core.graph_stores.types import EntityNode, KG_NODES_KEY, KG_RELATIONS_KEY, Relation from llama_index.core.llms.llm import LLM from llama_index.core.prompts import PromptTemplate from llama_index.core.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT from llama_index.core.schema import TransformComponent, BaseNode from llama_index.core.query_engine import CustomQueryEngine from llama_index.llms.openai import OpenAI from llama_index.core.llms import ChatMessage from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore from llama_index.core import SimpleDirectoryReader # Community detection (using NetworkX instead of graspologic as a fallback) try: from community import best_partition # python-louvain package except ImportError: print("Community detection package not found, using NetworkX built-in community detection") # Load environment variables from .env file load_dotenv() # Use API key from environment as fallback if not os.environ.get("OPENAI_API_KEY"): os.environ["OPENAI_API_KEY"] = "sk-proj-wXcoIn81Vwg4Iaw0vhmYT3BlbkFJmt1eOxeEAF1juUfhzMtk" # Define the GraphRAGExtractor class class GraphRAGExtractor(TransformComponent): """Extract triples from a graph. Uses an LLM and a simple prompt + output parsing to extract paths (i.e. triples) and entity, relation descriptions from text. Args: llm (LLM): The language model to use. extract_prompt (Union[str, PromptTemplate]): The prompt to use for extracting triples. parse_fn (callable): A function to parse the output of the language model. num_workers (int): The number of workers to use for parallel processing. max_paths_per_chunk (int): The maximum number of paths to extract per chunk. """ llm: LLM extract_prompt: PromptTemplate parse_fn: Callable num_workers: int max_paths_per_chunk: int def __init__( self, llm: Optional[LLM] = None, extract_prompt: Optional[Union[str, PromptTemplate]] = None, parse_fn: Callable = default_parse_triplets_fn, max_paths_per_chunk: int = 10, num_workers: int = 4, ) -> None: """Init params.""" from llama_index.core import Settings if isinstance(extract_prompt, str): extract_prompt = PromptTemplate(extract_prompt) super().__init__( llm=llm or Settings.llm, extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT, parse_fn=parse_fn, num_workers=num_workers, max_paths_per_chunk=max_paths_per_chunk, ) @classmethod def class_name(cls) -> str: return "GraphExtractor" def __call__( self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any ) -> List[BaseNode]: """Extract triples from nodes.""" return asyncio.run( self.acall(nodes, show_progress=show_progress, **kwargs) ) async def _aextract(self, node: BaseNode) -> BaseNode: """Extract triples from a node.""" assert hasattr(node, "text") text = node.get_content(metadata_mode="llm") try: llm_response = await self.llm.apredict( self.extract_prompt, text=text, max_knowledge_triplets=self.max_paths_per_chunk, ) entities, entities_relationship = self.parse_fn(llm_response) except ValueError: entities = [] entities_relationship = [] existing_nodes = node.metadata.pop(KG_NODES_KEY, []) existing_relations = node.metadata.pop(KG_RELATIONS_KEY, []) entity_metadata = node.metadata.copy() for entity, entity_type, description in entities: entity_metadata["entity_description"] = description entity_node = EntityNode( name=entity, label=entity_type, properties=entity_metadata ) existing_nodes.append(entity_node) relation_metadata = node.metadata.copy() for triple in entities_relationship: subj, obj, rel, description = triple relation_metadata["relationship_description"] = description rel_node = Relation( label=rel, source_id=subj, target_id=obj, properties=relation_metadata, ) existing_relations.append(rel_node) node.metadata[KG_NODES_KEY] = existing_nodes node.metadata[KG_RELATIONS_KEY] = existing_relations return node async def acall( self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any ) -> List[BaseNode]: """Extract triples from nodes async.""" jobs = [] for node in nodes: jobs.append(self._aextract(node)) return await run_jobs( jobs, workers=self.num_workers, show_progress=show_progress, desc="Extracting paths from text", ) # Define the GraphRAGStore class class GraphRAGStore(Neo4jPropertyGraphStore): community_summary = {} entity_info = None max_cluster_size = 5 def generate_community_summary(self, text): """Generate summary for a given text using an LLM.""" messages = [ ChatMessage( role="system", content=( "You are provided with a set of " "relationships from a knowledge graph, each represented as " "entity1->entity2->relation-" ">relationship_description. Your task is to create a summary of " "these " "relationships. The summary should include " "the names of the entities involved and a concise synthesis " "of the relationship descriptions. The " "goal is to capture the most critical and relevant details that " "highlight the nature and significance of " "each relationship. Ensure that the summary is coherent and " "integrates the information in a way that " "emphasizes the key aspects of the relationships." ), ), ChatMessage(role="user", content=text), ] response = OpenAI().chat(messages) clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip() return clean_response def build_communities(self): """Builds communities from the graph and summarizes them.""" nx_graph = self._create_nx_graph() # Use either Leiden algorithm (from graspologic) or an alternative try: from graspologic.partition import hierarchical_leiden community_hierarchical_clusters = hierarchical_leiden( nx_graph, max_cluster_size=self.max_cluster_size ) self.entity_info, community_info = self._collect_community_info( nx_graph, community_hierarchical_clusters ) except ImportError: # Fallback to community detection using NetworkX or python-louvain try: from community import best_partition partition = best_partition(nx_graph) # Reformat partition data to expected structure clusters = [] for node, cluster_id in partition.items(): class Cluster: def __init__(self, node, cluster): self.node = node self.cluster = cluster clusters.append(Cluster(node, cluster_id)) self.entity_info, community_info = self._collect_community_info( nx_graph, clusters ) except ImportError: # Use NetworkX's built-in community detection from networkx.algorithms import community communities = community.greedy_modularity_communities(nx_graph) clusters = [] for i, comm in enumerate(communities): for node in comm: class Cluster: def __init__(self, node, cluster): self.node = node self.cluster = cluster clusters.append(Cluster(node, i)) self.entity_info, community_info = self._collect_community_info( nx_graph, clusters ) self._summarize_communities(community_info) def _create_nx_graph(self): """Converts internal graph representation to NetworkX graph.""" nx_graph = nx.Graph() triplets = self.get_triplets() for entity1, relation, entity2 in triplets: nx_graph.add_node(entity1.name) nx_graph.add_node(entity2.name) nx_graph.add_edge( relation.source_id, relation.target_id, relationship=relation.label, description=relation.properties.get("relationship_description", "No description provided"), ) return nx_graph def _collect_community_info(self, nx_graph, clusters): """ Collect information for each node based on their community, allowing entities to belong to multiple clusters. """ entity_info = defaultdict(set) community_info = defaultdict(list) for item in clusters: node = item.node cluster_id = item.cluster # Update entity_info entity_info[node].add(cluster_id) for neighbor in nx_graph.neighbors(node): edge_data = nx_graph.get_edge_data(node, neighbor) if edge_data: detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}" community_info[cluster_id].append(detail) # Convert sets to lists for easier serialization if needed entity_info = {k: list(v) for k, v in entity_info.items()} return dict(entity_info), dict(community_info) def _summarize_communities(self, community_info): """Generate and store summaries for each community.""" for community_id, details in community_info.items(): details_text = "\n".join(details) + "." # Ensure it ends with a period self.community_summary[community_id] = self.generate_community_summary(details_text) def get_community_summaries(self): """Returns the community summaries, building them if not already done.""" if not self.community_summary: self.build_communities() return self.community_summary # Define the GraphRAGQueryEngine class class GraphRAGQueryEngine(CustomQueryEngine): graph_store: Union[GraphRAGStore, Any] # Accept any type of graph store index: PropertyGraphIndex llm: LLM similarity_top_k: int = 20 def custom_query(self, query_str: str) -> str: """Process query using either community-based approach or direct retrieval.""" # Check if we're using GraphRAGStore with communities or SimplePropertyGraphStore if hasattr(self.graph_store, 'get_community_summaries'): # GraphRAG approach with communities entities = self.get_entities(query_str, self.similarity_top_k) community_ids = self.retrieve_entity_communities( self.graph_store.entity_info, entities ) community_summaries = self.graph_store.get_community_summaries() community_answers = [ self.generate_answer_from_summary(community_summary, query_str) for id, community_summary in community_summaries.items() if id in community_ids ] final_answer = self.aggregate_answers(community_answers) return final_answer else: # Simple approach for SimplePropertyGraphStore # Just get relevant nodes and generate answer nodes = self.index.as_retriever( similarity_top_k=self.similarity_top_k ).retrieve(query_str) if not nodes: return "I couldn't find any relevant information to answer your question." # Combine text from all retrieved nodes context = "\n\n".join([node.get_content() for node in nodes]) # Generate answer using the LLM prompt = f"Based on the following information, please answer this question: {query_str}\n\nInformation:\n{context}" messages = [ ChatMessage(role="system", content=prompt), ChatMessage(role="user", content="Please provide a comprehensive answer based on the information provided.") ] response = self.llm.chat(messages) return str(response).strip() def get_entities(self, query_str, similarity_top_k): nodes_retrieved = self.index.as_retriever( similarity_top_k=similarity_top_k ).retrieve(query_str) entities = set() pattern = r"^(\w+(?:\s+\w+)*)\s*->\s*([a-zA-Z\s]+?)\s*->\s*(\w+(?:\s+\w+)*)$" for node in nodes_retrieved: matches = re.findall( pattern, node.text, re.MULTILINE | re.IGNORECASE ) for match in matches: subject = match[0] obj = match[2] entities.add(subject) entities.add(obj) return list(entities) def retrieve_entity_communities(self, entity_info, entities): """ Retrieve cluster information for given entities, allowing for multiple clusters per entity. Args: entity_info (dict): Dictionary mapping entities to their cluster IDs (list). entities (list): List of entity names to retrieve information for. Returns: List of community or cluster IDs to which an entity belongs. """ community_ids = [] for entity in entities: if entity in entity_info: community_ids.extend(entity_info[entity]) return list(set(community_ids)) def generate_answer_from_summary(self, community_summary, query): """Generate an answer from a community summary based on a given query using LLM.""" prompt = ( f"Given the community summary: {community_summary}, " f"how would you answer the following query? Query: {query}" ) messages = [ ChatMessage(role="system", content=prompt), ChatMessage( role="user", content="I need an answer based on the above information.", ), ] response = self.llm.chat(messages) cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip() return cleaned_response def aggregate_answers(self, community_answers): """Aggregate individual community answers into a final, coherent response.""" prompt = "Combine the following intermediate answers into a final, concise response." messages = [ ChatMessage(role="system", content=prompt), ChatMessage( role="user", content=f"Intermediate answers: {community_answers}", ), ] final_response = self.llm.chat(messages) cleaned_final_response = re.sub( r"^assistant:\s*", "", str(final_response) ).strip() return cleaned_final_response def custom_parse_fn(response_str: str) -> Any: """Custom parser for LLM responses that extract entities and relationships""" json_pattern = r"\{.*\}" match = re.search(json_pattern, response_str, re.DOTALL) entities = [] relationships = [] if not match: return entities, relationships json_str = match.group(0) try: data = json.loads(json_str) entities = [ ( entity["entity_name"], entity["entity_type"], entity.get("entity_description", f"Description of {entity['entity_name']}"), ) for entity in data.get("entities", []) ] relationships = [ ( relation["source_entity"], relation["target_entity"], relation["relation"], relation.get("relationship_description", f"Relationship between {relation['source_entity']} and {relation['target_entity']}"), ) for relation in data.get("relationships", []) ] return entities, relationships except (json.JSONDecodeError, KeyError) as e: print(f"Error parsing response: {e}") print(f"Problematic JSON: {json_str[:200]}...") return entities, relationships # Define the prompt template for triple extraction KG_TRIPLET_EXTRACT_TMPL = """ -Goal- Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities. Given the text, extract up to {max_knowledge_triplets} entity-relation triplets. -Steps- 1. Identify all entities. For each identified entity, extract the following information: - entity_name: Name of the entity, capitalized - entity_type: Type of the entity - entity_description: Comprehensive description of the entity's attributes and activities 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. For each pair of related entities, extract the following information: - source_entity: name of the source entity, as identified in step 1 - target_entity: name of the target entity, as identified in step 1 - relation: relationship between source_entity and target_entity - relationship_description: explanation as to why you think the source entity and the target entity are related to each other 3. Output Formatting: - Return the result in valid JSON format with two keys: 'entities' (list of entity objects) and 'relationships' (list of relationship objects). - Exclude any text outside the JSON structure (e.g., no explanations or comments). - If no entities or relationships are identified, return empty lists: { "entities": [], "relationships": [] }. -Real Data- ###################### text: {text} ###################### output: """ def main(): print("Starting GraphRAG document processing...") # Load documents from specified directory documents = SimpleDirectoryReader( input_dir="supporting_files/files_for_rag_store" ).load_data() print(f"Loaded {len(documents)} documents") # Create nodes using a sentence splitter splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20) nodes = splitter.get_nodes_from_documents(documents) print(f"Created {len(nodes)} nodes from documents") # Initialize the LLM llm = OpenAI(model="gpt-4") # Create the knowledge graph extractor kg_extractor = GraphRAGExtractor( llm=llm, extract_prompt=KG_TRIPLET_EXTRACT_TMPL, max_paths_per_chunk=2, parse_fn=custom_parse_fn, ) # Connect to Neo4j running in Docker neo4j_username = "neo4j" neo4j_password = "tavern-easy-museum-arthur-coconut-3483" neo4j_url = "bolt://localhost:7687" print(f"Connecting to Neo4j at {neo4j_url} with username '{neo4j_username}'") # Create GraphRAGStore (our extended Neo4j store) try: graph_store = GraphRAGStore( username=neo4j_username, password=neo4j_password, url=neo4j_url ) print("Successfully connected to Neo4j database") except Exception as e: print(f"Error connecting to Neo4j: {e}") print("Falling back to in-memory graph store. Some features may be limited.") # Fallback to in-memory graph store if Neo4j connection fails from llama_index.core.graph_stores import SimplePropertyGraphStore graph_store = SimplePropertyGraphStore() # Build the index index = PropertyGraphIndex( nodes=nodes, kg_extractors=[kg_extractor], property_graph_store=graph_store, show_progress=True, ) print("Building graph communities...") try: # Build communities for graph-based querying # Only for GraphRAGStore, not for SimplePropertyGraphStore if hasattr(graph_store, 'build_communities'): graph_store.build_communities() print("Communities built successfully") else: print("Skipping community building (not using Neo4j)") except Exception as e: print(f"Error building communities: {e}") # Create the query engine query_engine = GraphRAGQueryEngine( graph_store=graph_store, llm=llm, index=index, similarity_top_k=10, ) # Simple interactive query loop print("\n--- GraphRAG Query System Ready ---") print("Type 'exit' to quit") while True: query = input("\nEnter your query: ") if query.lower() in ('exit', 'quit'): break try: response = query_engine.custom_query(query) print("\nResponse:") print(response) except Exception as e: print(f"Error processing query: {e}") print("GraphRAG session ended") if __name__ == "__main__": main()